]> git.saurik.com Git - redis.git/blob - src/sort.c
fix rare condition where 'key' would already be destroyed while is was needed later on
[redis.git] / src / sort.c
1 #include "redis.h"
2 #include "pqsort.h" /* Partial qsort for SORT+LIMIT */
3
4 redisSortOperation *createSortOperation(int type, robj *pattern) {
5 redisSortOperation *so = zmalloc(sizeof(*so));
6 so->type = type;
7 so->pattern = pattern;
8 return so;
9 }
10
11 /* Return the value associated to the key with a name obtained
12 * substituting the first occurence of '*' in 'pattern' with 'subst'.
13 * The returned object will always have its refcount increased by 1
14 * when it is non-NULL. */
15 robj *lookupKeyByPattern(redisDb *db, robj *pattern, robj *subst) {
16 char *p, *f;
17 sds spat, ssub;
18 robj keyobj, fieldobj, *o;
19 int prefixlen, sublen, postfixlen, fieldlen;
20 /* Expoit the internal sds representation to create a sds string allocated on the stack in order to make this function faster */
21 struct {
22 int len;
23 int free;
24 char buf[REDIS_SORTKEY_MAX+1];
25 } keyname, fieldname;
26
27 /* If the pattern is "#" return the substitution object itself in order
28 * to implement the "SORT ... GET #" feature. */
29 spat = pattern->ptr;
30 if (spat[0] == '#' && spat[1] == '\0') {
31 incrRefCount(subst);
32 return subst;
33 }
34
35 /* The substitution object may be specially encoded. If so we create
36 * a decoded object on the fly. Otherwise getDecodedObject will just
37 * increment the ref count, that we'll decrement later. */
38 subst = getDecodedObject(subst);
39
40 ssub = subst->ptr;
41 if (sdslen(spat)+sdslen(ssub)-1 > REDIS_SORTKEY_MAX) return NULL;
42 p = strchr(spat,'*');
43 if (!p) {
44 decrRefCount(subst);
45 return NULL;
46 }
47
48 /* Find out if we're dealing with a hash dereference. */
49 if ((f = strstr(p+1, "->")) != NULL) {
50 fieldlen = sdslen(spat)-(f-spat);
51 /* this also copies \0 character */
52 memcpy(fieldname.buf,f+2,fieldlen-1);
53 fieldname.len = fieldlen-2;
54 } else {
55 fieldlen = 0;
56 }
57
58 prefixlen = p-spat;
59 sublen = sdslen(ssub);
60 postfixlen = sdslen(spat)-(prefixlen+1)-fieldlen;
61 memcpy(keyname.buf,spat,prefixlen);
62 memcpy(keyname.buf+prefixlen,ssub,sublen);
63 memcpy(keyname.buf+prefixlen+sublen,p+1,postfixlen);
64 keyname.buf[prefixlen+sublen+postfixlen] = '\0';
65 keyname.len = prefixlen+sublen+postfixlen;
66 decrRefCount(subst);
67
68 /* Lookup substituted key */
69 initStaticStringObject(keyobj,((char*)&keyname)+(sizeof(struct sdshdr)));
70 o = lookupKeyRead(db,&keyobj);
71 if (o == NULL) return NULL;
72
73 if (fieldlen > 0) {
74 if (o->type != REDIS_HASH || fieldname.len < 1) return NULL;
75
76 /* Retrieve value from hash by the field name. This operation
77 * already increases the refcount of the returned object. */
78 initStaticStringObject(fieldobj,((char*)&fieldname)+(sizeof(struct sdshdr)));
79 o = hashTypeGet(o, &fieldobj);
80 } else {
81 if (o->type != REDIS_STRING) return NULL;
82
83 /* Every object that this function returns needs to have its refcount
84 * increased. sortCommand decreases it again. */
85 incrRefCount(o);
86 }
87
88 return o;
89 }
90
91 /* sortCompare() is used by qsort in sortCommand(). Given that qsort_r with
92 * the additional parameter is not standard but a BSD-specific we have to
93 * pass sorting parameters via the global 'server' structure */
94 int sortCompare(const void *s1, const void *s2) {
95 const redisSortObject *so1 = s1, *so2 = s2;
96 int cmp;
97
98 if (!server.sort_alpha) {
99 /* Numeric sorting. Here it's trivial as we precomputed scores */
100 if (so1->u.score > so2->u.score) {
101 cmp = 1;
102 } else if (so1->u.score < so2->u.score) {
103 cmp = -1;
104 } else {
105 cmp = 0;
106 }
107 } else {
108 /* Alphanumeric sorting */
109 if (server.sort_bypattern) {
110 if (!so1->u.cmpobj || !so2->u.cmpobj) {
111 /* At least one compare object is NULL */
112 if (so1->u.cmpobj == so2->u.cmpobj)
113 cmp = 0;
114 else if (so1->u.cmpobj == NULL)
115 cmp = -1;
116 else
117 cmp = 1;
118 } else {
119 /* We have both the objects, use strcoll */
120 cmp = strcoll(so1->u.cmpobj->ptr,so2->u.cmpobj->ptr);
121 }
122 } else {
123 /* Compare elements directly. */
124 cmp = compareStringObjects(so1->obj,so2->obj);
125 }
126 }
127 return server.sort_desc ? -cmp : cmp;
128 }
129
130 /* The SORT command is the most complex command in Redis. Warning: this code
131 * is optimized for speed and a bit less for readability */
132 void sortCommand(redisClient *c) {
133 list *operations;
134 unsigned int outputlen = 0;
135 int desc = 0, alpha = 0;
136 int limit_start = 0, limit_count = -1, start, end;
137 int j, dontsort = 0, vectorlen;
138 int getop = 0; /* GET operation counter */
139 robj *sortval, *sortby = NULL, *storekey = NULL;
140 redisSortObject *vector; /* Resulting vector to sort */
141
142 /* Lookup the key to sort. It must be of the right types */
143 sortval = lookupKeyRead(c->db,c->argv[1]);
144 if (sortval == NULL) {
145 addReply(c,shared.emptymultibulk);
146 return;
147 }
148 if (sortval->type != REDIS_SET && sortval->type != REDIS_LIST &&
149 sortval->type != REDIS_ZSET)
150 {
151 addReply(c,shared.wrongtypeerr);
152 return;
153 }
154
155 /* Create a list of operations to perform for every sorted element.
156 * Operations can be GET/DEL/INCR/DECR */
157 operations = listCreate();
158 listSetFreeMethod(operations,zfree);
159 j = 2;
160
161 /* Now we need to protect sortval incrementing its count, in the future
162 * SORT may have options able to overwrite/delete keys during the sorting
163 * and the sorted key itself may get destroied */
164 incrRefCount(sortval);
165
166 /* The SORT command has an SQL-alike syntax, parse it */
167 while(j < c->argc) {
168 int leftargs = c->argc-j-1;
169 if (!strcasecmp(c->argv[j]->ptr,"asc")) {
170 desc = 0;
171 } else if (!strcasecmp(c->argv[j]->ptr,"desc")) {
172 desc = 1;
173 } else if (!strcasecmp(c->argv[j]->ptr,"alpha")) {
174 alpha = 1;
175 } else if (!strcasecmp(c->argv[j]->ptr,"limit") && leftargs >= 2) {
176 limit_start = atoi(c->argv[j+1]->ptr);
177 limit_count = atoi(c->argv[j+2]->ptr);
178 j+=2;
179 } else if (!strcasecmp(c->argv[j]->ptr,"store") && leftargs >= 1) {
180 storekey = c->argv[j+1];
181 j++;
182 } else if (!strcasecmp(c->argv[j]->ptr,"by") && leftargs >= 1) {
183 sortby = c->argv[j+1];
184 /* If the BY pattern does not contain '*', i.e. it is constant,
185 * we don't need to sort nor to lookup the weight keys. */
186 if (strchr(c->argv[j+1]->ptr,'*') == NULL) dontsort = 1;
187 j++;
188 } else if (!strcasecmp(c->argv[j]->ptr,"get") && leftargs >= 1) {
189 listAddNodeTail(operations,createSortOperation(
190 REDIS_SORT_GET,c->argv[j+1]));
191 getop++;
192 j++;
193 } else {
194 decrRefCount(sortval);
195 listRelease(operations);
196 addReply(c,shared.syntaxerr);
197 return;
198 }
199 j++;
200 }
201
202 /* Load the sorting vector with all the objects to sort */
203 switch(sortval->type) {
204 case REDIS_LIST: vectorlen = listTypeLength(sortval); break;
205 case REDIS_SET: vectorlen = dictSize((dict*)sortval->ptr); break;
206 case REDIS_ZSET: vectorlen = dictSize(((zset*)sortval->ptr)->dict); break;
207 default: vectorlen = 0; redisPanic("Bad SORT type"); /* Avoid GCC warning */
208 }
209 vector = zmalloc(sizeof(redisSortObject)*vectorlen);
210 j = 0;
211
212 if (sortval->type == REDIS_LIST) {
213 listTypeIterator *li = listTypeInitIterator(sortval,0,REDIS_TAIL);
214 listTypeEntry entry;
215 while(listTypeNext(li,&entry)) {
216 vector[j].obj = listTypeGet(&entry);
217 vector[j].u.score = 0;
218 vector[j].u.cmpobj = NULL;
219 j++;
220 }
221 listTypeReleaseIterator(li);
222 } else {
223 dict *set;
224 dictIterator *di;
225 dictEntry *setele;
226
227 if (sortval->type == REDIS_SET) {
228 set = sortval->ptr;
229 } else {
230 zset *zs = sortval->ptr;
231 set = zs->dict;
232 }
233
234 di = dictGetIterator(set);
235 while((setele = dictNext(di)) != NULL) {
236 vector[j].obj = dictGetEntryKey(setele);
237 vector[j].u.score = 0;
238 vector[j].u.cmpobj = NULL;
239 j++;
240 }
241 dictReleaseIterator(di);
242 }
243 redisAssert(j == vectorlen);
244
245 /* Now it's time to load the right scores in the sorting vector */
246 if (dontsort == 0) {
247 for (j = 0; j < vectorlen; j++) {
248 robj *byval;
249 if (sortby) {
250 /* lookup value to sort by */
251 byval = lookupKeyByPattern(c->db,sortby,vector[j].obj);
252 if (!byval) continue;
253 } else {
254 /* use object itself to sort by */
255 byval = vector[j].obj;
256 }
257
258 if (alpha) {
259 if (sortby) vector[j].u.cmpobj = getDecodedObject(byval);
260 } else {
261 if (byval->encoding == REDIS_ENCODING_RAW) {
262 vector[j].u.score = strtod(byval->ptr,NULL);
263 } else if (byval->encoding == REDIS_ENCODING_INT) {
264 /* Don't need to decode the object if it's
265 * integer-encoded (the only encoding supported) so
266 * far. We can just cast it */
267 vector[j].u.score = (long)byval->ptr;
268 } else {
269 redisAssert(1 != 1);
270 }
271 }
272
273 /* when the object was retrieved using lookupKeyByPattern,
274 * its refcount needs to be decreased. */
275 if (sortby) {
276 decrRefCount(byval);
277 }
278 }
279 }
280
281 /* We are ready to sort the vector... perform a bit of sanity check
282 * on the LIMIT option too. We'll use a partial version of quicksort. */
283 start = (limit_start < 0) ? 0 : limit_start;
284 end = (limit_count < 0) ? vectorlen-1 : start+limit_count-1;
285 if (start >= vectorlen) {
286 start = vectorlen-1;
287 end = vectorlen-2;
288 }
289 if (end >= vectorlen) end = vectorlen-1;
290
291 if (dontsort == 0) {
292 server.sort_desc = desc;
293 server.sort_alpha = alpha;
294 server.sort_bypattern = sortby ? 1 : 0;
295 if (sortby && (start != 0 || end != vectorlen-1))
296 pqsort(vector,vectorlen,sizeof(redisSortObject),sortCompare, start,end);
297 else
298 qsort(vector,vectorlen,sizeof(redisSortObject),sortCompare);
299 }
300
301 /* Send command output to the output buffer, performing the specified
302 * GET/DEL/INCR/DECR operations if any. */
303 outputlen = getop ? getop*(end-start+1) : end-start+1;
304 if (storekey == NULL) {
305 /* STORE option not specified, sent the sorting result to client */
306 addReplySds(c,sdscatprintf(sdsempty(),"*%d\r\n",outputlen));
307 for (j = start; j <= end; j++) {
308 listNode *ln;
309 listIter li;
310
311 if (!getop) addReplyBulk(c,vector[j].obj);
312 listRewind(operations,&li);
313 while((ln = listNext(&li))) {
314 redisSortOperation *sop = ln->value;
315 robj *val = lookupKeyByPattern(c->db,sop->pattern,
316 vector[j].obj);
317
318 if (sop->type == REDIS_SORT_GET) {
319 if (!val) {
320 addReply(c,shared.nullbulk);
321 } else {
322 addReplyBulk(c,val);
323 decrRefCount(val);
324 }
325 } else {
326 redisAssert(sop->type == REDIS_SORT_GET); /* always fails */
327 }
328 }
329 }
330 } else {
331 robj *sobj = createZiplistObject();
332
333 /* STORE option specified, set the sorting result as a List object */
334 for (j = start; j <= end; j++) {
335 listNode *ln;
336 listIter li;
337
338 if (!getop) {
339 listTypePush(sobj,vector[j].obj,REDIS_TAIL);
340 } else {
341 listRewind(operations,&li);
342 while((ln = listNext(&li))) {
343 redisSortOperation *sop = ln->value;
344 robj *val = lookupKeyByPattern(c->db,sop->pattern,
345 vector[j].obj);
346
347 if (sop->type == REDIS_SORT_GET) {
348 if (!val) val = createStringObject("",0);
349
350 /* listTypePush does an incrRefCount, so we should take care
351 * care of the incremented refcount caused by either
352 * lookupKeyByPattern or createStringObject("",0) */
353 listTypePush(sobj,val,REDIS_TAIL);
354 decrRefCount(val);
355 } else {
356 /* always fails */
357 redisAssert(sop->type == REDIS_SORT_GET);
358 }
359 }
360 }
361 }
362 dbReplace(c->db,storekey,sobj);
363 /* Note: we add 1 because the DB is dirty anyway since even if the
364 * SORT result is empty a new key is set and maybe the old content
365 * replaced. */
366 server.dirty += 1+outputlen;
367 touchWatchedKey(c->db,storekey);
368 addReplySds(c,sdscatprintf(sdsempty(),":%d\r\n",outputlen));
369 }
370
371 /* Cleanup */
372 if (sortval->type == REDIS_LIST)
373 for (j = 0; j < vectorlen; j++)
374 decrRefCount(vector[j].obj);
375 decrRefCount(sortval);
376 listRelease(operations);
377 for (j = 0; j < vectorlen; j++) {
378 if (alpha && vector[j].u.cmpobj)
379 decrRefCount(vector[j].u.cmpobj);
380 }
381 zfree(vector);
382 }
383
384