added explicit AGGREGATE [SUM|MIN|MAX] option to ZUNION/ZINTER
authorPieter Noordhuis <pcnoordhuis@gmail.com>
Sat, 13 Mar 2010 22:27:22 +0000 (23:27 +0100)
committerPieter Noordhuis <pcnoordhuis@gmail.com>
Tue, 16 Mar 2010 19:34:45 +0000 (20:34 +0100)
redis.c
test-redis.tcl

diff --git a/redis.c b/redis.c
index 302944de8c4b92e7ff52ee4d26788fdc498e9f75..075f1f81e7b87e9a514ec62c48a961cc27b2dbc3 100644 (file)
--- a/redis.c
+++ b/redis.c
@@ -5387,8 +5387,26 @@ static int qsortCompareZsetopsrcByCardinality(const void *s1, const void *s2) {
     return size1 - size2;
 }
 
+#define REDIS_AGGR_SUM 1
+#define REDIS_AGGR_MIN 2
+#define REDIS_AGGR_MAX 3
+
+inline static void zunionInterAggregate(double *target, double val, int aggregate) {
+    if (aggregate == REDIS_AGGR_SUM) {
+        *target = *target + val;
+    } else if (aggregate == REDIS_AGGR_MIN) {
+        *target = val < *target ? val : *target;
+    } else if (aggregate == REDIS_AGGR_MAX) {
+        *target = val > *target ? val : *target;
+    } else {
+        /* safety net */
+        redisAssert(0 != 0);
+    }
+}
+
 static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
     int i, j, zsetnum;
+    int aggregate = REDIS_AGGR_SUM;
     zsetopsrc *src;
     robj *dstobj;
     zset *dstzset;
@@ -5429,19 +5447,28 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
 
     /* parse optional extra arguments */
     if (j < c->argc) {
-        int remaining = c->argc-j;
+        int remaining = c->argc - j;
 
         while (remaining) {
-            if (!strcasecmp(c->argv[j]->ptr,"weights")) {
+            if (remaining >= (zsetnum + 1) && !strcasecmp(c->argv[j]->ptr,"weights")) {
                 j++; remaining--;
-                if (remaining < zsetnum) {
-                    zfree(src);
-                    addReplySds(c,sdsnew("-ERR not enough weights for ZUNION/ZINTER\r\n"));
-                    return;
-                }
                 for (i = 0; i < zsetnum; i++, j++, remaining--) {
                     src[i].weight = strtod(c->argv[j]->ptr, NULL);
                 }
+            } else if (remaining >= 2 && !strcasecmp(c->argv[j]->ptr,"aggregate")) {
+                j++; remaining--;
+                if (!strcasecmp(c->argv[j]->ptr,"sum")) {
+                    aggregate = REDIS_AGGR_SUM;
+                } else if (!strcasecmp(c->argv[j]->ptr,"min")) {
+                    aggregate = REDIS_AGGR_MIN;
+                } else if (!strcasecmp(c->argv[j]->ptr,"max")) {
+                    aggregate = REDIS_AGGR_MAX;
+                } else {
+                    zfree(src);
+                    addReply(c,shared.syntaxerr);
+                    return;
+                }
+                j++; remaining--;
             } else {
                 zfree(src);
                 addReply(c,shared.syntaxerr);
@@ -5450,27 +5477,28 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
         }
     }
 
+    /* sort sets from the smallest to largest, this will improve our
+     * algorithm's performance */
+    qsort(src,zsetnum,sizeof(zsetopsrc), qsortCompareZsetopsrcByCardinality);
+
     dstobj = createZsetObject();
     dstzset = dstobj->ptr;
 
     if (op == REDIS_OP_INTER) {
-        /* sort sets from the smallest to largest, this will improve our
-         * algorithm's performance */
-        qsort(src,zsetnum,sizeof(zsetopsrc), qsortCompareZsetopsrcByCardinality);
-
         /* skip going over all entries if the smallest zset is NULL or empty */
         if (src[0].dict && dictSize(src[0].dict) > 0) {
             /* precondition: as src[0].dict is non-empty and the zsets are ordered
              * from small to large, all src[i > 0].dict are non-empty too */
             di = dictGetIterator(src[0].dict);
             while((de = dictNext(di)) != NULL) {
-                double *score = zmalloc(sizeof(double));
-                *score = 0.0;
+                double *score = zmalloc(sizeof(double)), value;
+                *score = src[0].weight * (*(double*)dictGetEntryVal(de));
 
-                for (j = 0; j < zsetnum; j++) {
-                    dictEntry *other = (j == 0) ? de : dictFind(src[j].dict,dictGetEntryKey(de));
+                for (j = 1; j < zsetnum; j++) {
+                    dictEntry *other = dictFind(src[j].dict,dictGetEntryKey(de));
                     if (other) {
-                        *score = *score + src[j].weight * (*(double*)dictGetEntryVal(other));
+                        value = src[j].weight * (*(double*)dictGetEntryVal(other));
+                        zunionInterAggregate(score, value, aggregate);
                     } else {
                         break;
                     }
@@ -5498,14 +5526,16 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
                 /* skip key when already processed */
                 if (dictFind(dstzset->dict,dictGetEntryKey(de)) != NULL) continue;
 
-                double *score = zmalloc(sizeof(double));
-                *score = 0.0;
-                for (j = 0; j < zsetnum; j++) {
-                    if (!src[j].dict) continue;
+                double *score = zmalloc(sizeof(double)), value;
+                *score = src[i].weight * (*(double*)dictGetEntryVal(de));
 
-                    dictEntry *other = (i == j) ? de : dictFind(src[j].dict,dictGetEntryKey(de));
+                /* because the zsets are sorted by size, its only possible
+                 * for sets at larger indices to hold this entry */
+                for (j = (i+1); j < zsetnum; j++) {
+                    dictEntry *other = dictFind(src[j].dict,dictGetEntryKey(de));
                     if (other) {
-                        *score = *score + src[j].weight * (*(double*)dictGetEntryVal(other));
+                        value = src[j].weight * (*(double*)dictGetEntryVal(other));
+                        zunionInterAggregate(score, value, aggregate);
                     }
                 }
 
index 66aa0b30a872610f416e56a1167138edc9323649..1b1f82e900b5b36effb67b83bde3a38112f6aeb7 100644 (file)
@@ -1491,6 +1491,14 @@ proc main {server port} {
         list [$r zunion zsetc 2 zseta zsetb weights 2 3] [$r zrange zsetc 0 -1 withscores]
     } {4 {a 2 b 7 d 9 c 12}}
 
+    test {ZUNION with AGGREGATE MIN} {
+        list [$r zunion zsetc 2 zseta zsetb aggregate min] [$r zrange zsetc 0 -1 withscores]
+    } {4 {a 1 b 1 c 2 d 3}}
+
+    test {ZUNION with AGGREGATE MAX} {
+        list [$r zunion zsetc 2 zseta zsetb aggregate max] [$r zrange zsetc 0 -1 withscores]
+    } {4 {a 1 b 2 c 3 d 3}}
+
     test {ZINTER basics} {
         list [$r zinter zsetc 2 zseta zsetb] [$r zrange zsetc 0 -1 withscores]
     } {2 {b 3 c 5}}
@@ -1499,6 +1507,14 @@ proc main {server port} {
         list [$r zinter zsetc 2 zseta zsetb weights 2 3] [$r zrange zsetc 0 -1 withscores]
     } {2 {b 7 c 12}}
 
+    test {ZINTER with AGGREGATE MIN} {
+        list [$r zinter zsetc 2 zseta zsetb aggregate min] [$r zrange zsetc 0 -1 withscores]
+    } {2 {b 1 c 2}}
+
+    test {ZINTER with AGGREGATE MAX} {
+        list [$r zinter zsetc 2 zseta zsetb aggregate max] [$r zrange zsetc 0 -1 withscores]
+    } {2 {b 2 c 3}}
+
     test {SORT against sorted sets} {
         $r del zset
         $r zadd zset 1 a