]> git.saurik.com Git - redis.git/blobdiff - src/scripting.c
Scripting: add helper functions redis.error_reply() and redis.status_reply().
[redis.git] / src / scripting.c
index 76b25ad2334e5e455f238e6426ea59d19ea45574..01aa2006f3ddd02fad86bd09ac9c7b5ac25c768a 100644 (file)
@@ -167,6 +167,13 @@ int luaRedisGenericCommand(lua_State *lua, int raise_error) {
     redisClient *c = server.lua_client;
     sds reply;
 
     redisClient *c = server.lua_client;
     sds reply;
 
+    /* Require at least one argument */
+    if (argc == 0) {
+        luaPushError(lua,
+            "Please specify at least one argument for redis.call()");
+        return 1;
+    }
+
     /* Build the arguments vector */
     argv = zmalloc(sizeof(robj*)*argc);
     for (j = 0; j < argc; j++) {
     /* Build the arguments vector */
     argv = zmalloc(sizeof(robj*)*argc);
     for (j = 0; j < argc; j++) {
@@ -275,11 +282,10 @@ int luaRedisGenericCommand(lua_State *lua, int raise_error) {
      * reply as expected. */
     if ((cmd->flags & REDIS_CMD_SORT_FOR_SCRIPT) &&
         (reply[0] == '*' && reply[1] != '-')) {
      * reply as expected. */
     if ((cmd->flags & REDIS_CMD_SORT_FOR_SCRIPT) &&
         (reply[0] == '*' && reply[1] != '-')) {
-        /* Skip this step if command is SORT but output was already sorted */
-        if (cmd->proc != sortCommand || server.sort_dontsort)
             luaSortArray(lua);
     }
     sdsfree(reply);
             luaSortArray(lua);
     }
     sdsfree(reply);
+    c->reply_bytes = 0;
 
 cleanup:
     /* Clean up. Command code may have changed argv/argc so we use the
 
 cleanup:
     /* Clean up. Command code may have changed argv/argc so we use the
@@ -326,6 +332,34 @@ int luaRedisSha1hexCommand(lua_State *lua) {
     return 1;
 }
 
     return 1;
 }
 
+/* Returns a table with a single field 'field' set to the string value
+ * passed as argument. This helper function is handy when returning
+ * a Redis Protocol error or status reply from Lua:
+ *
+ * return redis.error_reply("ERR Some Error")
+ * return redis.status_reply("ERR Some Error")
+ */
+int luaRedisReturnSingleFieldTable(lua_State *lua, char *field) {
+    if (lua_gettop(lua) != 1 || lua_type(lua,-1) != LUA_TSTRING) {
+        luaPushError(lua, "wrong number or type of arguments");
+        return 1;
+    }
+
+    lua_newtable(lua);
+    lua_pushstring(lua, field);
+    lua_pushvalue(lua, -3);
+    lua_settable(lua, -3);
+    return 1;
+}
+
+int luaRedisErrorReplyCommand(lua_State *lua) {
+    return luaRedisReturnSingleFieldTable(lua,"err");
+}
+
+int luaRedisStatusReplyCommand(lua_State *lua) {
+    return luaRedisReturnSingleFieldTable(lua,"ok");
+}
+
 int luaLogCommand(lua_State *lua) {
     int j, argc = lua_gettop(lua);
     int level;
 int luaLogCommand(lua_State *lua) {
     int j, argc = lua_gettop(lua);
     int level;
@@ -412,6 +446,13 @@ void luaLoadLibraries(lua_State *lua) {
 #endif
 }
 
 #endif
 }
 
+/* Remove a functions that we don't want to expose to the Redis scripting
+ * environment. */
+void luaRemoveUnsupportedFunctions(lua_State *lua) {
+    lua_pushnil(lua);
+    lua_setglobal(lua,"loadfile");
+}
+
 /* This function installs metamethods in the global table _G that prevent
  * the creation of globals accidentally.
  *
 /* This function installs metamethods in the global table _G that prevent
  * the creation of globals accidentally.
  *
@@ -437,14 +478,14 @@ void scriptingEnableGlobalsProtection(lua_State *lua) {
     s[j++]="end\n";
     s[j++]="mt.__index = function (t, n)\n";
     s[j++]="  if debug.getinfo(2) and debug.getinfo(2, \"S\").what ~= \"C\" then\n";
     s[j++]="end\n";
     s[j++]="mt.__index = function (t, n)\n";
     s[j++]="  if debug.getinfo(2) and debug.getinfo(2, \"S\").what ~= \"C\" then\n";
-    s[j++]="    error(\"Script attempted to access unexisting global variable '\"..n..\"'\", 2)\n";
+    s[j++]="    error(\"Script attempted to access unexisting global variable '\"..tostring(n)..\"'\", 2)\n";
     s[j++]="  end\n";
     s[j++]="  return rawget(t, n)\n";
     s[j++]="end\n";
     s[j++]=NULL;
 
     for (j = 0; s[j] != NULL; j++) code = sdscatlen(code,s[j],strlen(s[j]));
     s[j++]="  end\n";
     s[j++]="  return rawget(t, n)\n";
     s[j++]="end\n";
     s[j++]=NULL;
 
     for (j = 0; s[j] != NULL; j++) code = sdscatlen(code,s[j],strlen(s[j]));
-    luaL_loadbuffer(lua,code,sdslen(code),"enable_strict_lua");
+    luaL_loadbuffer(lua,code,sdslen(code),"@enable_strict_lua");
     lua_pcall(lua,0,0,0);
     sdsfree(code);
 }
     lua_pcall(lua,0,0,0);
     sdsfree(code);
 }
@@ -455,7 +496,9 @@ void scriptingEnableGlobalsProtection(lua_State *lua) {
  * See scriptingReset() for more information. */
 void scriptingInit(void) {
     lua_State *lua = lua_open();
  * See scriptingReset() for more information. */
 void scriptingInit(void) {
     lua_State *lua = lua_open();
+
     luaLoadLibraries(lua);
     luaLoadLibraries(lua);
+    luaRemoveUnsupportedFunctions(lua);
 
     /* Initialize a dictionary we use to map SHAs to scripts.
      * This is useful for replication, as we need to replicate EVALSHA
 
     /* Initialize a dictionary we use to map SHAs to scripts.
      * This is useful for replication, as we need to replicate EVALSHA
@@ -501,6 +544,22 @@ void scriptingInit(void) {
     lua_pushcfunction(lua, luaRedisSha1hexCommand);
     lua_settable(lua, -3);
 
     lua_pushcfunction(lua, luaRedisSha1hexCommand);
     lua_settable(lua, -3);
 
+    /* redis.NIL */
+    lua_pushstring(lua, "NIL");
+    lua_newtable(lua);
+    lua_pushstring(lua, "nilbulk");
+    lua_pushboolean(lua, 1);
+    lua_settable(lua, -3);
+    lua_settable(lua, -3);
+
+    /* redis.error_reply and redis.status_reply */
+    lua_pushstring(lua, "error_reply");
+    lua_pushcfunction(lua, luaRedisErrorReplyCommand);
+    lua_settable(lua, -3);
+    lua_pushstring(lua, "status_reply");
+    lua_pushcfunction(lua, luaRedisStatusReplyCommand);
+    lua_settable(lua, -3);
+
     /* Finally set the table as 'redis' global var. */
     lua_setglobal(lua,"redis");
 
     /* Finally set the table as 'redis' global var. */
     lua_setglobal(lua,"redis");
 
@@ -525,7 +584,7 @@ void scriptingInit(void) {
                                 "  if b == false then b = '' end\n"
                                 "  return a<b\n"
                                 "end\n";
                                 "  if b == false then b = '' end\n"
                                 "  return a<b\n"
                                 "end\n";
-        luaL_loadbuffer(lua,compare_func,strlen(compare_func),"cmp_func_def");
+        luaL_loadbuffer(lua,compare_func,strlen(compare_func),"@cmp_func_def");
         lua_pcall(lua,0,0,0);
     }
 
         lua_pcall(lua,0,0,0);
     }
 
@@ -595,9 +654,30 @@ void luaReplyToRedisReply(redisClient *c, lua_State *lua) {
         addReplyLongLong(c,(long long)lua_tonumber(lua,-1));
         break;
     case LUA_TTABLE:
         addReplyLongLong(c,(long long)lua_tonumber(lua,-1));
         break;
     case LUA_TTABLE:
-        /* We need to check if it is an array, an error, or a status reply.
-         * Error are returned as a single element table with 'err' field.
-         * Status replies are returned as single elment table with 'ok' field */
+        /* The table can be an array or it may be in a special format that
+         * Lua uses to return special Redis protocol data types.
+         *
+         * 1) Errors are retuned as a single element table with 'err' field.
+         * 2) Status reply are returned as a single element table with 'ok'
+         *    field.
+         * 3) A Redis nil bulk reply is returned as a single element table
+         *    with 'nilbulk' field set to true.
+         *
+         * All the rest is considered just an array and is translated into
+         * a Redis multi bulk reply. */
+
+        /* Nil bulk reply */
+        lua_pushstring(lua,"nilbulk");
+        lua_gettable(lua,-2);
+        t = lua_type(lua,-1);
+        if (t == LUA_TBOOLEAN) {
+            addReply(c,shared.nullbulk);
+            lua_pop(lua,2);
+            return;
+        }
+        lua_pop(lua,1);
+
+        /* Error reply */
         lua_pushstring(lua,"err");
         lua_gettable(lua,-2);
         t = lua_type(lua,-1);
         lua_pushstring(lua,"err");
         lua_gettable(lua,-2);
         t = lua_type(lua,-1);
@@ -609,8 +689,9 @@ void luaReplyToRedisReply(redisClient *c, lua_State *lua) {
             lua_pop(lua,2);
             return;
         }
             lua_pop(lua,2);
             return;
         }
-
         lua_pop(lua,1);
         lua_pop(lua,1);
+
+        /* Status reply */
         lua_pushstring(lua,"ok");
         lua_gettable(lua,-2);
         t = lua_type(lua,-1);
         lua_pushstring(lua,"ok");
         lua_gettable(lua,-2);
         t = lua_type(lua,-1);
@@ -621,6 +702,7 @@ void luaReplyToRedisReply(redisClient *c, lua_State *lua) {
             sdsfree(ok);
             lua_pop(lua,1);
         } else {
             sdsfree(ok);
             lua_pop(lua,1);
         } else {
+            /* Multi bulk reply. */
             void *replylen = addDeferredMultiBulkLength(c);
             int j = 1, mbulklen = 0;
 
             void *replylen = addDeferredMultiBulkLength(c);
             int j = 1, mbulklen = 0;
 
@@ -676,7 +758,7 @@ int luaCreateFunction(redisClient *c, lua_State *lua, char *funcname, robj *body
     funcdef = sdscatlen(funcdef,body->ptr,sdslen(body->ptr));
     funcdef = sdscatlen(funcdef," end",4);
 
     funcdef = sdscatlen(funcdef,body->ptr,sdslen(body->ptr));
     funcdef = sdscatlen(funcdef," end",4);
 
-    if (luaL_loadbuffer(lua,funcdef,sdslen(funcdef),"func definition")) {
+    if (luaL_loadbuffer(lua,funcdef,sdslen(funcdef),"@user_script")) {
         addReplyErrorFormat(c,"Error compiling script (new function): %s\n",
             lua_tostring(lua,-1));
         lua_pop(lua,1);
         addReplyErrorFormat(c,"Error compiling script (new function): %s\n",
             lua_tostring(lua,-1));
         lua_pop(lua,1);
@@ -707,6 +789,7 @@ void evalGenericCommand(redisClient *c, int evalsha) {
     lua_State *lua = server.lua;
     char funcname[43];
     long long numkeys;
     lua_State *lua = server.lua;
     char funcname[43];
     long long numkeys;
+    int delhook = 0;
 
     /* We want the same PRNG sequence at every call so that our PRNG is
      * not affected by external state. */
 
     /* We want the same PRNG sequence at every call so that our PRNG is
      * not affected by external state. */
@@ -777,19 +860,19 @@ void evalGenericCommand(redisClient *c, int evalsha) {
      * is running for too much time.
      * We set the hook only if the time limit is enabled as the hook will
      * make the Lua script execution slower. */
      * is running for too much time.
      * We set the hook only if the time limit is enabled as the hook will
      * make the Lua script execution slower. */
+    server.lua_caller = c;
+    server.lua_time_start = ustime()/1000;
+    server.lua_kill = 0;
     if (server.lua_time_limit > 0 && server.masterhost == NULL) {
         lua_sethook(lua,luaMaskCountHook,LUA_MASKCOUNT,100000);
     if (server.lua_time_limit > 0 && server.masterhost == NULL) {
         lua_sethook(lua,luaMaskCountHook,LUA_MASKCOUNT,100000);
-    } else {
-        lua_sethook(lua,luaMaskCountHook,0,0);
+        delhook = 1;
     }
 
     /* At this point whatever this script was never seen before or if it was
      * already defined, we can call it. We have zero arguments and expect
      * a single return value. */
     }
 
     /* At this point whatever this script was never seen before or if it was
      * already defined, we can call it. We have zero arguments and expect
      * a single return value. */
-    server.lua_caller = c;
-    server.lua_time_start = ustime()/1000;
-    server.lua_kill = 0;
     if (lua_pcall(lua,0,1,0)) {
     if (lua_pcall(lua,0,1,0)) {
+        if (delhook) lua_sethook(lua,luaMaskCountHook,0,0); /* Disable hook */
         if (server.lua_timedout) {
             server.lua_timedout = 0;
             /* Restore the readable handler that was unregistered when the
         if (server.lua_timedout) {
             server.lua_timedout = 0;
             /* Restore the readable handler that was unregistered when the
@@ -805,6 +888,7 @@ void evalGenericCommand(redisClient *c, int evalsha) {
         lua_gc(lua,LUA_GCCOLLECT,0);
         return;
     }
         lua_gc(lua,LUA_GCCOLLECT,0);
         return;
     }
+    if (delhook) lua_sethook(lua,luaMaskCountHook,0,0); /* Disable hook */
     server.lua_timedout = 0;
     server.lua_caller = NULL;
     selectDb(c,server.lua_client->db->id); /* set DB ID from Lua client */
     server.lua_timedout = 0;
     server.lua_caller = NULL;
     selectDb(c,server.lua_client->db->id); /* set DB ID from Lua client */