From: Pieter Noordhuis Date: Sat, 13 Mar 2010 22:27:22 +0000 (+0100) Subject: added explicit AGGREGATE [SUM|MIN|MAX] option to ZUNION/ZINTER X-Git-Url: https://git.saurik.com/redis.git/commitdiff_plain/d2764cd6926fb4368ccf55f381aad7386afd5154 added explicit AGGREGATE [SUM|MIN|MAX] option to ZUNION/ZINTER --- diff --git a/redis.c b/redis.c index 302944de..075f1f81 100644 --- a/redis.c +++ b/redis.c @@ -5387,8 +5387,26 @@ static int qsortCompareZsetopsrcByCardinality(const void *s1, const void *s2) { return size1 - size2; } +#define REDIS_AGGR_SUM 1 +#define REDIS_AGGR_MIN 2 +#define REDIS_AGGR_MAX 3 + +inline static void zunionInterAggregate(double *target, double val, int aggregate) { + if (aggregate == REDIS_AGGR_SUM) { + *target = *target + val; + } else if (aggregate == REDIS_AGGR_MIN) { + *target = val < *target ? val : *target; + } else if (aggregate == REDIS_AGGR_MAX) { + *target = val > *target ? val : *target; + } else { + /* safety net */ + redisAssert(0 != 0); + } +} + static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) { int i, j, zsetnum; + int aggregate = REDIS_AGGR_SUM; zsetopsrc *src; robj *dstobj; zset *dstzset; @@ -5429,19 +5447,28 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) { /* parse optional extra arguments */ if (j < c->argc) { - int remaining = c->argc-j; + int remaining = c->argc - j; while (remaining) { - if (!strcasecmp(c->argv[j]->ptr,"weights")) { + if (remaining >= (zsetnum + 1) && !strcasecmp(c->argv[j]->ptr,"weights")) { j++; remaining--; - if (remaining < zsetnum) { - zfree(src); - addReplySds(c,sdsnew("-ERR not enough weights for ZUNION/ZINTER\r\n")); - return; - } for (i = 0; i < zsetnum; i++, j++, remaining--) { src[i].weight = strtod(c->argv[j]->ptr, NULL); } + } else if (remaining >= 2 && !strcasecmp(c->argv[j]->ptr,"aggregate")) { + j++; remaining--; + if (!strcasecmp(c->argv[j]->ptr,"sum")) { + aggregate = REDIS_AGGR_SUM; + } else if (!strcasecmp(c->argv[j]->ptr,"min")) { + aggregate = REDIS_AGGR_MIN; + } else if (!strcasecmp(c->argv[j]->ptr,"max")) { + aggregate = REDIS_AGGR_MAX; + } else { + zfree(src); + addReply(c,shared.syntaxerr); + return; + } + j++; remaining--; } else { zfree(src); addReply(c,shared.syntaxerr); @@ -5450,27 +5477,28 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) { } } + /* sort sets from the smallest to largest, this will improve our + * algorithm's performance */ + qsort(src,zsetnum,sizeof(zsetopsrc), qsortCompareZsetopsrcByCardinality); + dstobj = createZsetObject(); dstzset = dstobj->ptr; if (op == REDIS_OP_INTER) { - /* sort sets from the smallest to largest, this will improve our - * algorithm's performance */ - qsort(src,zsetnum,sizeof(zsetopsrc), qsortCompareZsetopsrcByCardinality); - /* skip going over all entries if the smallest zset is NULL or empty */ if (src[0].dict && dictSize(src[0].dict) > 0) { /* precondition: as src[0].dict is non-empty and the zsets are ordered * from small to large, all src[i > 0].dict are non-empty too */ di = dictGetIterator(src[0].dict); while((de = dictNext(di)) != NULL) { - double *score = zmalloc(sizeof(double)); - *score = 0.0; + double *score = zmalloc(sizeof(double)), value; + *score = src[0].weight * (*(double*)dictGetEntryVal(de)); - for (j = 0; j < zsetnum; j++) { - dictEntry *other = (j == 0) ? de : dictFind(src[j].dict,dictGetEntryKey(de)); + for (j = 1; j < zsetnum; j++) { + dictEntry *other = dictFind(src[j].dict,dictGetEntryKey(de)); if (other) { - *score = *score + src[j].weight * (*(double*)dictGetEntryVal(other)); + value = src[j].weight * (*(double*)dictGetEntryVal(other)); + zunionInterAggregate(score, value, aggregate); } else { break; } @@ -5498,14 +5526,16 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) { /* skip key when already processed */ if (dictFind(dstzset->dict,dictGetEntryKey(de)) != NULL) continue; - double *score = zmalloc(sizeof(double)); - *score = 0.0; - for (j = 0; j < zsetnum; j++) { - if (!src[j].dict) continue; + double *score = zmalloc(sizeof(double)), value; + *score = src[i].weight * (*(double*)dictGetEntryVal(de)); - dictEntry *other = (i == j) ? de : dictFind(src[j].dict,dictGetEntryKey(de)); + /* because the zsets are sorted by size, its only possible + * for sets at larger indices to hold this entry */ + for (j = (i+1); j < zsetnum; j++) { + dictEntry *other = dictFind(src[j].dict,dictGetEntryKey(de)); if (other) { - *score = *score + src[j].weight * (*(double*)dictGetEntryVal(other)); + value = src[j].weight * (*(double*)dictGetEntryVal(other)); + zunionInterAggregate(score, value, aggregate); } } diff --git a/test-redis.tcl b/test-redis.tcl index 66aa0b30..1b1f82e9 100644 --- a/test-redis.tcl +++ b/test-redis.tcl @@ -1491,6 +1491,14 @@ proc main {server port} { list [$r zunion zsetc 2 zseta zsetb weights 2 3] [$r zrange zsetc 0 -1 withscores] } {4 {a 2 b 7 d 9 c 12}} + test {ZUNION with AGGREGATE MIN} { + list [$r zunion zsetc 2 zseta zsetb aggregate min] [$r zrange zsetc 0 -1 withscores] + } {4 {a 1 b 1 c 2 d 3}} + + test {ZUNION with AGGREGATE MAX} { + list [$r zunion zsetc 2 zseta zsetb aggregate max] [$r zrange zsetc 0 -1 withscores] + } {4 {a 1 b 2 c 3 d 3}} + test {ZINTER basics} { list [$r zinter zsetc 2 zseta zsetb] [$r zrange zsetc 0 -1 withscores] } {2 {b 3 c 5}} @@ -1499,6 +1507,14 @@ proc main {server port} { list [$r zinter zsetc 2 zseta zsetb weights 2 3] [$r zrange zsetc 0 -1 withscores] } {2 {b 7 c 12}} + test {ZINTER with AGGREGATE MIN} { + list [$r zinter zsetc 2 zseta zsetb aggregate min] [$r zrange zsetc 0 -1 withscores] + } {2 {b 1 c 2}} + + test {ZINTER with AGGREGATE MAX} { + list [$r zinter zsetc 2 zseta zsetb aggregate max] [$r zrange zsetc 0 -1 withscores] + } {2 {b 2 c 3}} + test {SORT against sorted sets} { $r del zset $r zadd zset 1 a