]> git.saurik.com Git - redis.git/blobdiff - src/t_zset.c
Don't hardcode make to "make"
[redis.git] / src / t_zset.c
index e528205a73ce1ff515d2985e821eca49f25e7983..8139b53d84bb43b94ea7455786eac7ee54403003 100644 (file)
@@ -174,25 +174,35 @@ int zslDelete(zskiplist *zsl, double score, robj *obj) {
     return 0; /* not found */
 }
 
+/* Struct to hold a inclusive/exclusive range spec. */
+typedef struct {
+    double min, max;
+    int minex, maxex; /* are min or max exclusive? */
+} zrangespec;
+
 /* Delete all the elements with score between min and max from the skiplist.
  * Min and mx are inclusive, so a score >= min || score <= max is deleted.
  * Note that this function takes the reference to the hash table view of the
  * sorted set, in order to remove the elements from the hash table too. */
-unsigned long zslDeleteRangeByScore(zskiplist *zsl, double min, double max, dict *dict) {
+unsigned long zslDeleteRangeByScore(zskiplist *zsl, zrangespec range, dict *dict) {
     zskiplistNode *update[ZSKIPLIST_MAXLEVEL], *x;
     unsigned long removed = 0;
     int i;
 
     x = zsl->header;
     for (i = zsl->level-1; i >= 0; i--) {
-        while (x->level[i].forward && x->level[i].forward->score < min)
-            x = x->level[i].forward;
+        while (x->level[i].forward && (range.minex ?
+            x->level[i].forward->score <= range.min :
+            x->level[i].forward->score < range.min))
+                x = x->level[i].forward;
         update[i] = x;
     }
-    /* We may have multiple elements with the same score, what we need
-     * is to find the element with both the right score and object. */
+
+    /* Current node is the last with score < or <= min. */
     x = x->level[0].forward;
-    while (x && x->score <= max) {
+
+    /* Delete nodes while in range. */
+    while (x && (range.maxex ? x->score < range.max : x->score <= range.max)) {
         zskiplistNode *next = x->level[0].forward;
         zslDeleteNode(zsl,x,update);
         dictDelete(dict,x->obj);
@@ -200,7 +210,7 @@ unsigned long zslDeleteRangeByScore(zskiplist *zsl, double min, double max, dict
         removed++;
         x = next;
     }
-    return removed; /* not found */
+    return removed;
 }
 
 /* Delete all the elements with rank between start and end from the skiplist.
@@ -296,13 +306,9 @@ zskiplistNode* zslistTypeGetElementByRank(zskiplist *zsl, unsigned long rank) {
     return NULL;
 }
 
-typedef struct {
-    double min, max;
-    int minex, maxex; /* are min or max exclusive? */
-} zrangespec;
-
 /* Populate the rangespec according to the objects min and max. */
-int zslParseRange(robj *min, robj *max, zrangespec *spec) {
+static int zslParseRange(robj *min, robj *max, zrangespec *spec) {
+    char *eptr;
     spec->minex = spec->maxex = 0;
 
     /* Parse the min-max interval. If one of the values is prefixed
@@ -313,20 +319,24 @@ int zslParseRange(robj *min, robj *max, zrangespec *spec) {
         spec->min = (long)min->ptr;
     } else {
         if (((char*)min->ptr)[0] == '(') {
-            spec->min = strtod((char*)min->ptr+1,NULL);
+            spec->min = strtod((char*)min->ptr+1,&eptr);
+            if (eptr[0] != '\0' || isnan(spec->min)) return REDIS_ERR;
             spec->minex = 1;
         } else {
-            spec->min = strtod((char*)min->ptr,NULL);
+            spec->min = strtod((char*)min->ptr,&eptr);
+            if (eptr[0] != '\0' || isnan(spec->min)) return REDIS_ERR;
         }
     }
     if (max->encoding == REDIS_ENCODING_INT) {
         spec->max = (long)max->ptr;
     } else {
         if (((char*)max->ptr)[0] == '(') {
-            spec->max = strtod((char*)max->ptr+1,NULL);
+            spec->max = strtod((char*)max->ptr+1,&eptr);
+            if (eptr[0] != '\0' || isnan(spec->max)) return REDIS_ERR;
             spec->maxex = 1;
         } else {
-            spec->max = strtod((char*)max->ptr,NULL);
+            spec->max = strtod((char*)max->ptr,&eptr);
+            if (eptr[0] != '\0' || isnan(spec->max)) return REDIS_ERR;
         }
     }
 
@@ -430,12 +440,14 @@ void zaddGenericCommand(redisClient *c, robj *key, robj *ele, double score, int
 void zaddCommand(redisClient *c) {
     double scoreval;
     if (getDoubleFromObjectOrReply(c,c->argv[2],&scoreval,NULL) != REDIS_OK) return;
+    c->argv[3] = tryObjectEncoding(c->argv[3]);
     zaddGenericCommand(c,c->argv[1],c->argv[3],scoreval,0);
 }
 
 void zincrbyCommand(redisClient *c) {
     double scoreval;
     if (getDoubleFromObjectOrReply(c,c->argv[2],&scoreval,NULL) != REDIS_OK) return;
+    c->argv[3] = tryObjectEncoding(c->argv[3]);
     zaddGenericCommand(c,c->argv[1],c->argv[3],scoreval,1);
 }
 
@@ -450,6 +462,7 @@ void zremCommand(redisClient *c) {
         checkType(c,zsetobj,REDIS_ZSET)) return;
 
     zs = zsetobj->ptr;
+    c->argv[2] = tryObjectEncoding(c->argv[2]);
     de = dictFind(zs->dict,c->argv[2]);
     if (de == NULL) {
         addReply(c,shared.czero);
@@ -470,20 +483,22 @@ void zremCommand(redisClient *c) {
 }
 
 void zremrangebyscoreCommand(redisClient *c) {
-    double min;
-    double max;
+    zrangespec range;
     long deleted;
-    robj *zsetobj;
+    robj *o;
     zset *zs;
 
-    if ((getDoubleFromObjectOrReply(c, c->argv[2], &min, NULL) != REDIS_OK) ||
-        (getDoubleFromObjectOrReply(c, c->argv[3], &max, NULL) != REDIS_OK)) return;
+    /* Parse the range arguments. */
+    if (zslParseRange(c->argv[2],c->argv[3],&range) != REDIS_OK) {
+        addReplyError(c,"min or max is not a double");
+        return;
+    }
 
-    if ((zsetobj = lookupKeyWriteOrReply(c,c->argv[1],shared.czero)) == NULL ||
-        checkType(c,zsetobj,REDIS_ZSET)) return;
+    if ((o = lookupKeyWriteOrReply(c,c->argv[1],shared.czero)) == NULL ||
+        checkType(c,o,REDIS_ZSET)) return;
 
-    zs = zsetobj->ptr;
-    deleted = zslDeleteRangeByScore(zs->zsl,min,max,zs->dict);
+    zs = o->ptr;
+    deleted = zslDeleteRangeByScore(zs->zsl,range,zs->dict);
     if (htNeedsResize(zs->dict)) dictResize(zs->dict);
     if (dictSize(zs->dict) == 0) dbDelete(c->db,c->argv[1]);
     if (deleted) touchWatchedKey(c->db,c->argv[1]);
@@ -663,25 +678,23 @@ void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
              * from small to large, all src[i > 0].dict are non-empty too */
             di = dictGetIterator(src[0].dict);
             while((de = dictNext(di)) != NULL) {
-                double *score = zmalloc(sizeof(double)), value;
-                *score = src[0].weight * zunionInterDictValue(de);
+                double score, value;
 
+                score = src[0].weight * zunionInterDictValue(de);
                 for (j = 1; j < setnum; j++) {
                     dictEntry *other = dictFind(src[j].dict,dictGetEntryKey(de));
                     if (other) {
                         value = src[j].weight * zunionInterDictValue(other);
-                        zunionInterAggregate(score, value, aggregate);
+                        zunionInterAggregate(&score,value,aggregate);
                     } else {
                         break;
                     }
                 }
 
-                /* skip entry when not present in every source dict */
-                if (j != setnum) {
-                    zfree(score);
-                } else {
+                /* Only continue when present in every source dict. */
+                if (j == setnum) {
                     robj *o = dictGetEntryKey(de);
-                    znode = zslInsert(dstzset->zsl,*score,o);
+                    znode = zslInsert(dstzset->zsl,score,o);
                     incrRefCount(o); /* added to skiplist */
                     dictAdd(dstzset->dict,o,&znode->score);
                     incrRefCount(o); /* added to dictionary */
@@ -695,11 +708,14 @@ void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
 
             di = dictGetIterator(src[i].dict);
             while((de = dictNext(di)) != NULL) {
+                double score, value;
+
                 /* skip key when already processed */
-                if (dictFind(dstzset->dict,dictGetEntryKey(de)) != NULL) continue;
+                if (dictFind(dstzset->dict,dictGetEntryKey(de)) != NULL)
+                    continue;
 
-                double *score = zmalloc(sizeof(double)), value;
-                *score = src[i].weight * zunionInterDictValue(de);
+                /* initialize score */
+                score = src[i].weight * zunionInterDictValue(de);
 
                 /* because the zsets are sorted by size, its only possible
                  * for sets at larger indices to hold this entry */
@@ -707,12 +723,12 @@ void zunionInterGenericCommand(redisClient *c, robj *dstkey, int op) {
                     dictEntry *other = dictFind(src[j].dict,dictGetEntryKey(de));
                     if (other) {
                         value = src[j].weight * zunionInterDictValue(other);
-                        zunionInterAggregate(score, value, aggregate);
+                        zunionInterAggregate(&score,value,aggregate);
                     }
                 }
 
                 robj *o = dictGetEntryKey(de);
-                znode = zslInsert(dstzset->zsl,*score,o);
+                znode = zslInsert(dstzset->zsl,score,o);
                 incrRefCount(o); /* added to skiplist */
                 dictAdd(dstzset->dict,o,&znode->score);
                 incrRefCount(o); /* added to dictionary */
@@ -833,7 +849,10 @@ void genericZrangebyscoreCommand(redisClient *c, int reverse, int justcount) {
     void *replylen = NULL;
 
     /* Parse the range arguments. */
-    zslParseRange(c->argv[2],c->argv[3],&range);
+    if (zslParseRange(c->argv[2],c->argv[3],&range) != REDIS_OK) {
+        addReplyError(c,"min or max is not a double");
+        return;
+    }
 
     /* Parse optional extra arguments. Note that ZCOUNT will exactly have
      * 4 arguments, so we'll never enter the following code path. */
@@ -988,6 +1007,7 @@ void zscoreCommand(redisClient *c) {
         checkType(c,o,REDIS_ZSET)) return;
 
     zs = o->ptr;
+    c->argv[2] = tryObjectEncoding(c->argv[2]);
     de = dictFind(zs->dict,c->argv[2]);
     if (!de) {
         addReply(c,shared.nullbulk);
@@ -1011,6 +1031,7 @@ void zrankGenericCommand(redisClient *c, int reverse) {
 
     zs = o->ptr;
     zsl = zs->zsl;
+    c->argv[2] = tryObjectEncoding(c->argv[2]);
     de = dictFind(zs->dict,c->argv[2]);
     if (!de) {
         addReply(c,shared.nullbulk);