]> git.saurik.com Git - redis.git/commitdiff
Merged zsetops branch from Pietern
authorantirez <antirez@gmail.com>
Tue, 9 Mar 2010 15:25:55 +0000 (16:25 +0100)
committerantirez <antirez@gmail.com>
Tue, 9 Mar 2010 15:25:55 +0000 (16:25 +0100)
redis-cli.c
redis.c
test-redis.tcl

index 807f676d11af9ff2b7b7afbfa41a4a2e031c2226..04ff947e3da2744b825b8d965418b6abca5a18ba 100644 (file)
@@ -101,6 +101,8 @@ static struct redisCommand cmdTable[] = {
     {"zincrby",4,REDIS_CMD_BULK},
     {"zrem",3,REDIS_CMD_BULK},
     {"zremrangebyscore",4,REDIS_CMD_INLINE},
+    {"zmerge",-3,REDIS_CMD_INLINE},
+    {"zmergeweighed",-4,REDIS_CMD_INLINE},
     {"zrange",-4,REDIS_CMD_INLINE},
     {"zrank",3,REDIS_CMD_BULK},
     {"zrangebyscore",-4,REDIS_CMD_INLINE},
diff --git a/redis.c b/redis.c
index 8658bc1979263919d8ab0b809d03db3c8f2226e8..d15efbbc7763595cd84404f622e18d89bf464b88 100644 (file)
--- a/redis.c
+++ b/redis.c
@@ -678,6 +678,8 @@ static void zrevrankCommand(redisClient *c);
 static void hsetCommand(redisClient *c);
 static void hgetCommand(redisClient *c);
 static void zremrangebyrankCommand(redisClient *c);
+static void zunionCommand(redisClient *c);
+static void zinterCommand(redisClient *c);
 
 /*================================= Globals ================================= */
 
@@ -726,6 +728,8 @@ static struct redisCommand cmdTable[] = {
     {"zrem",zremCommand,3,REDIS_CMD_BULK,1,1,1},
     {"zremrangebyscore",zremrangebyscoreCommand,4,REDIS_CMD_INLINE,1,1,1},
     {"zremrangebyrank",zremrangebyrankCommand,4,REDIS_CMD_INLINE,1,1,1},
+    {"zunion",zunionCommand,-4,REDIS_CMD_INLINE|REDIS_CMD_DENYOOM,0,0,0},
+    {"zinter",zinterCommand,-4,REDIS_CMD_INLINE|REDIS_CMD_DENYOOM,0,0,0},
     {"zrange",zrangeCommand,-4,REDIS_CMD_INLINE,1,1,1},
     {"zrangebyscore",zrangebyscoreCommand,-4,REDIS_CMD_INLINE,1,1,1},
     {"zcount",zcountCommand,4,REDIS_CMD_INLINE,1,1,1},
@@ -4837,6 +4841,7 @@ static void sinterstoreCommand(redisClient *c) {
 
 #define REDIS_OP_UNION 0
 #define REDIS_OP_DIFF 1
+#define REDIS_OP_INTER 2
 
 static void sunionDiffGenericCommand(redisClient *c, robj **setskeys, int setsnum, robj *dstkey, int op) {
     dict **dv = zmalloc(sizeof(dict*)*setsnum);
@@ -5451,6 +5456,171 @@ static void zremrangebyrankCommand(redisClient *c) {
     }
 }
 
+typedef struct {
+    dict *dict;
+    double weight;
+} zsetopsrc;
+
+static int qsortCompareZsetopsrcByCardinality(const void *s1, const void *s2) {
+    zsetopsrc *d1 = (void*) s1, *d2 = (void*) s2;
+    unsigned long size1, size2;
+    size1 = d1->dict ? dictSize(d1->dict) : 0;
+    size2 = d2->dict ? dictSize(d2->dict) : 0;
+    return size1 - size2;
+}
+
+static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
+    int i, j, zsetnum;
+    zsetopsrc *src;
+    robj *dstobj;
+    zset *dstzset;
+    dictIterator *di;
+    dictEntry *de;
+
+    /* expect zsetnum input keys to be given */
+    zsetnum = atoi(c->argv[2]->ptr);
+    if (zsetnum < 1) {
+        addReplySds(c,sdsnew("-ERR at least 1 input key is needed for ZUNION/ZINTER\r\n"));
+        return;
+    }
+
+    /* test if the expected number of keys would overflow */
+    if (3+zsetnum > c->argc) {
+        addReply(c,shared.syntaxerr);
+        return;
+    }
+
+    /* read keys to be used for input */
+    src = malloc(sizeof(zsetopsrc) * zsetnum);
+    for (i = 0, j = 3; i < zsetnum; i++, j++) {
+        robj *zsetobj = lookupKeyWrite(c->db,c->argv[j]);
+        if (!zsetobj) {
+            src[i].dict = NULL;
+        } else {
+            if (zsetobj->type != REDIS_ZSET) {
+                zfree(src);
+                addReply(c,shared.wrongtypeerr);
+                return;
+            }
+            src[i].dict = ((zset*)zsetobj->ptr)->dict;
+        }
+
+        /* default all weights to 1 */
+        src[i].weight = 1.0;
+    }
+
+    /* parse optional extra arguments */
+    if (j < c->argc) {
+        int remaining = c->argc-j;
+
+        while (remaining) {
+            if (!strcasecmp(c->argv[j]->ptr,"weights")) {
+                j++; remaining--;
+                if (remaining < zsetnum) {
+                    zfree(src);
+                    addReplySds(c,sdsnew("-ERR not enough weights for ZUNION/ZINTER\r\n"));
+                    return;
+                }
+                for (i = 0; i < zsetnum; i++, j++, remaining--) {
+                    src[i].weight = strtod(c->argv[j]->ptr, NULL);
+                }
+            } else {
+                zfree(src);
+                addReply(c,shared.syntaxerr);
+                return;
+            }
+        }
+    }
+
+    dstobj = createZsetObject();
+    dstzset = dstobj->ptr;
+
+    if (op == REDIS_OP_INTER) {
+        /* sort sets from the smallest to largest, this will improve our
+         * algorithm's performance */
+        qsort(src,zsetnum,sizeof(zsetopsrc), qsortCompareZsetopsrcByCardinality);
+
+        /* skip going over all entries if the smallest zset is NULL or empty */
+        if (src[0].dict && dictSize(src[0].dict) > 0) {
+            /* precondition: as src[0].dict is non-empty and the zsets are ordered
+             * from small to large, all src[i > 0].dict are non-empty too */
+            di = dictGetIterator(src[0].dict);
+            while((de = dictNext(di)) != NULL) {
+                double *score = zmalloc(sizeof(double));
+                *score = 0.0;
+
+                for (j = 0; j < zsetnum; j++) {
+                    dictEntry *other = (j == 0) ? de : dictFind(src[j].dict,dictGetEntryKey(de));
+                    if (other) {
+                        *score = *score + src[j].weight * (*(double*)dictGetEntryVal(other));
+                    } else {
+                        break;
+                    }
+                }
+
+                /* skip entry when not present in every source dict */
+                if (j != zsetnum) {
+                    zfree(score);
+                } else {
+                    robj *o = dictGetEntryKey(de);
+                    dictAdd(dstzset->dict,o,score);
+                    incrRefCount(o); /* added to dictionary */
+                    zslInsert(dstzset->zsl,*score,o);
+                    incrRefCount(o); /* added to skiplist */
+                }
+            }
+            dictReleaseIterator(di);
+        }
+    } else if (op == REDIS_OP_UNION) {
+        for (i = 0; i < zsetnum; i++) {
+            if (!src[i].dict) continue;
+
+            di = dictGetIterator(src[i].dict);
+            while((de = dictNext(di)) != NULL) {
+                /* skip key when already processed */
+                if (dictFind(dstzset->dict,dictGetEntryKey(de)) != NULL) continue;
+
+                double *score = zmalloc(sizeof(double));
+                *score = 0.0;
+                for (j = 0; j < zsetnum; j++) {
+                    if (!src[j].dict) continue;
+
+                    dictEntry *other = (i == j) ? de : dictFind(src[j].dict,dictGetEntryKey(de));
+                    if (other) {
+                        *score = *score + src[j].weight * (*(double*)dictGetEntryVal(other));
+                    }
+                }
+
+                robj *o = dictGetEntryKey(de);
+                dictAdd(dstzset->dict,o,score);
+                incrRefCount(o); /* added to dictionary */
+                zslInsert(dstzset->zsl,*score,o);
+                incrRefCount(o); /* added to skiplist */
+            }
+            dictReleaseIterator(di);
+        }
+    } else {
+        /* unknown operator */
+        redisAssert(op == REDIS_OP_INTER || op == REDIS_OP_UNION);
+    }
+
+    deleteKey(c->db,dstkey);
+    dictAdd(c->db->dict,dstkey,dstobj);
+    incrRefCount(dstkey);
+
+    addReplyLong(c, dstzset->zsl->length);
+    server.dirty++;
+    zfree(src);
+}
+
+static void zunionCommand(redisClient *c) {
+    zunionInterGenericCommand(c,c->argv[1], REDIS_OP_UNION);
+}
+
+static void zinterCommand(redisClient *c) {
+    zunionInterGenericCommand(c,c->argv[1], REDIS_OP_INTER);
+}
+
 static void zrangeGenericCommand(redisClient *c, int reverse) {
     robj *o;
     int start = atoi(c->argv[2]->ptr);
index 9139f544754ae8abd17217b9881e63b4de921de9..00370a4c4b955f8bc442d135613dfad3661b7f4f 100644 (file)
@@ -1476,6 +1476,29 @@ proc main {server port} {
         list [$r zremrangebyrank zset 1 3] [$r zrange zset 0 -1]
     } {3 {a e}}
 
+    test {ZUNION basics} {
+        $r del zseta zsetb zsetc
+        $r zadd zseta 1 a
+        $r zadd zseta 2 b
+        $r zadd zseta 3 c
+        $r zadd zsetb 1 b
+        $r zadd zsetb 2 c
+        $r zadd zsetb 3 d
+        list [$r zunion zsetc 2 zseta zsetb] [$r zrange zsetc 0 -1 withscores]
+    } {4 {a 1 b 3 d 3 c 5}}
+
+    test {ZUNION with weights} {
+        list [$r zunion zsetc 2 zseta zsetb weights 2 3] [$r zrange zsetc 0 -1 withscores]
+    } {4 {a 2 b 7 d 9 c 12}}
+
+    test {ZINTER basics} {
+        list [$r zinter zsetc 2 zseta zsetb] [$r zrange zsetc 0 -1 withscores]
+    } {2 {b 3 c 5}}
+
+    test {ZINTER with weights} {
+        list [$r zinter zsetc 2 zseta zsetb weights 2 3] [$r zrange zsetc 0 -1 withscores]
+    } {2 {b 7 c 12}}
+
     test {SORT against sorted sets} {
         $r del zset
         $r zadd zset 1 a