]> git.saurik.com Git - redis.git/blob - src/sort.c
Finished code for sorted set memory efficiency
[redis.git] / src / sort.c
1 #include "redis.h"
2 #include "pqsort.h" /* Partial qsort for SORT+LIMIT */
3
4 redisSortOperation *createSortOperation(int type, robj *pattern) {
5 redisSortOperation *so = zmalloc(sizeof(*so));
6 so->type = type;
7 so->pattern = pattern;
8 return so;
9 }
10
11 /* Return the value associated to the key with a name obtained
12 * substituting the first occurence of '*' in 'pattern' with 'subst'.
13 * The returned object will always have its refcount increased by 1
14 * when it is non-NULL. */
15 robj *lookupKeyByPattern(redisDb *db, robj *pattern, robj *subst) {
16 char *p, *f;
17 sds spat, ssub;
18 robj keyobj, fieldobj, *o;
19 int prefixlen, sublen, postfixlen, fieldlen;
20 /* Expoit the internal sds representation to create a sds string allocated on the stack in order to make this function faster */
21 struct {
22 int len;
23 int free;
24 char buf[REDIS_SORTKEY_MAX+1];
25 } keyname, fieldname;
26
27 /* If the pattern is "#" return the substitution object itself in order
28 * to implement the "SORT ... GET #" feature. */
29 spat = pattern->ptr;
30 if (spat[0] == '#' && spat[1] == '\0') {
31 incrRefCount(subst);
32 return subst;
33 }
34
35 /* The substitution object may be specially encoded. If so we create
36 * a decoded object on the fly. Otherwise getDecodedObject will just
37 * increment the ref count, that we'll decrement later. */
38 subst = getDecodedObject(subst);
39
40 ssub = subst->ptr;
41 if (sdslen(spat)+sdslen(ssub)-1 > REDIS_SORTKEY_MAX) return NULL;
42 p = strchr(spat,'*');
43 if (!p) {
44 decrRefCount(subst);
45 return NULL;
46 }
47
48 /* Find out if we're dealing with a hash dereference. */
49 if ((f = strstr(p+1, "->")) != NULL) {
50 fieldlen = sdslen(spat)-(f-spat);
51 /* this also copies \0 character */
52 memcpy(fieldname.buf,f+2,fieldlen-1);
53 fieldname.len = fieldlen-2;
54 } else {
55 fieldlen = 0;
56 }
57
58 prefixlen = p-spat;
59 sublen = sdslen(ssub);
60 postfixlen = sdslen(spat)-(prefixlen+1)-fieldlen;
61 memcpy(keyname.buf,spat,prefixlen);
62 memcpy(keyname.buf+prefixlen,ssub,sublen);
63 memcpy(keyname.buf+prefixlen+sublen,p+1,postfixlen);
64 keyname.buf[prefixlen+sublen+postfixlen] = '\0';
65 keyname.len = prefixlen+sublen+postfixlen;
66 decrRefCount(subst);
67
68 /* Lookup substituted key */
69 initStaticStringObject(keyobj,((char*)&keyname)+(sizeof(struct sdshdr)));
70 o = lookupKeyRead(db,&keyobj);
71 if (o == NULL) return NULL;
72
73 if (fieldlen > 0) {
74 if (o->type != REDIS_HASH || fieldname.len < 1) return NULL;
75
76 /* Retrieve value from hash by the field name. This operation
77 * already increases the refcount of the returned object. */
78 initStaticStringObject(fieldobj,((char*)&fieldname)+(sizeof(struct sdshdr)));
79 o = hashTypeGet(o, &fieldobj);
80 } else {
81 if (o->type != REDIS_STRING) return NULL;
82
83 /* Every object that this function returns needs to have its refcount
84 * increased. sortCommand decreases it again. */
85 incrRefCount(o);
86 }
87
88 return o;
89 }
90
91 /* sortCompare() is used by qsort in sortCommand(). Given that qsort_r with
92 * the additional parameter is not standard but a BSD-specific we have to
93 * pass sorting parameters via the global 'server' structure */
94 int sortCompare(const void *s1, const void *s2) {
95 const redisSortObject *so1 = s1, *so2 = s2;
96 int cmp;
97
98 if (!server.sort_alpha) {
99 /* Numeric sorting. Here it's trivial as we precomputed scores */
100 if (so1->u.score > so2->u.score) {
101 cmp = 1;
102 } else if (so1->u.score < so2->u.score) {
103 cmp = -1;
104 } else {
105 cmp = 0;
106 }
107 } else {
108 /* Alphanumeric sorting */
109 if (server.sort_bypattern) {
110 if (!so1->u.cmpobj || !so2->u.cmpobj) {
111 /* At least one compare object is NULL */
112 if (so1->u.cmpobj == so2->u.cmpobj)
113 cmp = 0;
114 else if (so1->u.cmpobj == NULL)
115 cmp = -1;
116 else
117 cmp = 1;
118 } else {
119 /* We have both the objects, use strcoll */
120 cmp = strcoll(so1->u.cmpobj->ptr,so2->u.cmpobj->ptr);
121 }
122 } else {
123 /* Compare elements directly. */
124 cmp = compareStringObjects(so1->obj,so2->obj);
125 }
126 }
127 return server.sort_desc ? -cmp : cmp;
128 }
129
130 /* The SORT command is the most complex command in Redis. Warning: this code
131 * is optimized for speed and a bit less for readability */
132 void sortCommand(redisClient *c) {
133 list *operations;
134 unsigned int outputlen = 0;
135 int desc = 0, alpha = 0;
136 int limit_start = 0, limit_count = -1, start, end;
137 int j, dontsort = 0, vectorlen;
138 int getop = 0; /* GET operation counter */
139 robj *sortval, *sortby = NULL, *storekey = NULL;
140 redisSortObject *vector; /* Resulting vector to sort */
141
142 /* Lookup the key to sort. It must be of the right types */
143 sortval = lookupKeyRead(c->db,c->argv[1]);
144 if (sortval == NULL) {
145 addReply(c,shared.emptymultibulk);
146 return;
147 }
148 if (sortval->type != REDIS_SET && sortval->type != REDIS_LIST &&
149 sortval->type != REDIS_ZSET)
150 {
151 addReply(c,shared.wrongtypeerr);
152 return;
153 }
154
155 /* Create a list of operations to perform for every sorted element.
156 * Operations can be GET/DEL/INCR/DECR */
157 operations = listCreate();
158 listSetFreeMethod(operations,zfree);
159 j = 2;
160
161 /* Now we need to protect sortval incrementing its count, in the future
162 * SORT may have options able to overwrite/delete keys during the sorting
163 * and the sorted key itself may get destroied */
164 incrRefCount(sortval);
165
166 /* The SORT command has an SQL-alike syntax, parse it */
167 while(j < c->argc) {
168 int leftargs = c->argc-j-1;
169 if (!strcasecmp(c->argv[j]->ptr,"asc")) {
170 desc = 0;
171 } else if (!strcasecmp(c->argv[j]->ptr,"desc")) {
172 desc = 1;
173 } else if (!strcasecmp(c->argv[j]->ptr,"alpha")) {
174 alpha = 1;
175 } else if (!strcasecmp(c->argv[j]->ptr,"limit") && leftargs >= 2) {
176 limit_start = atoi(c->argv[j+1]->ptr);
177 limit_count = atoi(c->argv[j+2]->ptr);
178 j+=2;
179 } else if (!strcasecmp(c->argv[j]->ptr,"store") && leftargs >= 1) {
180 storekey = c->argv[j+1];
181 j++;
182 } else if (!strcasecmp(c->argv[j]->ptr,"by") && leftargs >= 1) {
183 sortby = c->argv[j+1];
184 /* If the BY pattern does not contain '*', i.e. it is constant,
185 * we don't need to sort nor to lookup the weight keys. */
186 if (strchr(c->argv[j+1]->ptr,'*') == NULL) dontsort = 1;
187 j++;
188 } else if (!strcasecmp(c->argv[j]->ptr,"get") && leftargs >= 1) {
189 listAddNodeTail(operations,createSortOperation(
190 REDIS_SORT_GET,c->argv[j+1]));
191 getop++;
192 j++;
193 } else {
194 decrRefCount(sortval);
195 listRelease(operations);
196 addReply(c,shared.syntaxerr);
197 return;
198 }
199 j++;
200 }
201
202 /* Load the sorting vector with all the objects to sort */
203 switch(sortval->type) {
204 case REDIS_LIST: vectorlen = listTypeLength(sortval); break;
205 case REDIS_SET: vectorlen = setTypeSize(sortval); break;
206 case REDIS_ZSET: vectorlen = dictSize(((zset*)sortval->ptr)->dict); break;
207 default: vectorlen = 0; redisPanic("Bad SORT type"); /* Avoid GCC warning */
208 }
209 vector = zmalloc(sizeof(redisSortObject)*vectorlen);
210 j = 0;
211
212 if (sortval->type == REDIS_LIST) {
213 listTypeIterator *li = listTypeInitIterator(sortval,0,REDIS_TAIL);
214 listTypeEntry entry;
215 while(listTypeNext(li,&entry)) {
216 vector[j].obj = listTypeGet(&entry);
217 vector[j].u.score = 0;
218 vector[j].u.cmpobj = NULL;
219 j++;
220 }
221 listTypeReleaseIterator(li);
222 } else if (sortval->type == REDIS_SET) {
223 setTypeIterator *si = setTypeInitIterator(sortval);
224 robj *ele;
225 while((ele = setTypeNext(si)) != NULL) {
226 vector[j].obj = ele;
227 vector[j].u.score = 0;
228 vector[j].u.cmpobj = NULL;
229 j++;
230 }
231 setTypeReleaseIterator(si);
232 } else if (sortval->type == REDIS_ZSET) {
233 dict *set = ((zset*)sortval->ptr)->dict;
234 dictIterator *di;
235 dictEntry *setele;
236 di = dictGetIterator(set);
237 while((setele = dictNext(di)) != NULL) {
238 vector[j].obj = dictGetEntryKey(setele);
239 vector[j].u.score = 0;
240 vector[j].u.cmpobj = NULL;
241 j++;
242 }
243 dictReleaseIterator(di);
244 } else {
245 redisPanic("Unknown type");
246 }
247 redisAssert(j == vectorlen);
248
249 /* Now it's time to load the right scores in the sorting vector */
250 if (dontsort == 0) {
251 for (j = 0; j < vectorlen; j++) {
252 robj *byval;
253 if (sortby) {
254 /* lookup value to sort by */
255 byval = lookupKeyByPattern(c->db,sortby,vector[j].obj);
256 if (!byval) continue;
257 } else {
258 /* use object itself to sort by */
259 byval = vector[j].obj;
260 }
261
262 if (alpha) {
263 if (sortby) vector[j].u.cmpobj = getDecodedObject(byval);
264 } else {
265 if (byval->encoding == REDIS_ENCODING_RAW) {
266 vector[j].u.score = strtod(byval->ptr,NULL);
267 } else if (byval->encoding == REDIS_ENCODING_INT) {
268 /* Don't need to decode the object if it's
269 * integer-encoded (the only encoding supported) so
270 * far. We can just cast it */
271 vector[j].u.score = (long)byval->ptr;
272 } else {
273 redisAssert(1 != 1);
274 }
275 }
276
277 /* when the object was retrieved using lookupKeyByPattern,
278 * its refcount needs to be decreased. */
279 if (sortby) {
280 decrRefCount(byval);
281 }
282 }
283 }
284
285 /* We are ready to sort the vector... perform a bit of sanity check
286 * on the LIMIT option too. We'll use a partial version of quicksort. */
287 start = (limit_start < 0) ? 0 : limit_start;
288 end = (limit_count < 0) ? vectorlen-1 : start+limit_count-1;
289 if (start >= vectorlen) {
290 start = vectorlen-1;
291 end = vectorlen-2;
292 }
293 if (end >= vectorlen) end = vectorlen-1;
294
295 if (dontsort == 0) {
296 server.sort_desc = desc;
297 server.sort_alpha = alpha;
298 server.sort_bypattern = sortby ? 1 : 0;
299 if (sortby && (start != 0 || end != vectorlen-1))
300 pqsort(vector,vectorlen,sizeof(redisSortObject),sortCompare, start,end);
301 else
302 qsort(vector,vectorlen,sizeof(redisSortObject),sortCompare);
303 }
304
305 /* Send command output to the output buffer, performing the specified
306 * GET/DEL/INCR/DECR operations if any. */
307 outputlen = getop ? getop*(end-start+1) : end-start+1;
308 if (storekey == NULL) {
309 /* STORE option not specified, sent the sorting result to client */
310 addReplyMultiBulkLen(c,outputlen);
311 for (j = start; j <= end; j++) {
312 listNode *ln;
313 listIter li;
314
315 if (!getop) addReplyBulk(c,vector[j].obj);
316 listRewind(operations,&li);
317 while((ln = listNext(&li))) {
318 redisSortOperation *sop = ln->value;
319 robj *val = lookupKeyByPattern(c->db,sop->pattern,
320 vector[j].obj);
321
322 if (sop->type == REDIS_SORT_GET) {
323 if (!val) {
324 addReply(c,shared.nullbulk);
325 } else {
326 addReplyBulk(c,val);
327 decrRefCount(val);
328 }
329 } else {
330 redisAssert(sop->type == REDIS_SORT_GET); /* always fails */
331 }
332 }
333 }
334 } else {
335 robj *sobj = createZiplistObject();
336
337 /* STORE option specified, set the sorting result as a List object */
338 for (j = start; j <= end; j++) {
339 listNode *ln;
340 listIter li;
341
342 if (!getop) {
343 listTypePush(sobj,vector[j].obj,REDIS_TAIL);
344 } else {
345 listRewind(operations,&li);
346 while((ln = listNext(&li))) {
347 redisSortOperation *sop = ln->value;
348 robj *val = lookupKeyByPattern(c->db,sop->pattern,
349 vector[j].obj);
350
351 if (sop->type == REDIS_SORT_GET) {
352 if (!val) val = createStringObject("",0);
353
354 /* listTypePush does an incrRefCount, so we should take care
355 * care of the incremented refcount caused by either
356 * lookupKeyByPattern or createStringObject("",0) */
357 listTypePush(sobj,val,REDIS_TAIL);
358 decrRefCount(val);
359 } else {
360 /* always fails */
361 redisAssert(sop->type == REDIS_SORT_GET);
362 }
363 }
364 }
365 }
366 dbReplace(c->db,storekey,sobj);
367 /* Note: we add 1 because the DB is dirty anyway since even if the
368 * SORT result is empty a new key is set and maybe the old content
369 * replaced. */
370 server.dirty += 1+outputlen;
371 touchWatchedKey(c->db,storekey);
372 addReplyLongLong(c,outputlen);
373 }
374
375 /* Cleanup */
376 if (sortval->type == REDIS_LIST || sortval->type == REDIS_SET)
377 for (j = 0; j < vectorlen; j++)
378 decrRefCount(vector[j].obj);
379 decrRefCount(sortval);
380 listRelease(operations);
381 for (j = 0; j < vectorlen; j++) {
382 if (alpha && vector[j].u.cmpobj)
383 decrRefCount(vector[j].u.cmpobj);
384 }
385 zfree(vector);
386 }
387
388