X-Git-Url: https://git.saurik.com/redis.git/blobdiff_plain/fdfdae0f3abdbe44905d2de1b89ea839717c2a65..17d68f9c99cd629591527d4c385ed0b1244726c5:/redis.c diff --git a/redis.c b/redis.c index 87575ee3..8faa1edb 100644 --- a/redis.c +++ b/redis.c @@ -5389,8 +5389,26 @@ static int qsortCompareZsetopsrcByCardinality(const void *s1, const void *s2) { return size1 - size2; } +#define REDIS_AGGR_SUM 1 +#define REDIS_AGGR_MIN 2 +#define REDIS_AGGR_MAX 3 + +inline static void zunionInterAggregate(double *target, double val, int aggregate) { + if (aggregate == REDIS_AGGR_SUM) { + *target = *target + val; + } else if (aggregate == REDIS_AGGR_MIN) { + *target = val < *target ? val : *target; + } else if (aggregate == REDIS_AGGR_MAX) { + *target = val > *target ? val : *target; + } else { + /* safety net */ + redisAssert(0 != 0); + } +} + static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) { int i, j, zsetnum; + int aggregate = REDIS_AGGR_SUM; zsetopsrc *src; robj *dstobj; zset *dstzset; @@ -5431,19 +5449,28 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) { /* parse optional extra arguments */ if (j < c->argc) { - int remaining = c->argc-j; + int remaining = c->argc - j; while (remaining) { - if (!strcasecmp(c->argv[j]->ptr,"weights")) { + if (remaining >= (zsetnum + 1) && !strcasecmp(c->argv[j]->ptr,"weights")) { j++; remaining--; - if (remaining < zsetnum) { - zfree(src); - addReplySds(c,sdsnew("-ERR not enough weights for ZUNION/ZINTER\r\n")); - return; - } for (i = 0; i < zsetnum; i++, j++, remaining--) { src[i].weight = strtod(c->argv[j]->ptr, NULL); } + } else if (remaining >= 2 && !strcasecmp(c->argv[j]->ptr,"aggregate")) { + j++; remaining--; + if (!strcasecmp(c->argv[j]->ptr,"sum")) { + aggregate = REDIS_AGGR_SUM; + } else if (!strcasecmp(c->argv[j]->ptr,"min")) { + aggregate = REDIS_AGGR_MIN; + } else if (!strcasecmp(c->argv[j]->ptr,"max")) { + aggregate = REDIS_AGGR_MAX; + } else { + zfree(src); + addReply(c,shared.syntaxerr); + return; + } + j++; remaining--; } else { zfree(src); addReply(c,shared.syntaxerr); @@ -5452,27 +5479,28 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) { } } + /* sort sets from the smallest to largest, this will improve our + * algorithm's performance */ + qsort(src,zsetnum,sizeof(zsetopsrc), qsortCompareZsetopsrcByCardinality); + dstobj = createZsetObject(); dstzset = dstobj->ptr; if (op == REDIS_OP_INTER) { - /* sort sets from the smallest to largest, this will improve our - * algorithm's performance */ - qsort(src,zsetnum,sizeof(zsetopsrc), qsortCompareZsetopsrcByCardinality); - /* skip going over all entries if the smallest zset is NULL or empty */ if (src[0].dict && dictSize(src[0].dict) > 0) { /* precondition: as src[0].dict is non-empty and the zsets are ordered * from small to large, all src[i > 0].dict are non-empty too */ di = dictGetIterator(src[0].dict); while((de = dictNext(di)) != NULL) { - double *score = zmalloc(sizeof(double)); - *score = 0.0; + double *score = zmalloc(sizeof(double)), value; + *score = src[0].weight * (*(double*)dictGetEntryVal(de)); - for (j = 0; j < zsetnum; j++) { - dictEntry *other = (j == 0) ? de : dictFind(src[j].dict,dictGetEntryKey(de)); + for (j = 1; j < zsetnum; j++) { + dictEntry *other = dictFind(src[j].dict,dictGetEntryKey(de)); if (other) { - *score = *score + src[j].weight * (*(double*)dictGetEntryVal(other)); + value = src[j].weight * (*(double*)dictGetEntryVal(other)); + zunionInterAggregate(score, value, aggregate); } else { break; } @@ -5500,14 +5528,16 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) { /* skip key when already processed */ if (dictFind(dstzset->dict,dictGetEntryKey(de)) != NULL) continue; - double *score = zmalloc(sizeof(double)); - *score = 0.0; - for (j = 0; j < zsetnum; j++) { - if (!src[j].dict) continue; + double *score = zmalloc(sizeof(double)), value; + *score = src[i].weight * (*(double*)dictGetEntryVal(de)); - dictEntry *other = (i == j) ? de : dictFind(src[j].dict,dictGetEntryKey(de)); + /* because the zsets are sorted by size, its only possible + * for sets at larger indices to hold this entry */ + for (j = (i+1); j < zsetnum; j++) { + dictEntry *other = dictFind(src[j].dict,dictGetEntryKey(de)); if (other) { - *score = *score + src[j].weight * (*(double*)dictGetEntryVal(other)); + value = src[j].weight * (*(double*)dictGetEntryVal(other)); + zunionInterAggregate(score, value, aggregate); } }