]> git.saurik.com Git - redis.git/commitdiff
allow regular sets to be passed to zunionstore/zinterstore
authorJosiah Carlson <josiah@ad.ly>
Wed, 26 May 2010 00:41:35 +0000 (17:41 -0700)
committerPieter Noordhuis <pcnoordhuis@gmail.com>
Wed, 26 May 2010 14:07:04 +0000 (16:07 +0200)
redis.c
tests/unit/type/zset.tcl

diff --git a/redis.c b/redis.c
index 42dc00ace0f9bbd38812e84d67cfc0d885da84f7..f18eabe6afed27095f070f972cf614be8e1f836d 100644 (file)
--- a/redis.c
+++ b/redis.c
@@ -5922,6 +5922,7 @@ static int qsortCompareZsetopsrcByCardinality(const void *s1, const void *s2) {
 #define REDIS_AGGR_SUM 1
 #define REDIS_AGGR_MIN 2
 #define REDIS_AGGR_MAX 3
+#define zunionInterDictValue(_e) (dictGetEntryVal(_e) == NULL ? 1.0 : *(double*)dictGetEntryVal(_e))
 
 inline static void zunionInterAggregate(double *target, double val, int aggregate) {
     if (aggregate == REDIS_AGGR_SUM) {
@@ -5937,7 +5938,7 @@ inline static void zunionInterAggregate(double *target, double val, int aggregat
 }
 
 static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
-    int i, j, zsetnum;
+    int i, j, setnum;
     int aggregate = REDIS_AGGR_SUM;
     zsetopsrc *src;
     robj *dstobj;
@@ -5945,32 +5946,35 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
     dictIterator *di;
     dictEntry *de;
 
-    /* expect zsetnum input keys to be given */
-    zsetnum = atoi(c->argv[2]->ptr);
-    if (zsetnum < 1) {
+    /* expect setnum input keys to be given */
+    setnum = atoi(c->argv[2]->ptr);
+    if (setnum < 1) {
         addReplySds(c,sdsnew("-ERR at least 1 input key is needed for ZUNIONSTORE/ZINTERSTORE\r\n"));
         return;
     }
 
     /* test if the expected number of keys would overflow */
-    if (3+zsetnum > c->argc) {
+    if (3+setnum > c->argc) {
         addReply(c,shared.syntaxerr);
         return;
     }
 
     /* read keys to be used for input */
-    src = zmalloc(sizeof(zsetopsrc) * zsetnum);
-    for (i = 0, j = 3; i < zsetnum; i++, j++) {
-        robj *zsetobj = lookupKeyWrite(c->db,c->argv[j]);
-        if (!zsetobj) {
+    src = zmalloc(sizeof(zsetopsrc) * setnum);
+    for (i = 0, j = 3; i < setnum; i++, j++) {
+        robj *obj = lookupKeyWrite(c->db,c->argv[j]);
+        if (!obj) {
             src[i].dict = NULL;
         } else {
-            if (zsetobj->type != REDIS_ZSET) {
+            if (obj->type == REDIS_ZSET) {
+                src[i].dict = ((zset*)obj->ptr)->dict;
+            } else if (obj->type == REDIS_SET) {
+                src[i].dict = (obj->ptr);
+            } else {
                 zfree(src);
                 addReply(c,shared.wrongtypeerr);
                 return;
             }
-            src[i].dict = ((zset*)zsetobj->ptr)->dict;
         }
 
         /* default all weights to 1 */
@@ -5982,9 +5986,9 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
         int remaining = c->argc - j;
 
         while (remaining) {
-            if (remaining >= (zsetnum + 1) && !strcasecmp(c->argv[j]->ptr,"weights")) {
+            if (remaining >= (setnum + 1) && !strcasecmp(c->argv[j]->ptr,"weights")) {
                 j++; remaining--;
-                for (i = 0; i < zsetnum; i++, j++, remaining--) {
+                for (i = 0; i < setnum; i++, j++, remaining--) {
                     if (getDoubleFromObjectOrReply(c, c->argv[j], &src[i].weight, NULL) != REDIS_OK)
                         return;
                 }
@@ -6012,7 +6016,7 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
 
     /* sort sets from the smallest to largest, this will improve our
      * algorithm's performance */
-    qsort(src,zsetnum,sizeof(zsetopsrc), qsortCompareZsetopsrcByCardinality);
+    qsort(src,setnum,sizeof(zsetopsrc),qsortCompareZsetopsrcByCardinality);
 
     dstobj = createZsetObject();
     dstzset = dstobj->ptr;
@@ -6025,12 +6029,12 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
             di = dictGetIterator(src[0].dict);
             while((de = dictNext(di)) != NULL) {
                 double *score = zmalloc(sizeof(double)), value;
-                *score = src[0].weight * (*(double*)dictGetEntryVal(de));
+                *score = src[0].weight * zunionInterDictValue(de);
 
-                for (j = 1; j < zsetnum; j++) {
+                for (j = 1; j < setnum; j++) {
                     dictEntry *other = dictFind(src[j].dict,dictGetEntryKey(de));
                     if (other) {
-                        value = src[j].weight * (*(double*)dictGetEntryVal(other));
+                        value = src[j].weight * zunionInterDictValue(other);
                         zunionInterAggregate(score, value, aggregate);
                     } else {
                         break;
@@ -6038,7 +6042,7 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
                 }
 
                 /* skip entry when not present in every source dict */
-                if (j != zsetnum) {
+                if (j != setnum) {
                     zfree(score);
                 } else {
                     robj *o = dictGetEntryKey(de);
@@ -6051,7 +6055,7 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
             dictReleaseIterator(di);
         }
     } else if (op == REDIS_OP_UNION) {
-        for (i = 0; i < zsetnum; i++) {
+        for (i = 0; i < setnum; i++) {
             if (!src[i].dict) continue;
 
             di = dictGetIterator(src[i].dict);
@@ -6060,14 +6064,14 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
                 if (dictFind(dstzset->dict,dictGetEntryKey(de)) != NULL) continue;
 
                 double *score = zmalloc(sizeof(double)), value;
-                *score = src[i].weight * (*(double*)dictGetEntryVal(de));
+                *score = src[i].weight * zunionInterDictValue(de);
 
                 /* because the zsets are sorted by size, its only possible
                  * for sets at larger indices to hold this entry */
-                for (j = (i+1); j < zsetnum; j++) {
+                for (j = (i+1); j < setnum; j++) {
                     dictEntry *other = dictFind(src[j].dict,dictGetEntryKey(de));
                     if (other) {
-                        value = src[j].weight * (*(double*)dictGetEntryVal(other));
+                        value = src[j].weight * zunionInterDictValue(other);
                         zunionInterAggregate(score, value, aggregate);
                     }
                 }
index 107bd6b146124e26a54127c2c2cc3d424acd4ede..cb78515d556464066e9f0823ab4936bc84159d87 100644 (file)
@@ -316,6 +316,14 @@ start_server default.conf {} {
         list [r zunionstore zsetc 2 zseta zsetb weights 2 3] [r zrange zsetc 0 -1 withscores]
     } {4 {a 2 b 7 d 9 c 12}}
 
+       test {ZUNIONSTORE with a regular set and weights} {
+               r del seta
+               r sadd seta a
+               r sadd seta b
+               r sadd seta c
+        list [r zunionstore zsetc 2 seta zsetb weights 2 3] [r zrange zsetc 0 -1 withscores]
+       } {4 {a 2 b 5 c 8 d 9}}
+
     test {ZUNIONSTORE with AGGREGATE MIN} {
         list [r zunionstore zsetc 2 zseta zsetb aggregate min] [r zrange zsetc 0 -1 withscores]
     } {4 {a 1 b 1 c 2 d 3}}
@@ -332,6 +340,14 @@ start_server default.conf {} {
         list [r zinterstore zsetc 2 zseta zsetb weights 2 3] [r zrange zsetc 0 -1 withscores]
     } {2 {b 7 c 12}}
 
+       test {ZINTERSTORE with a regular set and weights} {
+               r del seta
+               r sadd seta a
+               r sadd seta b
+               r sadd seta c
+        list [r zinterstore zsetc 2 seta zsetb weights 2 3] [r zrange zsetc 0 -1 withscores]
+       } {2 {b 5 c 8}}
+
     test {ZINTERSTORE with AGGREGATE MIN} {
         list [r zinterstore zsetc 2 zseta zsetb aggregate min] [r zrange zsetc 0 -1 withscores]
     } {2 {b 1 c 2}}