From ffc6b7f864dcaa58b6c5d81d7e595050fe954dec Mon Sep 17 00:00:00 2001 From: antirez Date: Thu, 1 Apr 2010 13:13:29 +0200 Subject: [PATCH] Pub/Sub pattern matching capabilities --- redis.c | 262 +++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 210 insertions(+), 52 deletions(-) diff --git a/redis.c b/redis.c index 097283da..bb0e2004 100644 --- a/redis.c +++ b/redis.c @@ -327,7 +327,8 @@ typedef struct redisClient { * is >= blockingto then the operation timed out. */ list *io_keys; /* Keys this client is waiting to be loaded from the * swap file in order to continue. */ - dict *pubsub_classes; /* Classes a client is interested in (SUBSCRIBE) */ + dict *pubsub_channels; /* channels a client is interested in (SUBSCRIBE) */ + list *pubsub_patterns; /* patterns a client is interested in (SUBSCRIBE) */ } redisClient; struct saveparam { @@ -437,11 +438,17 @@ struct redisServer { unsigned long long vm_stats_swapouts; unsigned long long vm_stats_swapins; /* Pubsub */ - dict *pubsub_classes; /* Associate classes to list of subscribed clients */ + dict *pubsub_channels; /* Map channels to list of subscribed clients */ + list *pubsub_patterns; /* A list of pubsub_patterns */ /* Misc */ FILE *devnull; }; +typedef struct pubsubPattern { + redisClient *client; + robj *pattern; +} pubsubPattern; + typedef void redisCommandProc(redisClient *c); struct redisCommand { char *name; @@ -506,7 +513,8 @@ struct sharedObjectsStruct { *outofrangeerr, *plus, *select0, *select1, *select2, *select3, *select4, *select5, *select6, *select7, *select8, *select9, - *messagebulk, *subscribebulk, *unsubscribebulk, *mbulk3; + *messagebulk, *subscribebulk, *unsubscribebulk, *mbulk3, + *psubscribebulk, *punsubscribebulk; } shared; /* Global vars that are actally used as constants. The following double @@ -606,7 +614,11 @@ static struct redisCommand *lookupCommand(char *name); static void call(redisClient *c, struct redisCommand *cmd); static void resetClient(redisClient *c); static void convertToRealHash(robj *o); -static int pubsubUnsubscribeAll(redisClient *c, int notify); +static int pubsubUnsubscribeAllChannels(redisClient *c, int notify); +static int pubsubUnsubscribeAllPatterns(redisClient *c, int notify); +static void freePubsubPattern(void *p); +static int listMatchPubsubPattern(void *a, void *b); +static int compareStringObjects(robj *a, robj *b); static void usage(); static void authCommand(redisClient *c); @@ -707,6 +719,8 @@ static void configCommand(redisClient *c); static void hincrbyCommand(redisClient *c); static void subscribeCommand(redisClient *c); static void unsubscribeCommand(redisClient *c); +static void psubscribeCommand(redisClient *c); +static void punsubscribeCommand(redisClient *c); static void publishCommand(redisClient *c); /*================================= Globals ================================= */ @@ -813,6 +827,8 @@ static struct redisCommand cmdTable[] = { {"config",configCommand,-2,REDIS_CMD_BULK,NULL,0,0,0}, {"subscribe",subscribeCommand,-2,REDIS_CMD_INLINE,NULL,0,0,0}, {"unsubscribe",unsubscribeCommand,-1,REDIS_CMD_INLINE,NULL,0,0,0}, + {"psubscribe",psubscribeCommand,-2,REDIS_CMD_INLINE,NULL,0,0,0}, + {"punsubscribe",punsubscribeCommand,-1,REDIS_CMD_INLINE,NULL,0,0,0}, {"publish",publishCommand,3,REDIS_CMD_BULK,NULL,0,0,0}, {NULL,NULL,0,0,NULL,0,0,0} }; @@ -1152,7 +1168,8 @@ static void closeTimedoutClients(void) { if (server.maxidletime && !(c->flags & REDIS_SLAVE) && /* no timeout for slaves */ !(c->flags & REDIS_MASTER) && /* no timeout for masters */ - dictSize(c->pubsub_classes) == 0 && /* no timeout for pubsub */ + dictSize(c->pubsub_channels) == 0 && /* no timeout for pubsub */ + listLength(c->pubsub_patterns) == 0 && (now - c->lastinteraction > server.maxidletime)) { redisLog(REDIS_VERBOSE,"Closing idle client"); @@ -1488,6 +1505,8 @@ static void createSharedObjects(void) { shared.messagebulk = createStringObject("$7\r\nmessage\r\n",13); shared.subscribebulk = createStringObject("$9\r\nsubscribe\r\n",15); shared.unsubscribebulk = createStringObject("$11\r\nunsubscribe\r\n",18); + shared.psubscribebulk = createStringObject("$10\r\npsubscribe\r\n",17); + shared.punsubscribebulk = createStringObject("$12\r\npunsubscribe\r\n",19); shared.mbulk3 = createStringObject("*3\r\n",4); } @@ -1592,7 +1611,10 @@ static void initServer() { server.db[j].io_keys = dictCreate(&keylistDictType,NULL); server.db[j].id = j; } - server.pubsub_classes = dictCreate(&keylistDictType,NULL); + server.pubsub_channels = dictCreate(&keylistDictType,NULL); + server.pubsub_patterns = listCreate(); + listSetFreeMethod(server.pubsub_patterns,freePubsubPattern); + listSetMatchMethod(server.pubsub_patterns,listMatchPubsubPattern); server.cronloops = 0; server.bgsavechildpid = -1; server.bgrewritechildpid = -1; @@ -1856,9 +1878,11 @@ static void freeClient(redisClient *c) { if (c->flags & REDIS_BLOCKED) unblockClientWaitingData(c); - /* Unsubscribe from all the pubsub classes */ - pubsubUnsubscribeAll(c,0); - dictRelease(c->pubsub_classes); + /* Unsubscribe from all the pubsub channels */ + pubsubUnsubscribeAllChannels(c,0); + pubsubUnsubscribeAllPatterns(c,0); + dictRelease(c->pubsub_channels); + listRelease(c->pubsub_patterns); /* Obvious cleanup */ aeDeleteFileEvent(server.el,c->fd,AE_READABLE); aeDeleteFileEvent(server.el,c->fd,AE_WRITABLE); @@ -2266,9 +2290,10 @@ static int processCommand(redisClient *c) { } /* Only allow SUBSCRIBE and UNSUBSCRIBE in the context of Pub/Sub */ - if (dictSize(c->pubsub_classes) > 0 && - cmd->proc != subscribeCommand && cmd->proc != unsubscribeCommand) { - addReplySds(c,sdsnew("-ERR only SUBSCRIBE / UNSUBSCRIBE / QUIT allowed in this context\r\n")); + if (dictSize(c->pubsub_channels) > 0 && + cmd->proc != subscribeCommand && cmd->proc != unsubscribeCommand && + cmd->proc != psubscribeCommand && cmd->proc != punsubscribeCommand) { + addReplySds(c,sdsnew("-ERR only (P)SUBSCRIBE / (P)UNSUBSCRIBE / QUIT allowed in this context\r\n")); resetClient(c); return 1; } @@ -2484,6 +2509,10 @@ static void *dupClientReplyValue(void *o) { return o; } +static int listMatchObjects(void *a, void *b) { + return compareStringObjects(a,b) == 0; +} + static redisClient *createClient(int fd) { redisClient *c = zmalloc(sizeof(*c)); @@ -2510,8 +2539,11 @@ static redisClient *createClient(int fd) { c->blockingkeys = NULL; c->blockingkeysnum = 0; c->io_keys = listCreate(); - c->pubsub_classes = dictCreate(&setDictType,NULL); listSetFreeMethod(c->io_keys,decrRefCount); + c->pubsub_channels = dictCreate(&setDictType,NULL); + c->pubsub_patterns = listCreate(); + listSetFreeMethod(c->pubsub_patterns,decrRefCount); + listSetMatchMethod(c->pubsub_patterns,listMatchObjects); if (aeCreateFileEvent(server.el, c->fd, AE_READABLE, readQueryFromClient, c) == AE_ERR) { freeClient(c); @@ -6675,7 +6707,8 @@ static sds genRedisInfoString(void) { "expired_keys:%lld\r\n" "hash_max_zipmap_entries:%ld\r\n" "hash_max_zipmap_value:%ld\r\n" - "pubsub_classes:%ld\r\n" + "pubsub_channels:%ld\r\n" + "pubsub_patterns:%u\r\n" "vm_enabled:%d\r\n" "role:%s\r\n" ,REDIS_VERSION, @@ -6698,7 +6731,8 @@ static sds genRedisInfoString(void) { server.stat_expiredkeys, server.hash_max_zipmap_entries, server.hash_max_zipmap_value, - dictSize(server.pubsub_classes), + dictSize(server.pubsub_channels), + listLength(server.pubsub_patterns), server.vm_enabled != 0, server.masterhost == NULL ? "master" : "slave" ); @@ -9271,23 +9305,37 @@ badarity: /* =========================== Pubsub implementation ======================== */ -/* Subscribe a client to a class. Returns 1 if the operation succeeded, or - * 0 if the client was already subscribed to that class. */ -static int pubsubSubscribe(redisClient *c, robj *class) { +static void freePubsubPattern(void *p) { + pubsubPattern *pat = p; + + decrRefCount(pat->pattern); + zfree(pat); +} + +static int listMatchPubsubPattern(void *a, void *b) { + pubsubPattern *pa = a, *pb = b; + + return (pa->client == pb->client) && + (compareStringObjects(pa->pattern,pb->pattern) == 0); +} + +/* Subscribe a client to a channel. Returns 1 if the operation succeeded, or + * 0 if the client was already subscribed to that channel. */ +static int pubsubSubscribeChannel(redisClient *c, robj *channel) { struct dictEntry *de; list *clients = NULL; int retval = 0; - /* Add the class to the client -> classes hash table */ - if (dictAdd(c->pubsub_classes,class,NULL) == DICT_OK) { + /* Add the channel to the client -> channels hash table */ + if (dictAdd(c->pubsub_channels,channel,NULL) == DICT_OK) { retval = 1; - incrRefCount(class); - /* Add the client to the class -> list of clients hash table */ - de = dictFind(server.pubsub_classes,class); + incrRefCount(channel); + /* Add the client to the channel -> list of clients hash table */ + de = dictFind(server.pubsub_channels,channel); if (de == NULL) { clients = listCreate(); - dictAdd(server.pubsub_classes,class,clients); - incrRefCount(class); + dictAdd(server.pubsub_channels,channel,clients); + incrRefCount(channel); } else { clients = dictGetEntryVal(de); } @@ -9296,26 +9344,26 @@ static int pubsubSubscribe(redisClient *c, robj *class) { /* Notify the client */ addReply(c,shared.mbulk3); addReply(c,shared.subscribebulk); - addReplyBulk(c,class); - addReplyLong(c,dictSize(c->pubsub_classes)); + addReplyBulk(c,channel); + addReplyLong(c,dictSize(c->pubsub_channels)+listLength(c->pubsub_patterns)); return retval; } -/* Unsubscribe a client from a class. Returns 1 if the operation succeeded, or - * 0 if the client was not subscribed to the specified class. */ -static int pubsubUnsubscribe(redisClient *c, robj *class, int notify) { +/* Unsubscribe a client from a channel. Returns 1 if the operation succeeded, or + * 0 if the client was not subscribed to the specified channel. */ +static int pubsubUnsubscribeChannel(redisClient *c, robj *channel, int notify) { struct dictEntry *de; list *clients; listNode *ln; int retval = 0; - /* Remove the class from the client -> classes hash table */ - incrRefCount(class); /* class may be just a pointer to the same object + /* Remove the channel from the client -> channels hash table */ + incrRefCount(channel); /* channel may be just a pointer to the same object we have in the hash tables. Protect it... */ - if (dictDelete(c->pubsub_classes,class) == DICT_OK) { + if (dictDelete(c->pubsub_channels,channel) == DICT_OK) { retval = 1; - /* Remove the client from the class -> clients list hash table */ - de = dictFind(server.pubsub_classes,class); + /* Remove the client from the channel -> clients list hash table */ + de = dictFind(server.pubsub_channels,channel); assert(de != NULL); clients = dictGetEntryVal(de); ln = listSearchKey(clients,c); @@ -9324,43 +9372,114 @@ static int pubsubUnsubscribe(redisClient *c, robj *class, int notify) { if (listLength(clients) == 0) { /* Free the list and associated hash entry at all if this was * the latest client, so that it will be possible to abuse - * Redis PUBSUB creating millions of classes. */ - dictDelete(server.pubsub_classes,class); + * Redis PUBSUB creating millions of channels. */ + dictDelete(server.pubsub_channels,channel); } } /* Notify the client */ if (notify) { addReply(c,shared.mbulk3); addReply(c,shared.unsubscribebulk); - addReplyBulk(c,class); - addReplyLong(c,dictSize(c->pubsub_classes)); + addReplyBulk(c,channel); + addReplyLong(c,dictSize(c->pubsub_channels)+ + listLength(c->pubsub_patterns)); + + } + decrRefCount(channel); /* it is finally safe to release it */ + return retval; +} + +/* Subscribe a client to a pattern. Returns 1 if the operation succeeded, or 0 if the clinet was already subscribed to that pattern. */ +static int pubsubSubscribePattern(redisClient *c, robj *pattern) { + int retval = 0; + + if (listSearchKey(c->pubsub_patterns,pattern) == NULL) { + retval = 1; + pubsubPattern *pat; + listAddNodeTail(c->pubsub_patterns,pattern); + incrRefCount(pattern); + pat = zmalloc(sizeof(*pat)); + pat->pattern = getDecodedObject(pattern); + pat->client = c; + listAddNodeTail(server.pubsub_patterns,pat); + } + /* Notify the client */ + addReply(c,shared.mbulk3); + addReply(c,shared.psubscribebulk); + addReplyBulk(c,pattern); + addReplyLong(c,dictSize(c->pubsub_channels)+listLength(c->pubsub_patterns)); + return retval; +} + +/* Unsubscribe a client from a channel. Returns 1 if the operation succeeded, or + * 0 if the client was not subscribed to the specified channel. */ +static int pubsubUnsubscribePattern(redisClient *c, robj *pattern, int notify) { + listNode *ln; + pubsubPattern pat; + int retval = 0; + + incrRefCount(pattern); /* Protect the object. May be the same we remove */ + if ((ln = listSearchKey(c->pubsub_patterns,pattern)) != NULL) { + retval = 1; + listDelNode(c->pubsub_patterns,ln); + pat.client = c; + pat.pattern = pattern; + ln = listSearchKey(server.pubsub_patterns,&pat); + listDelNode(server.pubsub_patterns,ln); + } + /* Notify the client */ + if (notify) { + addReply(c,shared.mbulk3); + addReply(c,shared.punsubscribebulk); + addReplyBulk(c,pattern); + addReplyLong(c,dictSize(c->pubsub_channels)+ + listLength(c->pubsub_patterns)); } - decrRefCount(class); /* it is finally safe to release it */ + decrRefCount(pattern); return retval; } -/* Unsubscribe from all the classes. Return the number of classes the - * client was subscribed to. */ -static int pubsubUnsubscribeAll(redisClient *c, int notify) { - dictIterator *di = dictGetIterator(c->pubsub_classes); +/* Unsubscribe from all the channels. Return the number of channels the + * client was subscribed from. */ +static int pubsubUnsubscribeAllChannels(redisClient *c, int notify) { + dictIterator *di = dictGetIterator(c->pubsub_channels); dictEntry *de; int count = 0; while((de = dictNext(di)) != NULL) { - robj *class = dictGetEntryKey(de); + robj *channel = dictGetEntryKey(de); - count += pubsubUnsubscribe(c,class,notify); + count += pubsubUnsubscribeChannel(c,channel,notify); } dictReleaseIterator(di); return count; } +/* Unsubscribe from all the patterns. Return the number of patterns the + * client was subscribed from. */ +static int pubsubUnsubscribeAllPatterns(redisClient *c, int notify) { + listNode *ln; + listIter li; + int count = 0; + + listRewind(c->pubsub_patterns,&li); + while ((ln = listNext(&li)) != NULL) { + robj *pattern = ln->value; + + count += pubsubUnsubscribePattern(c,pattern,notify); + } + return count; +} + /* Publish a message */ -static int pubsubPublishMessage(robj *class, robj *message) { +static int pubsubPublishMessage(robj *channel, robj *message) { int receivers = 0; struct dictEntry *de; + listNode *ln; + listIter li; - de = dictFind(server.pubsub_classes,class); + /* Send to clients listening for that channel */ + de = dictFind(server.pubsub_channels,channel); if (de) { list *list = dictGetEntryVal(de); listNode *ln; @@ -9372,11 +9491,31 @@ static int pubsubPublishMessage(robj *class, robj *message) { addReply(c,shared.mbulk3); addReply(c,shared.messagebulk); - addReplyBulk(c,class); + addReplyBulk(c,channel); addReplyBulk(c,message); receivers++; } } + /* Send to clients listening to matching channels */ + if (listLength(server.pubsub_patterns)) { + listRewind(server.pubsub_patterns,&li); + channel = getDecodedObject(channel); + while ((ln = listNext(&li)) != NULL) { + pubsubPattern *pat = ln->value; + + if (stringmatchlen((char*)pat->pattern->ptr, + sdslen(pat->pattern->ptr), + (char*)channel->ptr, + sdslen(channel->ptr),0)) { + addReply(pat->client,shared.mbulk3); + addReply(pat->client,shared.messagebulk); + addReplyBulk(pat->client,channel); + addReplyBulk(pat->client,message); + receivers++; + } + } + decrRefCount(channel); + } return receivers; } @@ -9384,18 +9523,37 @@ static void subscribeCommand(redisClient *c) { int j; for (j = 1; j < c->argc; j++) - pubsubSubscribe(c,c->argv[j]); + pubsubSubscribeChannel(c,c->argv[j]); } static void unsubscribeCommand(redisClient *c) { if (c->argc == 1) { - pubsubUnsubscribeAll(c,1); + pubsubUnsubscribeAllChannels(c,1); + return; + } else { + int j; + + for (j = 1; j < c->argc; j++) + pubsubUnsubscribeChannel(c,c->argv[j],1); + } +} + +static void psubscribeCommand(redisClient *c) { + int j; + + for (j = 1; j < c->argc; j++) + pubsubSubscribePattern(c,c->argv[j]); +} + +static void punsubscribeCommand(redisClient *c) { + if (c->argc == 1) { + pubsubUnsubscribeAllPatterns(c,1); return; } else { int j; for (j = 1; j < c->argc; j++) - pubsubUnsubscribe(c,c->argv[j],1); + pubsubUnsubscribePattern(c,c->argv[j],1); } } -- 2.45.2