replaced ZMERGE by ZUNION and ZINTER. note: key preloading by the VM does not yet...
authorPieter Noordhuis <pcnoordhuis@gmail.com>
Tue, 9 Mar 2010 11:38:50 +0000 (12:38 +0100)
committerPieter Noordhuis <pcnoordhuis@gmail.com>
Tue, 9 Mar 2010 11:38:50 +0000 (12:38 +0100)
redis.c
test-redis.tcl

diff --git a/redis.c b/redis.c
index e4db385d601e63e6d4370b14bbe261168bd6e36e..94ab6720b1f3db1671dfbe65c01641530f269a1d 100644 (file)
--- a/redis.c
+++ b/redis.c
@@ -675,8 +675,8 @@ static void substrCommand(redisClient *c);
 static void zrankCommand(redisClient *c);
 static void hsetCommand(redisClient *c);
 static void hgetCommand(redisClient *c);
 static void zrankCommand(redisClient *c);
 static void hsetCommand(redisClient *c);
 static void hgetCommand(redisClient *c);
-static void zmergeCommand(redisClient *c);
-static void zmergeweighedCommand(redisClient *c);
+static void zunionCommand(redisClient *c);
+static void zinterCommand(redisClient *c);
 
 /*================================= Globals ================================= */
 
 
 /*================================= Globals ================================= */
 
@@ -724,8 +724,8 @@ static struct redisCommand cmdTable[] = {
     {"zincrby",zincrbyCommand,4,REDIS_CMD_BULK|REDIS_CMD_DENYOOM,1,1,1},
     {"zrem",zremCommand,3,REDIS_CMD_BULK,1,1,1},
     {"zremrangebyscore",zremrangebyscoreCommand,4,REDIS_CMD_INLINE,1,1,1},
     {"zincrby",zincrbyCommand,4,REDIS_CMD_BULK|REDIS_CMD_DENYOOM,1,1,1},
     {"zrem",zremCommand,3,REDIS_CMD_BULK,1,1,1},
     {"zremrangebyscore",zremrangebyscoreCommand,4,REDIS_CMD_INLINE,1,1,1},
-    {"zmerge",zmergeCommand,-3,REDIS_CMD_INLINE|REDIS_CMD_DENYOOM,2,-1,1},
-    {"zmergeweighed",zmergeweighedCommand,-4,REDIS_CMD_INLINE|REDIS_CMD_DENYOOM,2,-2,2},
+    {"zunion",zunionCommand,-4,REDIS_CMD_INLINE|REDIS_CMD_DENYOOM,0,0,0},
+    {"zinter",zinterCommand,-4,REDIS_CMD_INLINE|REDIS_CMD_DENYOOM,0,0,0},
     {"zrange",zrangeCommand,-4,REDIS_CMD_INLINE,1,1,1},
     {"zrangebyscore",zrangebyscoreCommand,-4,REDIS_CMD_INLINE,1,1,1},
     {"zcount",zcountCommand,4,REDIS_CMD_INLINE,1,1,1},
     {"zrange",zrangeCommand,-4,REDIS_CMD_INLINE,1,1,1},
     {"zrangebyscore",zrangebyscoreCommand,-4,REDIS_CMD_INLINE,1,1,1},
     {"zcount",zcountCommand,4,REDIS_CMD_INLINE,1,1,1},
@@ -4771,6 +4771,7 @@ static void sinterstoreCommand(redisClient *c) {
 
 #define REDIS_OP_UNION 0
 #define REDIS_OP_DIFF 1
 
 #define REDIS_OP_UNION 0
 #define REDIS_OP_DIFF 1
+#define REDIS_OP_INTER 2
 
 static void sunionDiffGenericCommand(redisClient *c, robj **setskeys, int setsnum, robj *dstkey, int op) {
     dict **dv = zmalloc(sizeof(dict*)*setsnum);
 
 static void sunionDiffGenericCommand(redisClient *c, robj **setskeys, int setsnum, robj *dstkey, int op) {
     dict **dv = zmalloc(sizeof(dict*)*setsnum);
@@ -5329,103 +5330,166 @@ static void zremrangebyscoreCommand(redisClient *c) {
     }
 }
 
     }
 }
 
-/* This command merges 2 or more zsets to a destination. When an element
- * does not exist in a certain set, score 0 is assumed. The score for an
- * element across sets is summed. */
-static void zmergeGenericCommand(redisClient *c, int readweights) {
-    int i, j, zsetnum;
-    dict **srcdict;
+static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
+    int i, j, k, zsetnum;
+    dict **srcdicts;
     double *weights;
     double *weights;
-    robj *dstkey = c->argv[1], *dstobj;
-    zset *dst;
+    robj *dstobj;
+    zset *dstzset;
     dictIterator *di;
     dictEntry *de;
 
     dictIterator *di;
     dictEntry *de;
 
-    zsetnum = c->argc-2;
-    if (readweights) {
-        /* force number of arguments to be even */
-        if (zsetnum % 2 > 0) {
-            addReplySds(c,sdsnew("-ERR wrong number of arguments for ZMERGEWEIGHED\r\n"));
-            return;
-        }
-        zsetnum /= 2;
+    /* expect zsetnum input keys to be given */
+    zsetnum = atoi(c->argv[2]->ptr);
+    if (zsetnum < 1) {
+        addReplySds(c,sdsnew("-ERR at least 1 input key is needed for ZUNION/ZINTER\r\n"));
+        return;
     }
     }
-    if (!zsetnum) {
+
+    /* test if the expected number of keys would overflow */
+    if (3+zsetnum > c->argc) {
         addReply(c,shared.syntaxerr);
         return;
     }
 
         addReply(c,shared.syntaxerr);
         return;
     }
 
-    srcdict = zmalloc(sizeof(dict*) * zsetnum);
+    /* read keys to be used for input */
+    srcdicts = zmalloc(sizeof(dict*) * zsetnum);
     weights = zmalloc(sizeof(double) * zsetnum);
     weights = zmalloc(sizeof(double) * zsetnum);
-    for (i = 0; i < zsetnum; i++) {
-        if (readweights) {
-            j = 2 + 2*i;
-            weights[i] = strtod(c->argv[j+1]->ptr, NULL);
-        } else {
-            j = 2 + i;
-            weights[i] = 1.0;
-        }
-
+    for (i = 0, j = 3; i < zsetnum; i++, j++) {
         robj *zsetobj = lookupKeyWrite(c->db,c->argv[j]);
         if (!zsetobj) {
         robj *zsetobj = lookupKeyWrite(c->db,c->argv[j]);
         if (!zsetobj) {
-            srcdict[i] = NULL;
+            srcdicts[i] = NULL;
         } else {
             if (zsetobj->type != REDIS_ZSET) {
         } else {
             if (zsetobj->type != REDIS_ZSET) {
-                zfree(srcdict);
+                zfree(srcdicts);
                 zfree(weights);
                 addReply(c,shared.wrongtypeerr);
                 return;
             }
                 zfree(weights);
                 addReply(c,shared.wrongtypeerr);
                 return;
             }
-            srcdict[i] = ((zset*)zsetobj->ptr)->dict;
+            srcdicts[i] = ((zset*)zsetobj->ptr)->dict;
         }
         }
+
+        /* default all weights to 1 */
+        weights[i] = 1.0;
     }
 
     }
 
-    dstobj = createZsetObject();
-    dst = dstobj->ptr;
-    for (i = 0; i < zsetnum; i++) {
-        if (!srcdict[i]) continue;
+    /* parse optional extra arguments */
+    if (j < c->argc) {
+        int remaining = c->argc-j;
 
 
-        di = dictGetIterator(srcdict[i]);
-        while((de = dictNext(di)) != NULL) {
-            /* skip key when already processed */
-            if (dictFind(dst->dict,dictGetEntryKey(de)) != NULL) continue;
+        while (remaining) {
+            if (!strcasecmp(c->argv[j]->ptr,"weights")) {
+                j++; remaining--;
+                if (remaining < zsetnum) {
+                    zfree(srcdicts);
+                    zfree(weights);
+                    addReplySds(c,sdsnew("-ERR not enough weights for ZUNION/ZINTER\r\n"));
+                    return;
+                }
+                for (i = 0; i < zsetnum; i++, j++, remaining--) {
+                    weights[i] = strtod(c->argv[j]->ptr, NULL);
+                }
+            } else {
+                zfree(srcdicts);
+                zfree(weights);
+                addReply(c,shared.syntaxerr);
+                return;
+            }
+        }
+    }
 
 
-            double *score = zmalloc(sizeof(double));
-            *score = 0.0;
-            for (j = 0; j < zsetnum; j++) {
-                if (!srcdict[j]) continue;
+    dstobj = createZsetObject();
+    dstzset = dstobj->ptr;
+
+    if (op == REDIS_OP_INTER) {
+        /* store index of smallest zset in variable j */
+        for (i = 0, j = 0; i < zsetnum; i++) {
+            if (!srcdicts[i] || dictSize(srcdicts[i]) == 0) {
+                break;
+            }
+            if (dictSize(srcdicts[i]) < dictSize(srcdicts[j])) {
+                j = i;
+            }
+        }
+        /* skip going over all entries if at least one dict was NULL or empty */
+        if (i == zsetnum) {
+            /* precondition: all srcdicts are non-NULL and non-empty */
+            di = dictGetIterator(srcdicts[j]);
+            while((de = dictNext(di)) != NULL) {
+                double *score = zmalloc(sizeof(double));
+                *score = 0.0;
+
+                for (k = 0; k < zsetnum; k++) {
+                    dictEntry *other = (k == j) ? de : dictFind(srcdicts[k],dictGetEntryKey(de));
+                    if (other) {
+                        *score = *score + weights[k] * (*(double*)dictGetEntryVal(other));
+                    } else {
+                        break;
+                    }
+                }
 
 
-                dictEntry *other = dictFind(srcdict[j],dictGetEntryKey(de));
-                if (other) {
-                    *score = *score + weights[j] * (*(double*)dictGetEntryVal(other));
+                /* skip entry when not present in every source dict */
+                if (k != zsetnum) {
+                    zfree(score);
+                } else {
+                    robj *o = dictGetEntryKey(de);
+                    dictAdd(dstzset->dict,o,score);
+                    incrRefCount(o); /* added to dictionary */
+                    zslInsert(dstzset->zsl,*score,o);
+                    incrRefCount(o); /* added to skiplist */
                 }
             }
                 }
             }
+            dictReleaseIterator(di);
+        }
+    } else if (op == REDIS_OP_UNION) {
+        for (i = 0; i < zsetnum; i++) {
+            if (!srcdicts[i]) continue;
+
+            di = dictGetIterator(srcdicts[i]);
+            while((de = dictNext(di)) != NULL) {
+                /* skip key when already processed */
+                if (dictFind(dstzset->dict,dictGetEntryKey(de)) != NULL) continue;
+
+                double *score = zmalloc(sizeof(double));
+                *score = 0.0;
+                for (j = 0; j < zsetnum; j++) {
+                    if (!srcdicts[j]) continue;
+
+                    dictEntry *other = (i == j) ? de : dictFind(srcdicts[j],dictGetEntryKey(de));
+                    if (other) {
+                        *score = *score + weights[j] * (*(double*)dictGetEntryVal(other));
+                    }
+                }
 
 
-            robj *o = dictGetEntryKey(de);
-            dictAdd(dst->dict,o,score);
-            incrRefCount(o); /* added to dictionary */
-            zslInsert(dst->zsl,*score,o);
-            incrRefCount(o); /* added to skiplist */
+                robj *o = dictGetEntryKey(de);
+                dictAdd(dstzset->dict,o,score);
+                incrRefCount(o); /* added to dictionary */
+                zslInsert(dstzset->zsl,*score,o);
+                incrRefCount(o); /* added to skiplist */
+            }
+            dictReleaseIterator(di);
         }
         }
-        dictReleaseIterator(di);
+    } else {
+        /* unknown operator */
+        redisAssert(op == REDIS_OP_INTER || op == REDIS_OP_UNION);
     }
 
     deleteKey(c->db,dstkey);
     dictAdd(c->db->dict,dstkey,dstobj);
     incrRefCount(dstkey);
 
     }
 
     deleteKey(c->db,dstkey);
     dictAdd(c->db->dict,dstkey,dstobj);
     incrRefCount(dstkey);
 
-    addReplyLong(c, dst->zsl->length);
+    addReplyLong(c, dstzset->zsl->length);
     server.dirty++;
     server.dirty++;
-    zfree(srcdict);
+    zfree(srcdicts);
     zfree(weights);
 }
 
     zfree(weights);
 }
 
-static void zmergeCommand(redisClient *c) {
-    zmergeGenericCommand(c,0);
+static void zunionCommand(redisClient *c) {
+    zunionInterGenericCommand(c,c->argv[1], REDIS_OP_UNION);
 }
 
 }
 
-static void zmergeweighedCommand(redisClient *c) {
-    zmergeGenericCommand(c,1);
+static void zinterCommand(redisClient *c) {
+    zunionInterGenericCommand(c,c->argv[1], REDIS_OP_INTER);
 }
 
 static void zrangeGenericCommand(redisClient *c, int reverse) {
 }
 
 static void zrangeGenericCommand(redisClient *c, int reverse) {
index 9d458ed7b497db99af4fa9a831f412618e535367..54b614fd3a3fa4b896cf45becdc7811db0a3bce4 100644 (file)
@@ -1462,7 +1462,7 @@ proc main {server port} {
         list [$r zremrangebyscore zset -inf +inf] [$r zrange zset 0 -1]
     } {5 {}}
 
         list [$r zremrangebyscore zset -inf +inf] [$r zrange zset 0 -1]
     } {5 {}}
 
-    test {ZMERGE basics} {
+    test {ZUNION basics} {
         $r del zseta zsetb zsetc
         $r zadd zseta 1 a
         $r zadd zseta 2 b
         $r del zseta zsetb zsetc
         $r zadd zseta 1 a
         $r zadd zseta 2 b
@@ -1470,13 +1470,21 @@ proc main {server port} {
         $r zadd zsetb 1 b
         $r zadd zsetb 2 c
         $r zadd zsetb 3 d
         $r zadd zsetb 1 b
         $r zadd zsetb 2 c
         $r zadd zsetb 3 d
-        list [$r zmerge zsetc zseta zsetb] [$r zrange zsetc 0 -1 withscores]
+        list [$r zunion zsetc 2 zseta zsetb] [$r zrange zsetc 0 -1 withscores]
     } {4 {a 1 b 3 d 3 c 5}}
 
     } {4 {a 1 b 3 d 3 c 5}}
 
-    test {ZMERGEWEIGHED basics} {
-        list [$r zmergeweighed zsetc zseta 2 zsetb 3] [$r zrange zsetc 0 -1 withscores]
+    test {ZUNION with weights} {
+        list [$r zunion zsetc 2 zseta zsetb weights 2 3] [$r zrange zsetc 0 -1 withscores]
     } {4 {a 2 b 7 d 9 c 12}}
 
     } {4 {a 2 b 7 d 9 c 12}}
 
+    test {ZINTER basics} {
+        list [$r zinter zsetc 2 zseta zsetb] [$r zrange zsetc 0 -1 withscores]
+    } {2 {b 3 c 5}}
+
+    test {ZINTER with weights} {
+        list [$r zinter zsetc 2 zseta zsetb weights 2 3] [$r zrange zsetc 0 -1 withscores]
+    } {2 {b 7 c 12}}
+
     test {SORT against sorted sets} {
         $r del zset
         $r zadd zset 1 a
     test {SORT against sorted sets} {
         $r del zset
         $r zadd zset 1 a