return size1 - size2;
}
+#define REDIS_AGGR_SUM 1
+#define REDIS_AGGR_MIN 2
+#define REDIS_AGGR_MAX 3
+
+inline static void zunionInterAggregate(double *target, double val, int aggregate) {
+ if (aggregate == REDIS_AGGR_SUM) {
+ *target = *target + val;
+ } else if (aggregate == REDIS_AGGR_MIN) {
+ *target = val < *target ? val : *target;
+ } else if (aggregate == REDIS_AGGR_MAX) {
+ *target = val > *target ? val : *target;
+ } else {
+ /* safety net */
+ redisAssert(0 != 0);
+ }
+}
+
static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
int i, j, zsetnum;
+ int aggregate = REDIS_AGGR_SUM;
zsetopsrc *src;
robj *dstobj;
zset *dstzset;
/* parse optional extra arguments */
if (j < c->argc) {
- int remaining = c->argc-j;
+ int remaining = c->argc - j;
while (remaining) {
- if (!strcasecmp(c->argv[j]->ptr,"weights")) {
+ if (remaining >= (zsetnum + 1) && !strcasecmp(c->argv[j]->ptr,"weights")) {
j++; remaining--;
- if (remaining < zsetnum) {
- zfree(src);
- addReplySds(c,sdsnew("-ERR not enough weights for ZUNION/ZINTER\r\n"));
- return;
- }
for (i = 0; i < zsetnum; i++, j++, remaining--) {
src[i].weight = strtod(c->argv[j]->ptr, NULL);
}
+ } else if (remaining >= 2 && !strcasecmp(c->argv[j]->ptr,"aggregate")) {
+ j++; remaining--;
+ if (!strcasecmp(c->argv[j]->ptr,"sum")) {
+ aggregate = REDIS_AGGR_SUM;
+ } else if (!strcasecmp(c->argv[j]->ptr,"min")) {
+ aggregate = REDIS_AGGR_MIN;
+ } else if (!strcasecmp(c->argv[j]->ptr,"max")) {
+ aggregate = REDIS_AGGR_MAX;
+ } else {
+ zfree(src);
+ addReply(c,shared.syntaxerr);
+ return;
+ }
+ j++; remaining--;
} else {
zfree(src);
addReply(c,shared.syntaxerr);
}
}
+ /* sort sets from the smallest to largest, this will improve our
+ * algorithm's performance */
+ qsort(src,zsetnum,sizeof(zsetopsrc), qsortCompareZsetopsrcByCardinality);
+
dstobj = createZsetObject();
dstzset = dstobj->ptr;
if (op == REDIS_OP_INTER) {
- /* sort sets from the smallest to largest, this will improve our
- * algorithm's performance */
- qsort(src,zsetnum,sizeof(zsetopsrc), qsortCompareZsetopsrcByCardinality);
-
/* skip going over all entries if the smallest zset is NULL or empty */
if (src[0].dict && dictSize(src[0].dict) > 0) {
/* precondition: as src[0].dict is non-empty and the zsets are ordered
* from small to large, all src[i > 0].dict are non-empty too */
di = dictGetIterator(src[0].dict);
while((de = dictNext(di)) != NULL) {
- double *score = zmalloc(sizeof(double));
- *score = 0.0;
+ double *score = zmalloc(sizeof(double)), value;
+ *score = src[0].weight * (*(double*)dictGetEntryVal(de));
- for (j = 0; j < zsetnum; j++) {
- dictEntry *other = (j == 0) ? de : dictFind(src[j].dict,dictGetEntryKey(de));
+ for (j = 1; j < zsetnum; j++) {
+ dictEntry *other = dictFind(src[j].dict,dictGetEntryKey(de));
if (other) {
- *score = *score + src[j].weight * (*(double*)dictGetEntryVal(other));
+ value = src[j].weight * (*(double*)dictGetEntryVal(other));
+ zunionInterAggregate(score, value, aggregate);
} else {
break;
}
/* skip key when already processed */
if (dictFind(dstzset->dict,dictGetEntryKey(de)) != NULL) continue;
- double *score = zmalloc(sizeof(double));
- *score = 0.0;
- for (j = 0; j < zsetnum; j++) {
- if (!src[j].dict) continue;
+ double *score = zmalloc(sizeof(double)), value;
+ *score = src[i].weight * (*(double*)dictGetEntryVal(de));
- dictEntry *other = (i == j) ? de : dictFind(src[j].dict,dictGetEntryKey(de));
+ /* because the zsets are sorted by size, its only possible
+ * for sets at larger indices to hold this entry */
+ for (j = (i+1); j < zsetnum; j++) {
+ dictEntry *other = dictFind(src[j].dict,dictGetEntryKey(de));
if (other) {
- *score = *score + src[j].weight * (*(double*)dictGetEntryVal(other));
+ value = src[j].weight * (*(double*)dictGetEntryVal(other));
+ zunionInterAggregate(score, value, aggregate);
}
}