]> git.saurik.com Git - redis.git/commitdiff
use a struct to store both a dict and its weight for ZUNION and ZINTER, so qsort...
authorPieter Noordhuis <pcnoordhuis@gmail.com>
Tue, 9 Mar 2010 15:12:34 +0000 (16:12 +0100)
committerPieter Noordhuis <pcnoordhuis@gmail.com>
Tue, 9 Mar 2010 15:12:34 +0000 (16:12 +0100)
redis.c

diff --git a/redis.c b/redis.c
index 94ab6720b1f3db1671dfbe65c01641530f269a1d..cc64efb846a49d0ae19ff4c0702072e413d89167 100644 (file)
--- a/redis.c
+++ b/redis.c
@@ -5330,10 +5330,22 @@ static void zremrangebyscoreCommand(redisClient *c) {
     }
 }
 
+typedef struct {
+    dict *dict;
+    double weight;
+} zsetopsrc;
+
+static int qsortCompareZsetopsrcByCardinality(const void *s1, const void *s2) {
+    zsetopsrc *d1 = (void*) s1, *d2 = (void*) s2;
+    unsigned long size1, size2;
+    size1 = d1->dict ? dictSize(d1->dict) : 0;
+    size2 = d2->dict ? dictSize(d2->dict) : 0;
+    return size1 - size2;
+}
+
 static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
-    int i, j, k, zsetnum;
-    dict **srcdicts;
-    double *weights;
+    int i, j, zsetnum;
+    zsetopsrc *src;
     robj *dstobj;
     zset *dstzset;
     dictIterator *di;
@@ -5353,24 +5365,22 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
     }
 
     /* read keys to be used for input */
-    srcdicts = zmalloc(sizeof(dict*) * zsetnum);
-    weights = zmalloc(sizeof(double) * zsetnum);
+    src = malloc(sizeof(zsetopsrc) * zsetnum);
     for (i = 0, j = 3; i < zsetnum; i++, j++) {
         robj *zsetobj = lookupKeyWrite(c->db,c->argv[j]);
         if (!zsetobj) {
-            srcdicts[i] = NULL;
+            src[i].dict = NULL;
         } else {
             if (zsetobj->type != REDIS_ZSET) {
-                zfree(srcdicts);
-                zfree(weights);
+                zfree(src);
                 addReply(c,shared.wrongtypeerr);
                 return;
             }
-            srcdicts[i] = ((zset*)zsetobj->ptr)->dict;
+            src[i].dict = ((zset*)zsetobj->ptr)->dict;
         }
 
         /* default all weights to 1 */
-        weights[i] = 1.0;
+        src[i].weight = 1.0;
     }
 
     /* parse optional extra arguments */
@@ -5381,17 +5391,15 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
             if (!strcasecmp(c->argv[j]->ptr,"weights")) {
                 j++; remaining--;
                 if (remaining < zsetnum) {
-                    zfree(srcdicts);
-                    zfree(weights);
+                    zfree(src);
                     addReplySds(c,sdsnew("-ERR not enough weights for ZUNION/ZINTER\r\n"));
                     return;
                 }
                 for (i = 0; i < zsetnum; i++, j++, remaining--) {
-                    weights[i] = strtod(c->argv[j]->ptr, NULL);
+                    src[i].weight = strtod(c->argv[j]->ptr, NULL);
                 }
             } else {
-                zfree(srcdicts);
-                zfree(weights);
+                zfree(src);
                 addReply(c,shared.syntaxerr);
                 return;
             }
@@ -5402,34 +5410,30 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
     dstzset = dstobj->ptr;
 
     if (op == REDIS_OP_INTER) {
-        /* store index of smallest zset in variable j */
-        for (i = 0, j = 0; i < zsetnum; i++) {
-            if (!srcdicts[i] || dictSize(srcdicts[i]) == 0) {
-                break;
-            }
-            if (dictSize(srcdicts[i]) < dictSize(srcdicts[j])) {
-                j = i;
-            }
-        }
-        /* skip going over all entries if at least one dict was NULL or empty */
-        if (i == zsetnum) {
-            /* precondition: all srcdicts are non-NULL and non-empty */
-            di = dictGetIterator(srcdicts[j]);
+        /* sort sets from the smallest to largest, this will improve our
+         * algorithm's performance */
+        qsort(src,zsetnum,sizeof(zsetopsrc), qsortCompareZsetopsrcByCardinality);
+
+        /* skip going over all entries if the smallest zset is NULL or empty */
+        if (src[0].dict && dictSize(src[0].dict) > 0) {
+            /* precondition: as src[0].dict is non-empty and the zsets are ordered
+             * from small to large, all src[i > 0].dict are non-empty too */
+            di = dictGetIterator(src[0].dict);
             while((de = dictNext(di)) != NULL) {
                 double *score = zmalloc(sizeof(double));
                 *score = 0.0;
 
-                for (k = 0; k < zsetnum; k++) {
-                    dictEntry *other = (k == j) ? de : dictFind(srcdicts[k],dictGetEntryKey(de));
+                for (j = 0; j < zsetnum; j++) {
+                    dictEntry *other = (j == 0) ? de : dictFind(src[j].dict,dictGetEntryKey(de));
                     if (other) {
-                        *score = *score + weights[k] * (*(double*)dictGetEntryVal(other));
+                        *score = *score + src[j].weight * (*(double*)dictGetEntryVal(other));
                     } else {
                         break;
                     }
                 }
 
                 /* skip entry when not present in every source dict */
-                if (k != zsetnum) {
+                if (j != zsetnum) {
                     zfree(score);
                 } else {
                     robj *o = dictGetEntryKey(de);
@@ -5443,9 +5447,9 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
         }
     } else if (op == REDIS_OP_UNION) {
         for (i = 0; i < zsetnum; i++) {
-            if (!srcdicts[i]) continue;
+            if (!src[i].dict) continue;
 
-            di = dictGetIterator(srcdicts[i]);
+            di = dictGetIterator(src[i].dict);
             while((de = dictNext(di)) != NULL) {
                 /* skip key when already processed */
                 if (dictFind(dstzset->dict,dictGetEntryKey(de)) != NULL) continue;
@@ -5453,11 +5457,11 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
                 double *score = zmalloc(sizeof(double));
                 *score = 0.0;
                 for (j = 0; j < zsetnum; j++) {
-                    if (!srcdicts[j]) continue;
+                    if (!src[j].dict) continue;
 
-                    dictEntry *other = (i == j) ? de : dictFind(srcdicts[j],dictGetEntryKey(de));
+                    dictEntry *other = (i == j) ? de : dictFind(src[j].dict,dictGetEntryKey(de));
                     if (other) {
-                        *score = *score + weights[j] * (*(double*)dictGetEntryVal(other));
+                        *score = *score + src[j].weight * (*(double*)dictGetEntryVal(other));
                     }
                 }
 
@@ -5480,8 +5484,7 @@ static void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
 
     addReplyLong(c, dstzset->zsl->length);
     server.dirty++;
-    zfree(srcdicts);
-    zfree(weights);
+    zfree(src);
 }
 
 static void zunionCommand(redisClient *c) {