]> git.saurik.com Git - redis.git/blob - src/sort.c
Merge pull request #258 from miaout17/bugfix-247
[redis.git] / src / sort.c
1 #include "redis.h"
2 #include "pqsort.h" /* Partial qsort for SORT+LIMIT */
3
4 redisSortOperation *createSortOperation(int type, robj *pattern) {
5 redisSortOperation *so = zmalloc(sizeof(*so));
6 so->type = type;
7 so->pattern = pattern;
8 return so;
9 }
10
11 /* Return the value associated to the key with a name obtained
12 * substituting the first occurence of '*' in 'pattern' with 'subst'.
13 * The returned object will always have its refcount increased by 1
14 * when it is non-NULL. */
15 robj *lookupKeyByPattern(redisDb *db, robj *pattern, robj *subst) {
16 char *p, *f;
17 sds spat, ssub;
18 robj keyobj, fieldobj, *o;
19 int prefixlen, sublen, postfixlen, fieldlen;
20 /* Expoit the internal sds representation to create a sds string allocated on the stack in order to make this function faster */
21 struct {
22 int len;
23 int free;
24 char buf[REDIS_SORTKEY_MAX+1];
25 } keyname, fieldname;
26
27 /* If the pattern is "#" return the substitution object itself in order
28 * to implement the "SORT ... GET #" feature. */
29 spat = pattern->ptr;
30 if (spat[0] == '#' && spat[1] == '\0') {
31 incrRefCount(subst);
32 return subst;
33 }
34
35 /* The substitution object may be specially encoded. If so we create
36 * a decoded object on the fly. Otherwise getDecodedObject will just
37 * increment the ref count, that we'll decrement later. */
38 subst = getDecodedObject(subst);
39
40 ssub = subst->ptr;
41 if (sdslen(spat)+sdslen(ssub)-1 > REDIS_SORTKEY_MAX) return NULL;
42 p = strchr(spat,'*');
43 if (!p) {
44 decrRefCount(subst);
45 return NULL;
46 }
47
48 /* Find out if we're dealing with a hash dereference. */
49 if ((f = strstr(p+1, "->")) != NULL) {
50 fieldlen = sdslen(spat)-(f-spat);
51 /* this also copies \0 character */
52 memcpy(fieldname.buf,f+2,fieldlen-1);
53 fieldname.len = fieldlen-2;
54 } else {
55 fieldlen = 0;
56 }
57
58 prefixlen = p-spat;
59 sublen = sdslen(ssub);
60 postfixlen = sdslen(spat)-(prefixlen+1)-fieldlen;
61 memcpy(keyname.buf,spat,prefixlen);
62 memcpy(keyname.buf+prefixlen,ssub,sublen);
63 memcpy(keyname.buf+prefixlen+sublen,p+1,postfixlen);
64 keyname.buf[prefixlen+sublen+postfixlen] = '\0';
65 keyname.len = prefixlen+sublen+postfixlen;
66 decrRefCount(subst);
67
68 /* Lookup substituted key */
69 initStaticStringObject(keyobj,((char*)&keyname)+(sizeof(struct sdshdr)));
70 o = lookupKeyRead(db,&keyobj);
71 if (o == NULL) return NULL;
72
73 if (fieldlen > 0) {
74 if (o->type != REDIS_HASH || fieldname.len < 1) return NULL;
75
76 /* Retrieve value from hash by the field name. This operation
77 * already increases the refcount of the returned object. */
78 initStaticStringObject(fieldobj,((char*)&fieldname)+(sizeof(struct sdshdr)));
79 o = hashTypeGetObject(o, &fieldobj);
80 } else {
81 if (o->type != REDIS_STRING) return NULL;
82
83 /* Every object that this function returns needs to have its refcount
84 * increased. sortCommand decreases it again. */
85 incrRefCount(o);
86 }
87
88 return o;
89 }
90
91 /* sortCompare() is used by qsort in sortCommand(). Given that qsort_r with
92 * the additional parameter is not standard but a BSD-specific we have to
93 * pass sorting parameters via the global 'server' structure */
94 int sortCompare(const void *s1, const void *s2) {
95 const redisSortObject *so1 = s1, *so2 = s2;
96 int cmp;
97
98 if (!server.sort_alpha) {
99 /* Numeric sorting. Here it's trivial as we precomputed scores */
100 if (so1->u.score > so2->u.score) {
101 cmp = 1;
102 } else if (so1->u.score < so2->u.score) {
103 cmp = -1;
104 } else {
105 cmp = 0;
106 }
107 } else {
108 /* Alphanumeric sorting */
109 if (server.sort_bypattern) {
110 if (!so1->u.cmpobj || !so2->u.cmpobj) {
111 /* At least one compare object is NULL */
112 if (so1->u.cmpobj == so2->u.cmpobj)
113 cmp = 0;
114 else if (so1->u.cmpobj == NULL)
115 cmp = -1;
116 else
117 cmp = 1;
118 } else {
119 /* We have both the objects, use strcoll */
120 cmp = strcoll(so1->u.cmpobj->ptr,so2->u.cmpobj->ptr);
121 }
122 } else {
123 /* Compare elements directly. */
124 cmp = compareStringObjects(so1->obj,so2->obj);
125 }
126 }
127 return server.sort_desc ? -cmp : cmp;
128 }
129
130 /* The SORT command is the most complex command in Redis. Warning: this code
131 * is optimized for speed and a bit less for readability */
132 void sortCommand(redisClient *c) {
133 list *operations;
134 unsigned int outputlen = 0;
135 int desc = 0, alpha = 0;
136 long limit_start = 0, limit_count = -1, start, end;
137 int j, dontsort = 0, vectorlen;
138 int getop = 0; /* GET operation counter */
139 robj *sortval, *sortby = NULL, *storekey = NULL;
140 redisSortObject *vector; /* Resulting vector to sort */
141
142 /* Lookup the key to sort. It must be of the right types */
143 sortval = lookupKeyRead(c->db,c->argv[1]);
144 if (sortval && sortval->type != REDIS_SET && sortval->type != REDIS_LIST &&
145 sortval->type != REDIS_ZSET)
146 {
147 addReply(c,shared.wrongtypeerr);
148 return;
149 }
150
151 /* Create a list of operations to perform for every sorted element.
152 * Operations can be GET/DEL/INCR/DECR */
153 operations = listCreate();
154 listSetFreeMethod(operations,zfree);
155 j = 2;
156
157 /* Now we need to protect sortval incrementing its count, in the future
158 * SORT may have options able to overwrite/delete keys during the sorting
159 * and the sorted key itself may get destroied */
160 if (sortval)
161 incrRefCount(sortval);
162 else
163 sortval = createListObject();
164
165 /* The SORT command has an SQL-alike syntax, parse it */
166 while(j < c->argc) {
167 int leftargs = c->argc-j-1;
168 if (!strcasecmp(c->argv[j]->ptr,"asc")) {
169 desc = 0;
170 } else if (!strcasecmp(c->argv[j]->ptr,"desc")) {
171 desc = 1;
172 } else if (!strcasecmp(c->argv[j]->ptr,"alpha")) {
173 alpha = 1;
174 } else if (!strcasecmp(c->argv[j]->ptr,"limit") && leftargs >= 2) {
175 if ((getLongFromObjectOrReply(c, c->argv[j+1], &limit_start, NULL) != REDIS_OK) ||
176 (getLongFromObjectOrReply(c, c->argv[j+2], &limit_count, NULL) != REDIS_OK)) return;
177 j+=2;
178 } else if (!strcasecmp(c->argv[j]->ptr,"store") && leftargs >= 1) {
179 storekey = c->argv[j+1];
180 j++;
181 } else if (!strcasecmp(c->argv[j]->ptr,"by") && leftargs >= 1) {
182 sortby = c->argv[j+1];
183 /* If the BY pattern does not contain '*', i.e. it is constant,
184 * we don't need to sort nor to lookup the weight keys. */
185 if (strchr(c->argv[j+1]->ptr,'*') == NULL) dontsort = 1;
186 j++;
187 } else if (!strcasecmp(c->argv[j]->ptr,"get") && leftargs >= 1) {
188 listAddNodeTail(operations,createSortOperation(
189 REDIS_SORT_GET,c->argv[j+1]));
190 getop++;
191 j++;
192 } else {
193 decrRefCount(sortval);
194 listRelease(operations);
195 addReply(c,shared.syntaxerr);
196 return;
197 }
198 j++;
199 }
200
201 /* Destructively convert encoded sorted sets for SORT. */
202 if (sortval->type == REDIS_ZSET)
203 zsetConvert(sortval, REDIS_ENCODING_SKIPLIST);
204
205 /* Load the sorting vector with all the objects to sort */
206 switch(sortval->type) {
207 case REDIS_LIST: vectorlen = listTypeLength(sortval); break;
208 case REDIS_SET: vectorlen = setTypeSize(sortval); break;
209 case REDIS_ZSET: vectorlen = dictSize(((zset*)sortval->ptr)->dict); break;
210 default: vectorlen = 0; redisPanic("Bad SORT type"); /* Avoid GCC warning */
211 }
212 vector = zmalloc(sizeof(redisSortObject)*vectorlen);
213 j = 0;
214
215 if (sortval->type == REDIS_LIST) {
216 listTypeIterator *li = listTypeInitIterator(sortval,0,REDIS_TAIL);
217 listTypeEntry entry;
218 while(listTypeNext(li,&entry)) {
219 vector[j].obj = listTypeGet(&entry);
220 vector[j].u.score = 0;
221 vector[j].u.cmpobj = NULL;
222 j++;
223 }
224 listTypeReleaseIterator(li);
225 } else if (sortval->type == REDIS_SET) {
226 setTypeIterator *si = setTypeInitIterator(sortval);
227 robj *ele;
228 while((ele = setTypeNextObject(si)) != NULL) {
229 vector[j].obj = ele;
230 vector[j].u.score = 0;
231 vector[j].u.cmpobj = NULL;
232 j++;
233 }
234 setTypeReleaseIterator(si);
235 } else if (sortval->type == REDIS_ZSET) {
236 dict *set = ((zset*)sortval->ptr)->dict;
237 dictIterator *di;
238 dictEntry *setele;
239 di = dictGetIterator(set);
240 while((setele = dictNext(di)) != NULL) {
241 vector[j].obj = dictGetKey(setele);
242 vector[j].u.score = 0;
243 vector[j].u.cmpobj = NULL;
244 j++;
245 }
246 dictReleaseIterator(di);
247 } else {
248 redisPanic("Unknown type");
249 }
250 redisAssertWithInfo(c,sortval,j == vectorlen);
251
252 /* Now it's time to load the right scores in the sorting vector */
253 if (dontsort == 0) {
254 for (j = 0; j < vectorlen; j++) {
255 robj *byval;
256 if (sortby) {
257 /* lookup value to sort by */
258 byval = lookupKeyByPattern(c->db,sortby,vector[j].obj);
259 if (!byval) continue;
260 } else {
261 /* use object itself to sort by */
262 byval = vector[j].obj;
263 }
264
265 if (alpha) {
266 if (sortby) vector[j].u.cmpobj = getDecodedObject(byval);
267 } else {
268 if (byval->encoding == REDIS_ENCODING_RAW) {
269 vector[j].u.score = strtod(byval->ptr,NULL);
270 } else if (byval->encoding == REDIS_ENCODING_INT) {
271 /* Don't need to decode the object if it's
272 * integer-encoded (the only encoding supported) so
273 * far. We can just cast it */
274 vector[j].u.score = (long)byval->ptr;
275 } else {
276 redisAssertWithInfo(c,sortval,1 != 1);
277 }
278 }
279
280 /* when the object was retrieved using lookupKeyByPattern,
281 * its refcount needs to be decreased. */
282 if (sortby) {
283 decrRefCount(byval);
284 }
285 }
286 }
287
288 /* We are ready to sort the vector... perform a bit of sanity check
289 * on the LIMIT option too. We'll use a partial version of quicksort. */
290 start = (limit_start < 0) ? 0 : limit_start;
291 end = (limit_count < 0) ? vectorlen-1 : start+limit_count-1;
292 if (start >= vectorlen) {
293 start = vectorlen-1;
294 end = vectorlen-2;
295 }
296 if (end >= vectorlen) end = vectorlen-1;
297
298 if (dontsort == 0) {
299 server.sort_desc = desc;
300 server.sort_alpha = alpha;
301 server.sort_bypattern = sortby ? 1 : 0;
302 if (sortby && (start != 0 || end != vectorlen-1))
303 pqsort(vector,vectorlen,sizeof(redisSortObject),sortCompare, start,end);
304 else
305 qsort(vector,vectorlen,sizeof(redisSortObject),sortCompare);
306 }
307
308 /* Send command output to the output buffer, performing the specified
309 * GET/DEL/INCR/DECR operations if any. */
310 outputlen = getop ? getop*(end-start+1) : end-start+1;
311 if (storekey == NULL) {
312 /* STORE option not specified, sent the sorting result to client */
313 addReplyMultiBulkLen(c,outputlen);
314 for (j = start; j <= end; j++) {
315 listNode *ln;
316 listIter li;
317
318 if (!getop) addReplyBulk(c,vector[j].obj);
319 listRewind(operations,&li);
320 while((ln = listNext(&li))) {
321 redisSortOperation *sop = ln->value;
322 robj *val = lookupKeyByPattern(c->db,sop->pattern,
323 vector[j].obj);
324
325 if (sop->type == REDIS_SORT_GET) {
326 if (!val) {
327 addReply(c,shared.nullbulk);
328 } else {
329 addReplyBulk(c,val);
330 decrRefCount(val);
331 }
332 } else {
333 /* Always fails */
334 redisAssertWithInfo(c,sortval,sop->type == REDIS_SORT_GET);
335 }
336 }
337 }
338 } else {
339 robj *sobj = createZiplistObject();
340
341 /* STORE option specified, set the sorting result as a List object */
342 for (j = start; j <= end; j++) {
343 listNode *ln;
344 listIter li;
345
346 if (!getop) {
347 listTypePush(sobj,vector[j].obj,REDIS_TAIL);
348 } else {
349 listRewind(operations,&li);
350 while((ln = listNext(&li))) {
351 redisSortOperation *sop = ln->value;
352 robj *val = lookupKeyByPattern(c->db,sop->pattern,
353 vector[j].obj);
354
355 if (sop->type == REDIS_SORT_GET) {
356 if (!val) val = createStringObject("",0);
357
358 /* listTypePush does an incrRefCount, so we should take care
359 * care of the incremented refcount caused by either
360 * lookupKeyByPattern or createStringObject("",0) */
361 listTypePush(sobj,val,REDIS_TAIL);
362 decrRefCount(val);
363 } else {
364 /* Always fails */
365 redisAssertWithInfo(c,sortval,sop->type == REDIS_SORT_GET);
366 }
367 }
368 }
369 }
370 if (outputlen) setKey(c->db,storekey,sobj);
371 decrRefCount(sobj);
372 server.dirty += outputlen;
373 addReplyLongLong(c,outputlen);
374 }
375
376 /* Cleanup */
377 if (sortval->type == REDIS_LIST || sortval->type == REDIS_SET)
378 for (j = 0; j < vectorlen; j++)
379 decrRefCount(vector[j].obj);
380 decrRefCount(sortval);
381 listRelease(operations);
382 for (j = 0; j < vectorlen; j++) {
383 if (alpha && vector[j].u.cmpobj)
384 decrRefCount(vector[j].u.cmpobj);
385 }
386 zfree(vector);
387 }
388
389