]> git.saurik.com Git - redis.git/blobdiff - src/scripting.c
Now it is possible to return multi bulks of multi bulks from Lua, just returning...
[redis.git] / src / scripting.c
index 9b1fb12d465827cf9f22f347fd461bf0b3003a5a..21b3214db7d7743a7376b4d1fb3cf7b378672ad3 100644 (file)
@@ -67,8 +67,8 @@ char *redisProtocolToLuaType_Bulk(lua_State *lua, char *reply) {
     long long bulklen;
 
     string2ll(reply+1,p-reply-1,&bulklen);
-    if (bulklen == 0) {
-        lua_pushnil(lua);
+    if (bulklen == -1) {
+        lua_pushboolean(lua,0);
         return p+2;
     } else {
         lua_pushlstring(lua,p+2,bulklen);
@@ -79,7 +79,10 @@ char *redisProtocolToLuaType_Bulk(lua_State *lua, char *reply) {
 char *redisProtocolToLuaType_Status(lua_State *lua, char *reply) {
     char *p = strchr(reply+1,'\r');
 
+    lua_newtable(lua);
+    lua_pushstring(lua,"ok");
     lua_pushlstring(lua,reply+1,p-reply-1);
+    lua_settable(lua,-3);
     return p+2;
 }
 
@@ -101,7 +104,7 @@ char *redisProtocolToLuaType_MultiBulk(lua_State *lua, char *reply) {
     string2ll(reply+1,p-reply-1,&mbulklen);
     p += 2;
     if (mbulklen == -1) {
-        lua_pushnil(lua);
+        lua_pushboolean(lua,0);
         return p;
     }
     lua_newtable(lua);
@@ -113,6 +116,13 @@ char *redisProtocolToLuaType_MultiBulk(lua_State *lua, char *reply) {
     return p;
 }
 
+void luaPushError(lua_State *lua, char *error) {
+    lua_newtable(lua);
+    lua_pushstring(lua,"err");
+    lua_pushstring(lua, error);
+    lua_settable(lua,-3);
+}
+
 int luaRedisCommand(lua_State *lua) {
     int j, argc = lua_gettop(lua);
     struct redisCommand *cmd;
@@ -122,9 +132,26 @@ int luaRedisCommand(lua_State *lua) {
 
     /* Build the arguments vector */
     argv = zmalloc(sizeof(robj*)*argc);
-    for (j = 0; j < argc; j++)
+    for (j = 0; j < argc; j++) {
+        if (!lua_isstring(lua,j+1)) break;
         argv[j] = createStringObject((char*)lua_tostring(lua,j+1),
                                      lua_strlen(lua,j+1));
+    }
+    
+    /* Check if one of the arguments passed by the Lua script
+     * is not a string or an integer (lua_isstring() return true for
+     * integers as well). */
+    if (j != argc) {
+        j--;
+        while (j >= 0) {
+            decrRefCount(argv[j]);
+            j--;
+        }
+        zfree(argv);
+        luaPushError(lua,
+            "Lua redis() command arguments must be strings or integers");
+        return 1;
+    }
 
     /* Command lookup */
     cmd = lookupCommand(argv[0]->ptr);
@@ -133,14 +160,11 @@ int luaRedisCommand(lua_State *lua) {
     {
         for (j = 0; j < argc; j++) decrRefCount(argv[j]);
         zfree(argv);
-        lua_newtable(lua);
-        lua_pushstring(lua,"err");
         if (cmd)
-            lua_pushstring(lua,
+            luaPushError(lua,
                 "Wrong number of args calling Redis command From Lua script");
         else
-            lua_pushstring(lua,"Unknown Redis command called from Lua script");
-        lua_settable(lua,-3);
+            luaPushError(lua,"Unknown Redis command called from Lua script");
         return 1;
     }
 
@@ -160,7 +184,7 @@ int luaRedisCommand(lua_State *lua) {
     while(listLength(c->reply)) {
         robj *o = listNodeValue(listFirst(c->reply));
 
-        sdscatlen(reply,o->ptr,sdslen(o->ptr));
+        reply = sdscatlen(reply,o->ptr,sdslen(o->ptr));
         listDelNode(c->reply,listFirst(c->reply));
     }
     redisProtocolToLuaType(lua,reply);
@@ -175,6 +199,18 @@ int luaRedisCommand(lua_State *lua) {
     return 1;
 }
 
+void luaMaskCountHook(lua_State *lua, lua_Debug *ar) {
+    long long elapsed;
+    REDIS_NOTUSED(ar);
+
+    elapsed = (ustime()/1000) - server.lua_time_start;
+    if (elapsed >= server.lua_time_limit) {
+        redisLog(REDIS_NOTICE,"Lua script aborted for max execution time after %lld milliseconds of running time.",elapsed);
+        lua_pushstring(lua,"Script aborted for max execution time.");
+        lua_error(lua);
+    }
+}
+
 void scriptingInit(void) {
     lua_State *lua = lua_open();
     luaL_openlibs(lua);
@@ -212,33 +248,45 @@ void hashScript(char *digest, char *script, size_t len) {
 }
 
 void luaReplyToRedisReply(redisClient *c, lua_State *lua) {
-    int t = lua_type(lua,1);
+    int t = lua_type(lua,-1);
 
     switch(t) {
     case LUA_TSTRING:
-        addReplyBulkCBuffer(c,(char*)lua_tostring(lua,1),lua_strlen(lua,1));
+        addReplyBulkCBuffer(c,(char*)lua_tostring(lua,-1),lua_strlen(lua,-1));
         break;
     case LUA_TBOOLEAN:
-        addReply(c,lua_toboolean(lua,1) ? shared.cone : shared.czero);
+        addReply(c,lua_toboolean(lua,-1) ? shared.cone : shared.nullbulk);
         break;
     case LUA_TNUMBER:
-        addReplyLongLong(c,(long long)lua_tonumber(lua,1));
+        addReplyLongLong(c,(long long)lua_tonumber(lua,-1));
         break;
     case LUA_TTABLE:
-        /* We need to check if it is an array or an error.
-         * Error are returned as a single element table with 'err' field. */
+        /* We need to check if it is an array, an error, or a status reply.
+         * Error are returned as a single element table with 'err' field.
+         * Status replies are returned as single elment table with 'ok' field */
         lua_pushstring(lua,"err");
         lua_gettable(lua,-2);
         t = lua_type(lua,-1);
         if (t == LUA_TSTRING) {
             addReplySds(c,sdscatprintf(sdsempty(),
                     "-%s\r\n",(char*)lua_tostring(lua,-1)));
+            lua_pop(lua,2);
+            return;
+        }
+
+        lua_pop(lua,1);
+        lua_pushstring(lua,"ok");
+        lua_gettable(lua,-2);
+        t = lua_type(lua,-1);
+        if (t == LUA_TSTRING) {
+            addReplySds(c,sdscatprintf(sdsempty(),
+                    "+%s\r\n",(char*)lua_tostring(lua,-1)));
             lua_pop(lua,1);
         } else {
             void *replylen = addDeferredMultiBulkLength(c);
             int j = 1, mbulklen = 0;
 
-            lua_pop(lua,1); /* Discard the 'err' field value we popped */
+            lua_pop(lua,1); /* Discard the 'ok' field value we popped */
             while(1) {
                 lua_pushnumber(lua,j++);
                 lua_gettable(lua,-2);
@@ -246,17 +294,9 @@ void luaReplyToRedisReply(redisClient *c, lua_State *lua) {
                 if (t == LUA_TNIL) {
                     lua_pop(lua,1);
                     break;
-                } else if (t == LUA_TSTRING) {
-                    size_t len;
-                    char *s = (char*) lua_tolstring(lua,-1,&len);
-
-                    addReplyBulkCBuffer(c,s,len);
-                    mbulklen++;
-                } else if (t == LUA_TNUMBER) {
-                    addReplyLongLong(c,(long long)lua_tonumber(lua,-1));
-                    mbulklen++;
                 }
-                lua_pop(lua,1);
+                luaReplyToRedisReply(c, lua);
+                mbulklen++;
             }
             setDeferredMultiBulkLength(c,replylen,mbulklen);
         }
@@ -309,7 +349,7 @@ void evalCommand(redisClient *c) {
         funcdef = sdscatlen(funcdef," ()\n",4);
         funcdef = sdscatlen(funcdef,c->argv[1]->ptr,sdslen(c->argv[1]->ptr));
         funcdef = sdscatlen(funcdef,"\nend\n",5);
-        printf("Defining:\n%s\n",funcdef);
+        /* printf("Defining:\n%s\n",funcdef); */
 
         if (luaL_loadbuffer(lua,funcdef,sdslen(funcdef),"func definition")) {
             addReplyErrorFormat(c,"Error compiling script (new function): %s\n",
@@ -332,15 +372,33 @@ void evalCommand(redisClient *c) {
      * EVAL received. */
     luaSetGlobalArray(lua,"KEYS",c->argv+3,numkeys);
     luaSetGlobalArray(lua,"ARGV",c->argv+3+numkeys,c->argc-3-numkeys);
+
+    /* Select the right DB in the context of the Lua client */
+    selectDb(server.lua_client,c->db->id);
     
+    /* Set an hook in order to be able to stop the script execution if it
+     * is running for too much time.
+     * We set the hook only if the time limit is enabled as the hook will
+     * make the Lua script execution slower. */
+    if (server.lua_time_limit > 0) {
+        lua_sethook(lua,luaMaskCountHook,LUA_MASKCOUNT,100000);
+        server.lua_time_start = ustime()/1000;
+    } else {
+        lua_sethook(lua,luaMaskCountHook,0,0);
+    }
+
     /* At this point whatever this script was never seen before or if it was
      * already defined, we can call it. We have zero arguments and expect
      * a single return value. */
     if (lua_pcall(lua,0,1,0)) {
+        selectDb(c,server.lua_client->db->id); /* set DB ID from Lua client */
         addReplyErrorFormat(c,"Error running script (call to %s): %s\n",
             funcname, lua_tostring(lua,-1));
         lua_pop(lua,1);
+        lua_gc(lua,LUA_GCCOLLECT,0);
         return;
     }
+    selectDb(c,server.lua_client->db->id); /* set DB ID from Lua client */
     luaReplyToRedisReply(c,lua);
+    lua_gc(lua,LUA_GCSTEP,1);
 }