]> git.saurik.com Git - redis.git/blobdiff - redis.c
Merge branch 'aggregates' of git://github.com/pietern/redis
[redis.git] / redis.c
diff --git a/redis.c b/redis.c
index 87575ee3f529eb617660021cdc2abf34903e414b..8faa1edb4ac221e8306bdb78b52eb2b45295cfbf 100644 (file)
--- a/redis.c
+++ b/redis.c
@@ -5389,8 +5389,26 @@ static int qsortCompareZsetopsrcByCardinality(const void *s1, const void *s2) {
     return size1 - size2;
 }
 
+#define REDIS_AGGR_SUM 1
+#define REDIS_AGGR_MIN 2
+#define REDIS_AGGR_MAX 3
+
+inline static void zunionInterAggregate(double *target, double val, int aggregate) {
+    if (aggregate == REDIS_AGGR_SUM) {
+        *target = *target + val;
+    } else if (aggregate == REDIS_AGGR_MIN) {
+        *target = val < *target ? val : *target;
+    } else if (aggregate == REDIS_AGGR_MAX) {
+        *target = val > *target ? val : *target;
+    } else {
+        /* safety net */
+        redisAssert(0 != 0);
+    }
+}
+
 static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
     int i, j, zsetnum;
+    int aggregate = REDIS_AGGR_SUM;
     zsetopsrc *src;
     robj *dstobj;
     zset *dstzset;
@@ -5431,19 +5449,28 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
 
     /* parse optional extra arguments */
     if (j < c->argc) {
-        int remaining = c->argc-j;
+        int remaining = c->argc - j;
 
         while (remaining) {
-            if (!strcasecmp(c->argv[j]->ptr,"weights")) {
+            if (remaining >= (zsetnum + 1) && !strcasecmp(c->argv[j]->ptr,"weights")) {
                 j++; remaining--;
-                if (remaining < zsetnum) {
-                    zfree(src);
-                    addReplySds(c,sdsnew("-ERR not enough weights for ZUNION/ZINTER\r\n"));
-                    return;
-                }
                 for (i = 0; i < zsetnum; i++, j++, remaining--) {
                     src[i].weight = strtod(c->argv[j]->ptr, NULL);
                 }
+            } else if (remaining >= 2 && !strcasecmp(c->argv[j]->ptr,"aggregate")) {
+                j++; remaining--;
+                if (!strcasecmp(c->argv[j]->ptr,"sum")) {
+                    aggregate = REDIS_AGGR_SUM;
+                } else if (!strcasecmp(c->argv[j]->ptr,"min")) {
+                    aggregate = REDIS_AGGR_MIN;
+                } else if (!strcasecmp(c->argv[j]->ptr,"max")) {
+                    aggregate = REDIS_AGGR_MAX;
+                } else {
+                    zfree(src);
+                    addReply(c,shared.syntaxerr);
+                    return;
+                }
+                j++; remaining--;
             } else {
                 zfree(src);
                 addReply(c,shared.syntaxerr);
@@ -5452,27 +5479,28 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
         }
     }
 
+    /* sort sets from the smallest to largest, this will improve our
+     * algorithm's performance */
+    qsort(src,zsetnum,sizeof(zsetopsrc), qsortCompareZsetopsrcByCardinality);
+
     dstobj = createZsetObject();
     dstzset = dstobj->ptr;
 
     if (op == REDIS_OP_INTER) {
-        /* sort sets from the smallest to largest, this will improve our
-         * algorithm's performance */
-        qsort(src,zsetnum,sizeof(zsetopsrc), qsortCompareZsetopsrcByCardinality);
-
         /* skip going over all entries if the smallest zset is NULL or empty */
         if (src[0].dict && dictSize(src[0].dict) > 0) {
             /* precondition: as src[0].dict is non-empty and the zsets are ordered
              * from small to large, all src[i > 0].dict are non-empty too */
             di = dictGetIterator(src[0].dict);
             while((de = dictNext(di)) != NULL) {
-                double *score = zmalloc(sizeof(double));
-                *score = 0.0;
+                double *score = zmalloc(sizeof(double)), value;
+                *score = src[0].weight * (*(double*)dictGetEntryVal(de));
 
-                for (j = 0; j < zsetnum; j++) {
-                    dictEntry *other = (j == 0) ? de : dictFind(src[j].dict,dictGetEntryKey(de));
+                for (j = 1; j < zsetnum; j++) {
+                    dictEntry *other = dictFind(src[j].dict,dictGetEntryKey(de));
                     if (other) {
-                        *score = *score + src[j].weight * (*(double*)dictGetEntryVal(other));
+                        value = src[j].weight * (*(double*)dictGetEntryVal(other));
+                        zunionInterAggregate(score, value, aggregate);
                     } else {
                         break;
                     }
@@ -5500,14 +5528,16 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
                 /* skip key when already processed */
                 if (dictFind(dstzset->dict,dictGetEntryKey(de)) != NULL) continue;
 
-                double *score = zmalloc(sizeof(double));
-                *score = 0.0;
-                for (j = 0; j < zsetnum; j++) {
-                    if (!src[j].dict) continue;
+                double *score = zmalloc(sizeof(double)), value;
+                *score = src[i].weight * (*(double*)dictGetEntryVal(de));
 
-                    dictEntry *other = (i == j) ? de : dictFind(src[j].dict,dictGetEntryKey(de));
+                /* because the zsets are sorted by size, its only possible
+                 * for sets at larger indices to hold this entry */
+                for (j = (i+1); j < zsetnum; j++) {
+                    dictEntry *other = dictFind(src[j].dict,dictGetEntryKey(de));
                     if (other) {
-                        *score = *score + src[j].weight * (*(double*)dictGetEntryVal(other));
+                        value = src[j].weight * (*(double*)dictGetEntryVal(other));
+                        zunionInterAggregate(score, value, aggregate);
                     }
                 }