static void zrankCommand(redisClient *c);
static void hsetCommand(redisClient *c);
static void hgetCommand(redisClient *c);
-static void zmergeCommand(redisClient *c);
-static void zmergeweighedCommand(redisClient *c);
+static void zunionCommand(redisClient *c);
+static void zinterCommand(redisClient *c);
/*================================= Globals ================================= */
{"zincrby",zincrbyCommand,4,REDIS_CMD_BULK|REDIS_CMD_DENYOOM,1,1,1},
{"zrem",zremCommand,3,REDIS_CMD_BULK,1,1,1},
{"zremrangebyscore",zremrangebyscoreCommand,4,REDIS_CMD_INLINE,1,1,1},
- {"zmerge",zmergeCommand,-3,REDIS_CMD_INLINE|REDIS_CMD_DENYOOM,2,-1,1},
- {"zmergeweighed",zmergeweighedCommand,-4,REDIS_CMD_INLINE|REDIS_CMD_DENYOOM,2,-2,2},
+ {"zunion",zunionCommand,-4,REDIS_CMD_INLINE|REDIS_CMD_DENYOOM,0,0,0},
+ {"zinter",zinterCommand,-4,REDIS_CMD_INLINE|REDIS_CMD_DENYOOM,0,0,0},
{"zrange",zrangeCommand,-4,REDIS_CMD_INLINE,1,1,1},
{"zrangebyscore",zrangebyscoreCommand,-4,REDIS_CMD_INLINE,1,1,1},
{"zcount",zcountCommand,4,REDIS_CMD_INLINE,1,1,1},
#define REDIS_OP_UNION 0
#define REDIS_OP_DIFF 1
+#define REDIS_OP_INTER 2
static void sunionDiffGenericCommand(redisClient *c, robj **setskeys, int setsnum, robj *dstkey, int op) {
dict **dv = zmalloc(sizeof(dict*)*setsnum);
}
}
-/* This command merges 2 or more zsets to a destination. When an element
- * does not exist in a certain set, score 0 is assumed. The score for an
- * element across sets is summed. */
-static void zmergeGenericCommand(redisClient *c, int readweights) {
- int i, j, zsetnum;
- dict **srcdict;
+static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
+ int i, j, k, zsetnum;
+ dict **srcdicts;
double *weights;
- robj *dstkey = c->argv[1], *dstobj;
- zset *dst;
+ robj *dstobj;
+ zset *dstzset;
dictIterator *di;
dictEntry *de;
- zsetnum = c->argc-2;
- if (readweights) {
- /* force number of arguments to be even */
- if (zsetnum % 2 > 0) {
- addReplySds(c,sdsnew("-ERR wrong number of arguments for ZMERGEWEIGHED\r\n"));
- return;
- }
- zsetnum /= 2;
+ /* expect zsetnum input keys to be given */
+ zsetnum = atoi(c->argv[2]->ptr);
+ if (zsetnum < 1) {
+ addReplySds(c,sdsnew("-ERR at least 1 input key is needed for ZUNION/ZINTER\r\n"));
+ return;
}
- if (!zsetnum) {
+
+ /* test if the expected number of keys would overflow */
+ if (3+zsetnum > c->argc) {
addReply(c,shared.syntaxerr);
return;
}
- srcdict = zmalloc(sizeof(dict*) * zsetnum);
+ /* read keys to be used for input */
+ srcdicts = zmalloc(sizeof(dict*) * zsetnum);
weights = zmalloc(sizeof(double) * zsetnum);
- for (i = 0; i < zsetnum; i++) {
- if (readweights) {
- j = 2 + 2*i;
- weights[i] = strtod(c->argv[j+1]->ptr, NULL);
- } else {
- j = 2 + i;
- weights[i] = 1.0;
- }
-
+ for (i = 0, j = 3; i < zsetnum; i++, j++) {
robj *zsetobj = lookupKeyWrite(c->db,c->argv[j]);
if (!zsetobj) {
- srcdict[i] = NULL;
+ srcdicts[i] = NULL;
} else {
if (zsetobj->type != REDIS_ZSET) {
- zfree(srcdict);
+ zfree(srcdicts);
zfree(weights);
addReply(c,shared.wrongtypeerr);
return;
}
- srcdict[i] = ((zset*)zsetobj->ptr)->dict;
+ srcdicts[i] = ((zset*)zsetobj->ptr)->dict;
}
+
+ /* default all weights to 1 */
+ weights[i] = 1.0;
}
- dstobj = createZsetObject();
- dst = dstobj->ptr;
- for (i = 0; i < zsetnum; i++) {
- if (!srcdict[i]) continue;
+ /* parse optional extra arguments */
+ if (j < c->argc) {
+ int remaining = c->argc-j;
- di = dictGetIterator(srcdict[i]);
- while((de = dictNext(di)) != NULL) {
- /* skip key when already processed */
- if (dictFind(dst->dict,dictGetEntryKey(de)) != NULL) continue;
+ while (remaining) {
+ if (!strcasecmp(c->argv[j]->ptr,"weights")) {
+ j++; remaining--;
+ if (remaining < zsetnum) {
+ zfree(srcdicts);
+ zfree(weights);
+ addReplySds(c,sdsnew("-ERR not enough weights for ZUNION/ZINTER\r\n"));
+ return;
+ }
+ for (i = 0; i < zsetnum; i++, j++, remaining--) {
+ weights[i] = strtod(c->argv[j]->ptr, NULL);
+ }
+ } else {
+ zfree(srcdicts);
+ zfree(weights);
+ addReply(c,shared.syntaxerr);
+ return;
+ }
+ }
+ }
- double *score = zmalloc(sizeof(double));
- *score = 0.0;
- for (j = 0; j < zsetnum; j++) {
- if (!srcdict[j]) continue;
+ dstobj = createZsetObject();
+ dstzset = dstobj->ptr;
+
+ if (op == REDIS_OP_INTER) {
+ /* store index of smallest zset in variable j */
+ for (i = 0, j = 0; i < zsetnum; i++) {
+ if (!srcdicts[i] || dictSize(srcdicts[i]) == 0) {
+ break;
+ }
+ if (dictSize(srcdicts[i]) < dictSize(srcdicts[j])) {
+ j = i;
+ }
+ }
+ /* skip going over all entries if at least one dict was NULL or empty */
+ if (i == zsetnum) {
+ /* precondition: all srcdicts are non-NULL and non-empty */
+ di = dictGetIterator(srcdicts[j]);
+ while((de = dictNext(di)) != NULL) {
+ double *score = zmalloc(sizeof(double));
+ *score = 0.0;
+
+ for (k = 0; k < zsetnum; k++) {
+ dictEntry *other = (k == j) ? de : dictFind(srcdicts[k],dictGetEntryKey(de));
+ if (other) {
+ *score = *score + weights[k] * (*(double*)dictGetEntryVal(other));
+ } else {
+ break;
+ }
+ }
- dictEntry *other = dictFind(srcdict[j],dictGetEntryKey(de));
- if (other) {
- *score = *score + weights[j] * (*(double*)dictGetEntryVal(other));
+ /* skip entry when not present in every source dict */
+ if (k != zsetnum) {
+ zfree(score);
+ } else {
+ robj *o = dictGetEntryKey(de);
+ dictAdd(dstzset->dict,o,score);
+ incrRefCount(o); /* added to dictionary */
+ zslInsert(dstzset->zsl,*score,o);
+ incrRefCount(o); /* added to skiplist */
}
}
+ dictReleaseIterator(di);
+ }
+ } else if (op == REDIS_OP_UNION) {
+ for (i = 0; i < zsetnum; i++) {
+ if (!srcdicts[i]) continue;
+
+ di = dictGetIterator(srcdicts[i]);
+ while((de = dictNext(di)) != NULL) {
+ /* skip key when already processed */
+ if (dictFind(dstzset->dict,dictGetEntryKey(de)) != NULL) continue;
+
+ double *score = zmalloc(sizeof(double));
+ *score = 0.0;
+ for (j = 0; j < zsetnum; j++) {
+ if (!srcdicts[j]) continue;
+
+ dictEntry *other = (i == j) ? de : dictFind(srcdicts[j],dictGetEntryKey(de));
+ if (other) {
+ *score = *score + weights[j] * (*(double*)dictGetEntryVal(other));
+ }
+ }
- robj *o = dictGetEntryKey(de);
- dictAdd(dst->dict,o,score);
- incrRefCount(o); /* added to dictionary */
- zslInsert(dst->zsl,*score,o);
- incrRefCount(o); /* added to skiplist */
+ robj *o = dictGetEntryKey(de);
+ dictAdd(dstzset->dict,o,score);
+ incrRefCount(o); /* added to dictionary */
+ zslInsert(dstzset->zsl,*score,o);
+ incrRefCount(o); /* added to skiplist */
+ }
+ dictReleaseIterator(di);
}
- dictReleaseIterator(di);
+ } else {
+ /* unknown operator */
+ redisAssert(op == REDIS_OP_INTER || op == REDIS_OP_UNION);
}
deleteKey(c->db,dstkey);
dictAdd(c->db->dict,dstkey,dstobj);
incrRefCount(dstkey);
- addReplyLong(c, dst->zsl->length);
+ addReplyLong(c, dstzset->zsl->length);
server.dirty++;
- zfree(srcdict);
+ zfree(srcdicts);
zfree(weights);
}
-static void zmergeCommand(redisClient *c) {
- zmergeGenericCommand(c,0);
+static void zunionCommand(redisClient *c) {
+ zunionInterGenericCommand(c,c->argv[1], REDIS_OP_UNION);
}
-static void zmergeweighedCommand(redisClient *c) {
- zmergeGenericCommand(c,1);
+static void zinterCommand(redisClient *c) {
+ zunionInterGenericCommand(c,c->argv[1], REDIS_OP_INTER);
}
static void zrangeGenericCommand(redisClient *c, int reverse) {