From bc000c1db00dafaba9aae9620b246b63e4130238 Mon Sep 17 00:00:00 2001 From: Josiah Carlson Date: Tue, 25 May 2010 17:41:35 -0700 Subject: [PATCH] allow regular sets to be passed to zunionstore/zinterstore --- redis.c | 48 ++++++++++++++++++++++------------------ tests/unit/type/zset.tcl | 16 ++++++++++++++ 2 files changed, 42 insertions(+), 22 deletions(-) diff --git a/redis.c b/redis.c index 42dc00ac..f18eabe6 100644 --- a/redis.c +++ b/redis.c @@ -5922,6 +5922,7 @@ static int qsortCompareZsetopsrcByCardinality(const void *s1, const void *s2) { #define REDIS_AGGR_SUM 1 #define REDIS_AGGR_MIN 2 #define REDIS_AGGR_MAX 3 +#define zunionInterDictValue(_e) (dictGetEntryVal(_e) == NULL ? 1.0 : *(double*)dictGetEntryVal(_e)) inline static void zunionInterAggregate(double *target, double val, int aggregate) { if (aggregate == REDIS_AGGR_SUM) { @@ -5937,7 +5938,7 @@ inline static void zunionInterAggregate(double *target, double val, int aggregat } static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) { - int i, j, zsetnum; + int i, j, setnum; int aggregate = REDIS_AGGR_SUM; zsetopsrc *src; robj *dstobj; @@ -5945,32 +5946,35 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) { dictIterator *di; dictEntry *de; - /* expect zsetnum input keys to be given */ - zsetnum = atoi(c->argv[2]->ptr); - if (zsetnum < 1) { + /* expect setnum input keys to be given */ + setnum = atoi(c->argv[2]->ptr); + if (setnum < 1) { addReplySds(c,sdsnew("-ERR at least 1 input key is needed for ZUNIONSTORE/ZINTERSTORE\r\n")); return; } /* test if the expected number of keys would overflow */ - if (3+zsetnum > c->argc) { + if (3+setnum > c->argc) { addReply(c,shared.syntaxerr); return; } /* read keys to be used for input */ - src = zmalloc(sizeof(zsetopsrc) * zsetnum); - for (i = 0, j = 3; i < zsetnum; i++, j++) { - robj *zsetobj = lookupKeyWrite(c->db,c->argv[j]); - if (!zsetobj) { + src = zmalloc(sizeof(zsetopsrc) * setnum); + for (i = 0, j = 3; i < setnum; i++, j++) { + robj *obj = lookupKeyWrite(c->db,c->argv[j]); + if (!obj) { src[i].dict = NULL; } else { - if (zsetobj->type != REDIS_ZSET) { + if (obj->type == REDIS_ZSET) { + src[i].dict = ((zset*)obj->ptr)->dict; + } else if (obj->type == REDIS_SET) { + src[i].dict = (obj->ptr); + } else { zfree(src); addReply(c,shared.wrongtypeerr); return; } - src[i].dict = ((zset*)zsetobj->ptr)->dict; } /* default all weights to 1 */ @@ -5982,9 +5986,9 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) { int remaining = c->argc - j; while (remaining) { - if (remaining >= (zsetnum + 1) && !strcasecmp(c->argv[j]->ptr,"weights")) { + if (remaining >= (setnum + 1) && !strcasecmp(c->argv[j]->ptr,"weights")) { j++; remaining--; - for (i = 0; i < zsetnum; i++, j++, remaining--) { + for (i = 0; i < setnum; i++, j++, remaining--) { if (getDoubleFromObjectOrReply(c, c->argv[j], &src[i].weight, NULL) != REDIS_OK) return; } @@ -6012,7 +6016,7 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) { /* sort sets from the smallest to largest, this will improve our * algorithm's performance */ - qsort(src,zsetnum,sizeof(zsetopsrc), qsortCompareZsetopsrcByCardinality); + qsort(src,setnum,sizeof(zsetopsrc),qsortCompareZsetopsrcByCardinality); dstobj = createZsetObject(); dstzset = dstobj->ptr; @@ -6025,12 +6029,12 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) { di = dictGetIterator(src[0].dict); while((de = dictNext(di)) != NULL) { double *score = zmalloc(sizeof(double)), value; - *score = src[0].weight * (*(double*)dictGetEntryVal(de)); + *score = src[0].weight * zunionInterDictValue(de); - for (j = 1; j < zsetnum; j++) { + for (j = 1; j < setnum; j++) { dictEntry *other = dictFind(src[j].dict,dictGetEntryKey(de)); if (other) { - value = src[j].weight * (*(double*)dictGetEntryVal(other)); + value = src[j].weight * zunionInterDictValue(other); zunionInterAggregate(score, value, aggregate); } else { break; @@ -6038,7 +6042,7 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) { } /* skip entry when not present in every source dict */ - if (j != zsetnum) { + if (j != setnum) { zfree(score); } else { robj *o = dictGetEntryKey(de); @@ -6051,7 +6055,7 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) { dictReleaseIterator(di); } } else if (op == REDIS_OP_UNION) { - for (i = 0; i < zsetnum; i++) { + for (i = 0; i < setnum; i++) { if (!src[i].dict) continue; di = dictGetIterator(src[i].dict); @@ -6060,14 +6064,14 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) { if (dictFind(dstzset->dict,dictGetEntryKey(de)) != NULL) continue; double *score = zmalloc(sizeof(double)), value; - *score = src[i].weight * (*(double*)dictGetEntryVal(de)); + *score = src[i].weight * zunionInterDictValue(de); /* because the zsets are sorted by size, its only possible * for sets at larger indices to hold this entry */ - for (j = (i+1); j < zsetnum; j++) { + for (j = (i+1); j < setnum; j++) { dictEntry *other = dictFind(src[j].dict,dictGetEntryKey(de)); if (other) { - value = src[j].weight * (*(double*)dictGetEntryVal(other)); + value = src[j].weight * zunionInterDictValue(other); zunionInterAggregate(score, value, aggregate); } } diff --git a/tests/unit/type/zset.tcl b/tests/unit/type/zset.tcl index 107bd6b1..cb78515d 100644 --- a/tests/unit/type/zset.tcl +++ b/tests/unit/type/zset.tcl @@ -316,6 +316,14 @@ start_server default.conf {} { list [r zunionstore zsetc 2 zseta zsetb weights 2 3] [r zrange zsetc 0 -1 withscores] } {4 {a 2 b 7 d 9 c 12}} + test {ZUNIONSTORE with a regular set and weights} { + r del seta + r sadd seta a + r sadd seta b + r sadd seta c + list [r zunionstore zsetc 2 seta zsetb weights 2 3] [r zrange zsetc 0 -1 withscores] + } {4 {a 2 b 5 c 8 d 9}} + test {ZUNIONSTORE with AGGREGATE MIN} { list [r zunionstore zsetc 2 zseta zsetb aggregate min] [r zrange zsetc 0 -1 withscores] } {4 {a 1 b 1 c 2 d 3}} @@ -332,6 +340,14 @@ start_server default.conf {} { list [r zinterstore zsetc 2 zseta zsetb weights 2 3] [r zrange zsetc 0 -1 withscores] } {2 {b 7 c 12}} + test {ZINTERSTORE with a regular set and weights} { + r del seta + r sadd seta a + r sadd seta b + r sadd seta c + list [r zinterstore zsetc 2 seta zsetb weights 2 3] [r zrange zsetc 0 -1 withscores] + } {2 {b 5 c 8}} + test {ZINTERSTORE with AGGREGATE MIN} { list [r zinterstore zsetc 2 zseta zsetb aggregate min] [r zrange zsetc 0 -1 withscores] } {2 {b 1 c 2}} -- 2.47.2