From 852795959822b60cbed190e88f7821969bc35670 Mon Sep 17 00:00:00 2001 From: Chen Tianjie Date: Wed, 27 Dec 2023 17:40:45 +0800 Subject: [PATCH] Replace slots_to_channels radix tree with slot specific dictionaries for shard channels. (#12804) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We have achieved replacing `slots_to_keys` radix tree with key->slot linked list (#9356), and then replacing the list with slot specific dictionaries for keys (#11695). Shard channels behave just like keys in many ways, and we also need a slots->channels mapping. Currently this is still done by using a radix tree. So we should split `server.pubsubshard_channels` into 16384 dicts and drop the radix tree, just like what we did to DBs. Some benefits (basically the benefits of what we've done to DBs): 1. Optimize counting channels in a slot. This is currently used only in removing channels in a slot. But this is potentially more useful: sometimes we need to know how many channels there are in a specific slot when doing slot migration. Counting is now implemented by traversing the radix tree, and with this PR it will be as simple as calling `dictSize`, from O(n) to O(1). 2. The radix tree in the cluster has been removed. The shard channel names no longer require additional storage, which can save memory. 3. Potentially useful in slot migration, as shard channels are logically split by slots, thus making it easier to migrate, remove or add as a whole. 4. Avoid rehashing a big dict when there is a large number of channels. Drawbacks: 1. Takes more memory than using radix tree when there are relatively few shard channels. What this PR does: 1. in cluster mode, split `server.pubsubshard_channels` into 16384 dicts, in standalone mode, still use only one dict. 2. drop the `slots_to_channels` radix tree. 3. to save memory (to solve the drawback above), all 16384 dicts are created lazily, which means only when a channel is about to be inserted to the dict will the dict be initialized, and when all channels are deleted, the dict would delete itself. 5. use `server.shard_channel_count` to keep track of the number of all shard channels. --------- Co-authored-by: Viktor Söderqvist --- src/acl.c | 2 +- src/cluster.h | 2 - src/cluster_legacy.c | 73 +---------- src/cluster_legacy.h | 1 - src/pubsub.c | 171 +++++++++++++++---------- src/server.c | 15 ++- src/server.h | 6 +- tests/cluster/tests/26-pubsubshard.tcl | 38 +++++- 8 files changed, 159 insertions(+), 149 deletions(-) diff --git a/src/acl.c b/src/acl.c index 8ae867130..b7e43cffa 100644 --- a/src/acl.c +++ b/src/acl.c @@ -1906,7 +1906,7 @@ int ACLCheckAllPerm(client *c, int *idxptr) { int totalSubscriptions(void) { return dictSize(server.pubsub_patterns) + dictSize(server.pubsub_channels) + - dictSize(server.pubsubshard_channels); + server.shard_channel_count; } /* If 'new' can access all channels 'original' could then return NULL; diff --git a/src/cluster.h b/src/cluster.h index 97a4febd5..0bd1eb6a0 100644 --- a/src/cluster.h +++ b/src/cluster.h @@ -48,8 +48,6 @@ void clusterUpdateMyselfHostname(void); void clusterUpdateMyselfAnnouncedPorts(void); void clusterUpdateMyselfHumanNodename(void); -void slotToChannelAdd(sds channel); -void slotToChannelDel(sds channel); void clusterPropagatePublish(robj *channel, robj *message, int sharded); unsigned long getClusterConnectionsCount(void); diff --git a/src/cluster_legacy.c b/src/cluster_legacy.c index 801becf3e..f203a9416 100644 --- a/src/cluster_legacy.c +++ b/src/cluster_legacy.c @@ -1021,9 +1021,6 @@ void clusterInit(void) { exit(1); } - /* The slots -> channels map is a radix tree. Initialize it here. */ - server.cluster->slots_to_channels = raxNew(); - /* Set myself->port/cport/pport to my listening ports, we'll just need to * discover the IP address via MEET messages. */ deriveAnnouncedPorts(&myself->tcp_port, &myself->tls_port, &myself->cport); @@ -5075,7 +5072,7 @@ int verifyClusterConfigWithData(void) { /* Remove all the shard channel related information not owned by the current shard. */ static inline void removeAllNotOwnedShardChannelSubscriptions(void) { - if (!dictSize(server.pubsubshard_channels)) return; + if (!server.shard_channel_count) return; clusterNode *currmaster = clusterNodeIsMaster(myself) ? myself : myself->slaveof; for (int j = 0; j < CLUSTER_SLOTS; j++) { if (server.cluster->slots[j] != currmaster) { @@ -5664,27 +5661,9 @@ sds genClusterInfoString(void) { void removeChannelsInSlot(unsigned int slot) { - unsigned int channelcount = countChannelsInSlot(slot); - if (channelcount == 0) return; + if (countChannelsInSlot(slot) == 0) return; - /* Retrieve all the channels for the slot. */ - robj **channels = zmalloc(sizeof(robj*)*channelcount); - raxIterator iter; - int j = 0; - unsigned char indexed[2]; - - indexed[0] = (slot >> 8) & 0xff; - indexed[1] = slot & 0xff; - raxStart(&iter,server.cluster->slots_to_channels); - raxSeek(&iter,">=",indexed,2); - while(raxNext(&iter)) { - if (iter.key[0] != indexed[0] || iter.key[1] != indexed[1]) break; - channels[j++] = createStringObject((char*)iter.key + 2, iter.key_len - 2); - } - raxStop(&iter); - - pubsubUnsubscribeShardChannels(channels, channelcount); - zfree(channels); + pubsubShardUnsubscribeAllChannelsInSlot(slot); } @@ -5719,52 +5698,10 @@ unsigned int delKeysInSlot(unsigned int hashslot) { return j; } -/* ----------------------------------------------------------------------------- - * Operation(s) on channel rax tree. - * -------------------------------------------------------------------------- */ - -void slotToChannelUpdate(sds channel, int add) { - size_t keylen = sdslen(channel); - unsigned int hashslot = keyHashSlot(channel,keylen); - unsigned char buf[64]; - unsigned char *indexed = buf; - - if (keylen+2 > 64) indexed = zmalloc(keylen+2); - indexed[0] = (hashslot >> 8) & 0xff; - indexed[1] = hashslot & 0xff; - memcpy(indexed+2,channel,keylen); - if (add) { - raxInsert(server.cluster->slots_to_channels,indexed,keylen+2,NULL,NULL); - } else { - raxRemove(server.cluster->slots_to_channels,indexed,keylen+2,NULL); - } - if (indexed != buf) zfree(indexed); -} - -void slotToChannelAdd(sds channel) { - slotToChannelUpdate(channel,1); -} - -void slotToChannelDel(sds channel) { - slotToChannelUpdate(channel,0); -} - /* Get the count of the channels for a given slot. */ unsigned int countChannelsInSlot(unsigned int hashslot) { - raxIterator iter; - int j = 0; - unsigned char indexed[2]; - - indexed[0] = (hashslot >> 8) & 0xff; - indexed[1] = hashslot & 0xff; - raxStart(&iter,server.cluster->slots_to_channels); - raxSeek(&iter,">=",indexed,2); - while(raxNext(&iter)) { - if (iter.key[0] != indexed[0] || iter.key[1] != indexed[1]) break; - j++; - } - raxStop(&iter); - return j; + dict *d = server.pubsubshard_channels[hashslot]; + return d ? dictSize(d) : 0; } int clusterNodeIsMyself(clusterNode *n) { diff --git a/src/cluster_legacy.h b/src/cluster_legacy.h index 578b46fc3..a857184ab 100644 --- a/src/cluster_legacy.h +++ b/src/cluster_legacy.h @@ -318,7 +318,6 @@ struct clusterState { clusterNode *migrating_slots_to[CLUSTER_SLOTS]; clusterNode *importing_slots_from[CLUSTER_SLOTS]; clusterNode *slots[CLUSTER_SLOTS]; - rax *slots_to_channels; /* The following fields are used to take the slave state on elections. */ mstime_t failover_auth_time; /* Time of previous or next election. */ int failover_auth_count; /* Number of votes received so far. */ diff --git a/src/pubsub.c b/src/pubsub.c index 2fe7a3ff5..f8910ee4f 100644 --- a/src/pubsub.c +++ b/src/pubsub.c @@ -36,7 +36,7 @@ typedef struct pubsubtype { int shard; dict *(*clientPubSubChannels)(client*); int (*subscriptionCount)(client*); - dict **serverPubSubChannels; + dict **(*serverPubSubChannels)(unsigned int); robj **subscribeMsg; robj **unsubscribeMsg; robj **messageBulk; @@ -62,12 +62,22 @@ dict* getClientPubSubChannels(client *c); */ dict* getClientPubSubShardChannels(client *c); +/* + * Get server's global Pub/Sub channels dict. + */ +dict **getServerPubSubChannels(unsigned int slot); + +/* + * Get server's shard level Pub/Sub channels dict. + */ +dict **getServerPubSubShardChannels(unsigned int slot); + /* * Get list of channels client is subscribed to. * If a pattern is provided, the subset of channels is returned * matching the pattern. */ -void channelList(client *c, sds pat, dict* pubsub_channels); +void channelList(client *c, sds pat, dict** pubsub_channels, int is_sharded); /* * Pub/Sub type for global channels. @@ -76,7 +86,7 @@ pubsubtype pubSubType = { .shard = 0, .clientPubSubChannels = getClientPubSubChannels, .subscriptionCount = clientSubscriptionsCount, - .serverPubSubChannels = &server.pubsub_channels, + .serverPubSubChannels = getServerPubSubChannels, .subscribeMsg = &shared.subscribebulk, .unsubscribeMsg = &shared.unsubscribebulk, .messageBulk = &shared.messagebulk, @@ -89,7 +99,7 @@ pubsubtype pubSubShardType = { .shard = 1, .clientPubSubChannels = getClientPubSubShardChannels, .subscriptionCount = clientShardSubscriptionsCount, - .serverPubSubChannels = &server.pubsubshard_channels, + .serverPubSubChannels = getServerPubSubShardChannels, .subscribeMsg = &shared.ssubscribebulk, .unsubscribeMsg = &shared.sunsubscribebulk, .messageBulk = &shared.smessagebulk, @@ -213,7 +223,7 @@ int serverPubsubSubscriptionCount(void) { /* Return the number of pubsub shard level channels is handled. */ int serverPubsubShardSubscriptionCount(void) { - return dictSize(server.pubsubshard_channels); + return server.shard_channel_count; } @@ -235,6 +245,16 @@ dict* getClientPubSubShardChannels(client *c) { return c->pubsubshard_channels; } +dict **getServerPubSubChannels(unsigned int slot) { + UNUSED(slot); + return &server.pubsub_channels; +} + +dict **getServerPubSubShardChannels(unsigned int slot) { + serverAssert(server.cluster_enabled || slot == 0); + return &server.pubsubshard_channels[slot]; +} + /* Return the number of pubsub + pubsub shard level channels * a client is subscribed to. */ int clientTotalPubSubSubscriptionCount(client *c) { @@ -258,20 +278,32 @@ void unmarkClientAsPubSub(client *c) { /* Subscribe a client to a channel. Returns 1 if the operation succeeded, or * 0 if the client was already subscribed to that channel. */ int pubsubSubscribeChannel(client *c, robj *channel, pubsubtype type) { + dict **d_ptr; dictEntry *de; list *clients = NULL; int retval = 0; + unsigned int slot = 0; /* Add the channel to the client -> channels hash table */ if (dictAdd(type.clientPubSubChannels(c),channel,NULL) == DICT_OK) { retval = 1; incrRefCount(channel); /* Add the client to the channel -> list of clients hash table */ - de = dictFind(*type.serverPubSubChannels, channel); + if (server.cluster_enabled && type.shard) { + slot = c->slot; + } + d_ptr = type.serverPubSubChannels(slot); + if (*d_ptr == NULL) { + *d_ptr = dictCreate(&keylistDictType); + } + de = dictFind(*d_ptr, channel); if (de == NULL) { clients = listCreate(); - dictAdd(*type.serverPubSubChannels, channel, clients); + dictAdd(*d_ptr, channel, clients); incrRefCount(channel); + if (type.shard) { + server.shard_channel_count++; + } } else { clients = dictGetVal(de); } @@ -285,10 +317,12 @@ int pubsubSubscribeChannel(client *c, robj *channel, pubsubtype type) { /* Unsubscribe a client from a channel. Returns 1 if the operation succeeded, or * 0 if the client was not subscribed to the specified channel. */ int pubsubUnsubscribeChannel(client *c, robj *channel, int notify, pubsubtype type) { + dict *d; dictEntry *de; list *clients; listNode *ln; int retval = 0; + int slot = 0; /* Remove the channel from the client -> channels hash table */ incrRefCount(channel); /* channel may be just a pointer to the same object @@ -296,7 +330,12 @@ int pubsubUnsubscribeChannel(client *c, robj *channel, int notify, pubsubtype ty if (dictDelete(type.clientPubSubChannels(c),channel) == DICT_OK) { retval = 1; /* Remove the client from the channel -> clients list hash table */ - de = dictFind(*type.serverPubSubChannels, channel); + if (server.cluster_enabled && type.shard) { + slot = c->slot != -1 ? c->slot : (int)keyHashSlot(channel->ptr, sdslen(channel->ptr)); + } + d = *type.serverPubSubChannels(slot); + serverAssertWithInfo(c,NULL,d != NULL); + de = dictFind(d, channel); serverAssertWithInfo(c,NULL,de != NULL); clients = dictGetVal(de); ln = listSearchKey(clients,c); @@ -306,11 +345,14 @@ int pubsubUnsubscribeChannel(client *c, robj *channel, int notify, pubsubtype ty /* Free the list and associated hash entry at all if this was * the latest client, so that it will be possible to abuse * Redis PUBSUB creating millions of channels. */ - dictDelete(*type.serverPubSubChannels, channel); - /* As this channel isn't subscribed by anyone, it's safe - * to remove the channel from the slot. */ - if (server.cluster_enabled & type.shard) { - slotToChannelDel(channel->ptr); + dictDelete(d, channel); + if (type.shard) { + if (dictSize(d) == 0) { + dictRelease(d); + dict **d_ptr = type.serverPubSubChannels(slot); + *d_ptr = NULL; + } + server.shard_channel_count--; } } } @@ -322,19 +364,22 @@ int pubsubUnsubscribeChannel(client *c, robj *channel, int notify, pubsubtype ty return retval; } -void pubsubShardUnsubscribeAllClients(robj *channel) { - int retval; - dictEntry *de = dictFind(server.pubsubshard_channels, channel); - serverAssertWithInfo(NULL,channel,de != NULL); - list *clients = dictGetVal(de); - if (listLength(clients) > 0) { +/* Unsubscribe all shard channels in a slot. */ +void pubsubShardUnsubscribeAllChannelsInSlot(unsigned int slot) { + dict *d = server.pubsubshard_channels[slot]; + if (!d) { + return; + } + dictIterator *di = dictGetSafeIterator(d); + dictEntry *de; + while ((de = dictNext(di)) != NULL) { + robj *channel = dictGetKey(de); + list *clients = dictGetVal(de); /* For each client subscribed to the channel, unsubscribe it. */ - listIter li; listNode *ln; - listRewind(clients, &li); - while ((ln = listNext(&li)) != NULL) { + while ((ln = listFirst(clients)) != NULL) { client *c = listNodeValue(ln); - retval = dictDelete(c->pubsubshard_channels, channel); + int retval = dictDelete(c->pubsubshard_channels, channel); serverAssertWithInfo(c,channel,retval == DICT_OK); addReplyPubsubUnsubscribed(c, channel, pubSubShardType); /* If the client has no other pubsub subscription, @@ -343,16 +388,14 @@ void pubsubShardUnsubscribeAllClients(robj *channel) { unmarkClientAsPubSub(c); } } + server.shard_channel_count--; + dictDelete(d, channel); } - /* Delete the channel from server pubsubshard channels hash table. */ - retval = dictDelete(server.pubsubshard_channels, channel); - /* Delete the channel from slots_to_channel mapping. */ - slotToChannelDel(channel->ptr); - serverAssertWithInfo(NULL,channel,retval == DICT_OK); - decrRefCount(channel); /* it is finally safe to release it */ + dictReleaseIterator(di); + dictRelease(d); + server.pubsubshard_channels[slot] = NULL; } - /* Subscribe a client to a pattern. Returns 1 if the operation succeeded, or 0 if the client was already subscribed to that pattern. */ int pubsubSubscribePattern(client *c, robj *pattern) { dictEntry *de; @@ -446,17 +489,6 @@ int pubsubUnsubscribeShardAllChannels(client *c, int notify) { return count; } -/* - * Unsubscribe a client from provided shard subscribed channel(s). - */ -void pubsubUnsubscribeShardChannels(robj **channels, unsigned int count) { - for (unsigned int j = 0; j < count; j++) { - /* Remove the channel from server and from the clients - * subscribed to it as well as notify them. */ - pubsubShardUnsubscribeAllClients(channels[j]); - } -} - /* Unsubscribe from all the patterns. Return the number of patterns the * client was subscribed from. */ int pubsubUnsubscribeAllPatterns(client *c, int notify) { @@ -483,13 +515,19 @@ int pubsubUnsubscribeAllPatterns(client *c, int notify) { */ int pubsubPublishMessageInternal(robj *channel, robj *message, pubsubtype type) { int receivers = 0; + dict *d; dictEntry *de; dictIterator *di; listNode *ln; listIter li; + unsigned int slot = 0; /* Send to clients listening for that channel */ - de = dictFind(*type.serverPubSubChannels, channel); + if (server.cluster_enabled && type.shard) { + slot = keyHashSlot(channel->ptr, sdslen(channel->ptr)); + } + d = *type.serverPubSubChannels(slot); + de = d ? dictFind(d, channel) : NULL; if (de) { list *list = dictGetVal(de); listNode *ln; @@ -658,7 +696,7 @@ NULL { /* PUBSUB CHANNELS [] */ sds pat = (c->argc == 2) ? NULL : c->argv[2]->ptr; - channelList(c, pat, server.pubsub_channels); + channelList(c, pat, &server.pubsub_channels, 0); } else if (!strcasecmp(c->argv[1]->ptr,"numsub") && c->argc >= 2) { /* PUBSUB NUMSUB [Channel_1 ... Channel_N] */ int j; @@ -678,14 +716,15 @@ NULL { /* PUBSUB SHARDCHANNELS */ sds pat = (c->argc == 2) ? NULL : c->argv[2]->ptr; - channelList(c,pat,server.pubsubshard_channels); + channelList(c,pat,server.pubsubshard_channels,server.cluster_enabled); } else if (!strcasecmp(c->argv[1]->ptr,"shardnumsub") && c->argc >= 2) { /* PUBSUB SHARDNUMSUB [ShardChannel_1 ... ShardChannel_N] */ int j; - addReplyArrayLen(c, (c->argc-2)*2); for (j = 2; j < c->argc; j++) { - list *l = dictFetchValue(server.pubsubshard_channels, c->argv[j]); + unsigned int slot = calculateKeySlot(c->argv[j]->ptr); + dict *d = server.pubsubshard_channels[slot]; + list *l = d ? dictFetchValue(d, c->argv[j]) : NULL; addReplyBulk(c,c->argv[j]); addReplyLongLong(c,l ? listLength(l) : 0); @@ -695,25 +734,31 @@ NULL } } -void channelList(client *c, sds pat, dict *pubsub_channels) { - dictIterator *di = dictGetIterator(pubsub_channels); - dictEntry *de; +void channelList(client *c, sds pat, dict **pubsub_channels, int is_sharded) { long mblen = 0; void *replylen; + unsigned int slot_cnt = is_sharded ? CLUSTER_SLOTS : 1; replylen = addReplyDeferredLen(c); - while((de = dictNext(di)) != NULL) { - robj *cobj = dictGetKey(de); - sds channel = cobj->ptr; - - if (!pat || stringmatchlen(pat, sdslen(pat), - channel, sdslen(channel),0)) - { - addReplyBulk(c,cobj); - mblen++; + for (unsigned int i = 0; i < slot_cnt; i++) { + if (pubsub_channels[i] == NULL) { + continue; } + dictIterator *di = dictGetIterator(pubsub_channels[i]); + dictEntry *de; + while((de = dictNext(di)) != NULL) { + robj *cobj = dictGetKey(de); + sds channel = cobj->ptr; + + if (!pat || stringmatchlen(pat, sdslen(pat), + channel, sdslen(channel),0)) + { + addReplyBulk(c,cobj); + mblen++; + } + } + dictReleaseIterator(di); } - dictReleaseIterator(di); setDeferredArrayLen(c,replylen,mblen); } @@ -735,14 +780,6 @@ void ssubscribeCommand(client *c) { } for (int j = 1; j < c->argc; j++) { - /* A channel is only considered to be added, if a - * subscriber exists for it. And if a subscriber - * already exists the slotToChannel doesn't needs - * to be incremented. */ - if (server.cluster_enabled & - (dictFind(*pubSubShardType.serverPubSubChannels, c->argv[j]) == NULL)) { - slotToChannelAdd(c->argv[j]->ptr); - } pubsubSubscribeChannel(c, c->argv[j], pubSubShardType); } markClientAsPubSub(c); diff --git a/src/server.c b/src/server.c index 872c327a3..0b45616c3 100644 --- a/src/server.c +++ b/src/server.c @@ -2714,10 +2714,10 @@ void initServer(void) { server.db = zmalloc(sizeof(redisDb)*server.dbnum); /* Create the Redis databases, and initialize other internal state. */ - for (j = 0; j < server.dbnum; j++) { - int slotCount = (server.cluster_enabled) ? CLUSTER_SLOTS : 1; - server.db[j].dict = dictCreateMultiple(&dbDictType, slotCount); - server.db[j].expires = dictCreateMultiple(&dbExpiresDictType,slotCount); + int slot_count = (server.cluster_enabled) ? CLUSTER_SLOTS : 1; + for (j = 0; j < server.dbnum; j++) { + server.db[j].dict = dictCreateMultiple(&dbDictType, slot_count); + server.db[j].expires = dictCreateMultiple(&dbExpiresDictType,slot_count); server.db[j].expires_cursor = 0; server.db[j].blocking_keys = dictCreate(&keylistDictType); server.db[j].blocking_keys_unblock_on_nokey = dictCreate(&objectKeyPointerValueDictType); @@ -2726,7 +2726,7 @@ void initServer(void) { server.db[j].id = j; server.db[j].avg_ttl = 0; server.db[j].defrag_later = listCreate(); - server.db[j].dict_count = slotCount; + server.db[j].dict_count = slot_count; initDbState(&server.db[j]); listSetFreeMethod(server.db[j].defrag_later,(void (*)(void*))sdsfree); } @@ -2734,7 +2734,8 @@ void initServer(void) { evictionPoolAlloc(); /* Initialize the LRU keys pool. */ server.pubsub_channels = dictCreate(&keylistDictType); server.pubsub_patterns = dictCreate(&keylistDictType); - server.pubsubshard_channels = dictCreate(&keylistDictType); + server.pubsubshard_channels = zcalloc(sizeof(dict *) * slot_count); + server.shard_channel_count = 0; server.pubsub_clients = 0; server.cronloops = 0; server.in_exec = 0; @@ -5869,7 +5870,7 @@ sds genRedisInfoString(dict *section_dict, int all_sections, int everything) { "keyspace_misses:%lld\r\n", server.stat_keyspace_misses, "pubsub_channels:%ld\r\n", dictSize(server.pubsub_channels), "pubsub_patterns:%lu\r\n", dictSize(server.pubsub_patterns), - "pubsubshard_channels:%lu\r\n", dictSize(server.pubsubshard_channels), + "pubsubshard_channels:%llu\r\n", server.shard_channel_count, "latest_fork_usec:%lld\r\n", server.stat_fork_time, "total_forks:%lld\r\n", server.stat_total_forks, "migrate_cached_sockets:%ld\r\n", dictSize(server.migrate_cached_sockets), diff --git a/src/server.h b/src/server.h index 99bce884a..a0ffdf746 100644 --- a/src/server.h +++ b/src/server.h @@ -1994,7 +1994,8 @@ struct redisServer { dict *pubsub_patterns; /* A dict of pubsub_patterns */ int notify_keyspace_events; /* Events to propagate via Pub/Sub. This is an xor of NOTIFY_... flags. */ - dict *pubsubshard_channels; /* Map shard channels to list of subscribed clients */ + dict **pubsubshard_channels; /* Map shard channels in every slot to list of subscribed clients */ + unsigned long long shard_channel_count; unsigned int pubsub_clients; /* # of clients in Pub/Sub mode */ /* Cluster */ int cluster_enabled; /* Is cluster enabled? */ @@ -2498,6 +2499,7 @@ extern dictType sdsHashDictType; extern dictType dbExpiresDictType; extern dictType modulesDictType; extern dictType sdsReplyDictType; +extern dictType keylistDictType; extern dict *modules; /*----------------------------------------------------------------------------- @@ -3197,7 +3199,7 @@ robj *hashTypeDup(robj *o); /* Pub / Sub */ int pubsubUnsubscribeAllChannels(client *c, int notify); int pubsubUnsubscribeShardAllChannels(client *c, int notify); -void pubsubUnsubscribeShardChannels(robj **channels, unsigned int count); +void pubsubShardUnsubscribeAllChannelsInSlot(unsigned int slot); int pubsubUnsubscribeAllPatterns(client *c, int notify); int pubsubPublishMessage(robj *channel, robj *message, int sharded); int pubsubPublishMessageAndPropagateToCluster(robj *channel, robj *message, int sharded); diff --git a/tests/cluster/tests/26-pubsubshard.tcl b/tests/cluster/tests/26-pubsubshard.tcl index 2619eda0a..34939acf7 100644 --- a/tests/cluster/tests/26-pubsubshard.tcl +++ b/tests/cluster/tests/26-pubsubshard.tcl @@ -56,6 +56,21 @@ test "client can subscribe to multiple shard channels across different slots in $cluster sunsubscribe ch7 } +test "sunsubscribe without specifying any channel would unsubscribe all shard channels subscribed" { + set publishclient [redis_client_by_addr $publishnode(host) $publishnode(port)] + set subscribeclient [redis_deferring_client_by_addr $publishnode(host) $publishnode(port)] + + set sub_res [ssubscribe $subscribeclient [list "\{channel.0\}1" "\{channel.0\}2" "\{channel.0\}3"]] + assert_equal [list 1 2 3] $sub_res + sunsubscribe $subscribeclient + + assert_equal 0 [$publishclient spublish "\{channel.0\}1" hello] + assert_equal 0 [$publishclient spublish "\{channel.0\}2" hello] + assert_equal 0 [$publishclient spublish "\{channel.0\}3" hello] + + $publishclient close + $subscribeclient close +} test "Verify Pub/Sub and Pub/Sub shard no overlap" { set slot [$cluster cluster keyslot "channel.0"] @@ -91,4 +106,25 @@ test "Verify Pub/Sub and Pub/Sub shard no overlap" { $publishclient close $subscribeclient close $subscribeshardclient close -} \ No newline at end of file +} + +test "PUBSUB channels/shardchannels" { + set subscribeclient [redis_deferring_client_by_addr $publishnode(host) $publishnode(port)] + set subscribeclient2 [redis_deferring_client_by_addr $publishnode(host) $publishnode(port)] + set subscribeclient3 [redis_deferring_client_by_addr $publishnode(host) $publishnode(port)] + set publishclient [redis_client_by_addr $publishnode(host) $publishnode(port)] + + ssubscribe $subscribeclient [list "\{channel.0\}1"] + ssubscribe $subscribeclient2 [list "\{channel.0\}2"] + ssubscribe $subscribeclient3 [list "\{channel.0\}3"] + assert_equal {3} [llength [$publishclient pubsub shardchannels]] + + subscribe $subscribeclient [list "\{channel.0\}4"] + assert_equal {3} [llength [$publishclient pubsub shardchannels]] + + sunsubscribe $subscribeclient + set channel_list [$publishclient pubsub shardchannels] + assert_equal {2} [llength $channel_list] + assert {[lsearch -exact $channel_list "\{channel.0\}2"] >= 0} + assert {[lsearch -exact $channel_list "\{channel.0\}3"] >= 0} +}