]> git.saurik.com Git - redis.git/blobdiff - src/scripting.c
Sentinel: Support for AUTH.
[redis.git] / src / scripting.c
index 0f876961a221cf4c200d1c4c76edba0615585a9e..35b654f7739c490b806cbe5896d30766cbb98ad9 100644 (file)
@@ -167,6 +167,13 @@ int luaRedisGenericCommand(lua_State *lua, int raise_error) {
     redisClient *c = server.lua_client;
     sds reply;
 
+    /* Require at least one argument */
+    if (argc == 0) {
+        luaPushError(lua,
+            "Please specify at least one argument for redis.call()");
+        return 1;
+    }
+
     /* Build the arguments vector */
     argv = zmalloc(sizeof(robj*)*argc);
     for (j = 0; j < argc; j++) {
@@ -275,11 +282,10 @@ int luaRedisGenericCommand(lua_State *lua, int raise_error) {
      * reply as expected. */
     if ((cmd->flags & REDIS_CMD_SORT_FOR_SCRIPT) &&
         (reply[0] == '*' && reply[1] != '-')) {
-        /* Skip this step if command is SORT but output was already sorted */
-        if (cmd->proc != sortCommand || server.sort_dontsort)
             luaSortArray(lua);
     }
     sdsfree(reply);
+    c->reply_bytes = 0;
 
 cleanup:
     /* Clean up. Command code may have changed argv/argc so we use the
@@ -412,13 +418,59 @@ void luaLoadLibraries(lua_State *lua) {
 #endif
 }
 
+/* Remove a functions that we don't want to expose to the Redis scripting
+ * environment. */
+void luaRemoveUnsupportedFunctions(lua_State *lua) {
+    lua_pushnil(lua);
+    lua_setglobal(lua,"loadfile");
+}
+
+/* This function installs metamethods in the global table _G that prevent
+ * the creation of globals accidentally.
+ *
+ * It should be the last to be called in the scripting engine initialization
+ * sequence, because it may interact with creation of globals. */
+void scriptingEnableGlobalsProtection(lua_State *lua) {
+    char *s[32];
+    sds code = sdsempty();
+    int j = 0;
+
+    /* strict.lua from: http://metalua.luaforge.net/src/lib/strict.lua.html.
+     * Modified to be adapted to Redis. */
+    s[j++]="local mt = {}\n";
+    s[j++]="setmetatable(_G, mt)\n";
+    s[j++]="mt.__newindex = function (t, n, v)\n";
+    s[j++]="  if debug.getinfo(2) then\n";
+    s[j++]="    local w = debug.getinfo(2, \"S\").what\n";
+    s[j++]="    if w ~= \"main\" and w ~= \"C\" then\n";
+    s[j++]="      error(\"Script attempted to create global variable '\"..tostring(n)..\"'\", 2)\n";
+    s[j++]="    end\n";
+    s[j++]="  end\n";
+    s[j++]="  rawset(t, n, v)\n";
+    s[j++]="end\n";
+    s[j++]="mt.__index = function (t, n)\n";
+    s[j++]="  if debug.getinfo(2) and debug.getinfo(2, \"S\").what ~= \"C\" then\n";
+    s[j++]="    error(\"Script attempted to access unexisting global variable '\"..tostring(n)..\"'\", 2)\n";
+    s[j++]="  end\n";
+    s[j++]="  return rawget(t, n)\n";
+    s[j++]="end\n";
+    s[j++]=NULL;
+
+    for (j = 0; s[j] != NULL; j++) code = sdscatlen(code,s[j],strlen(s[j]));
+    luaL_loadbuffer(lua,code,sdslen(code),"@enable_strict_lua");
+    lua_pcall(lua,0,0,0);
+    sdsfree(code);
+}
+
 /* Initialize the scripting environment.
  * It is possible to call this function to reset the scripting environment
  * assuming that we call scriptingRelease() before.
  * See scriptingReset() for more information. */
 void scriptingInit(void) {
     lua_State *lua = lua_open();
+
     luaLoadLibraries(lua);
+    luaRemoveUnsupportedFunctions(lua);
 
     /* Initialize a dictionary we use to map SHAs to scripts.
      * This is useful for replication, as we need to replicate EVALSHA
@@ -488,7 +540,7 @@ void scriptingInit(void) {
                                 "  if b == false then b = '' end\n"
                                 "  return a<b\n"
                                 "end\n";
-        luaL_loadbuffer(lua,compare_func,strlen(compare_func),"cmp_func_def");
+        luaL_loadbuffer(lua,compare_func,strlen(compare_func),"@cmp_func_def");
         lua_pcall(lua,0,0,0);
     }
 
@@ -501,6 +553,11 @@ void scriptingInit(void) {
         server.lua_client->flags |= REDIS_LUA_CLIENT;
     }
 
+    /* Lua beginners ofter don't use "local", this is likely to introduce
+     * subtle bugs in their code. To prevent problems we protect accesses
+     * to global variables. */
+    scriptingEnableGlobalsProtection(lua);
+
     server.lua = lua;
 }
 
@@ -634,7 +691,7 @@ int luaCreateFunction(redisClient *c, lua_State *lua, char *funcname, robj *body
     funcdef = sdscatlen(funcdef,body->ptr,sdslen(body->ptr));
     funcdef = sdscatlen(funcdef," end",4);
 
-    if (luaL_loadbuffer(lua,funcdef,sdslen(funcdef),"func definition")) {
+    if (luaL_loadbuffer(lua,funcdef,sdslen(funcdef),"@user_script")) {
         addReplyErrorFormat(c,"Error compiling script (new function): %s\n",
             lua_tostring(lua,-1));
         lua_pop(lua,1);
@@ -665,6 +722,7 @@ void evalGenericCommand(redisClient *c, int evalsha) {
     lua_State *lua = server.lua;
     char funcname[43];
     long long numkeys;
+    int delhook = 0;
 
     /* We want the same PRNG sequence at every call so that our PRNG is
      * not affected by external state. */
@@ -735,19 +793,19 @@ void evalGenericCommand(redisClient *c, int evalsha) {
      * is running for too much time.
      * We set the hook only if the time limit is enabled as the hook will
      * make the Lua script execution slower. */
+    server.lua_caller = c;
+    server.lua_time_start = ustime()/1000;
+    server.lua_kill = 0;
     if (server.lua_time_limit > 0 && server.masterhost == NULL) {
         lua_sethook(lua,luaMaskCountHook,LUA_MASKCOUNT,100000);
-    } else {
-        lua_sethook(lua,luaMaskCountHook,0,0);
+        delhook = 1;
     }
 
     /* At this point whatever this script was never seen before or if it was
      * already defined, we can call it. We have zero arguments and expect
      * a single return value. */
-    server.lua_caller = c;
-    server.lua_time_start = ustime()/1000;
-    server.lua_kill = 0;
     if (lua_pcall(lua,0,1,0)) {
+        if (delhook) lua_sethook(lua,luaMaskCountHook,0,0); /* Disable hook */
         if (server.lua_timedout) {
             server.lua_timedout = 0;
             /* Restore the readable handler that was unregistered when the
@@ -763,6 +821,7 @@ void evalGenericCommand(redisClient *c, int evalsha) {
         lua_gc(lua,LUA_GCCOLLECT,0);
         return;
     }
+    if (delhook) lua_sethook(lua,luaMaskCountHook,0,0); /* Disable hook */
     server.lua_timedout = 0;
     server.lua_caller = NULL;
     selectDb(c,server.lua_client->db->id); /* set DB ID from Lua client */