diff --git a/kombu/transport/redis_cluster.py b/kombu/transport/redis_cluster.py index 92851f881..4449c73c9 100644 --- a/kombu/transport/redis_cluster.py +++ b/kombu/transport/redis_cluster.py @@ -138,6 +138,21 @@ def on_readable(self, fileno): if chan.qos.can_consume(): return chan.handlers[cmd](**{'conn': conn}) +class RedisClusterConnection(): + connections = {} + @classmethod + def get_connection(cls, host, port): + key = (host, port) + if key not in cls.connections: + cls.connections[key] = cls.create_connection(host, port) + return cls.connections[key] + + @classmethod + def create_connection(cls, host, port): + params = {'skip_full_coverage_check': True, 'host': host, 'port': port} + + return redis.RedisCluster(**params) + class Channel(RedisChannel): @@ -187,7 +202,9 @@ def conn_or_acquire(self, client=None): yield self.client def _create_client(self, asynchronous=False): - return self.connection.cluster_connection + conninfo = self.connection.client + + return RedisClusterConnection.get_connection(conninfo.hostname, conninfo.port) def _brpop_start(self, timeout=1): queues = self._queue_cycle.consume(len(self.active_queues)) @@ -256,13 +273,5 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.cycle = ClusterPoller() - params = {'skip_full_coverage_check': True, 'host': self.client.hostname, 'port': self.client.port} - self.cluster_connection = redis.RedisCluster(**params) - - def close_connection(self, connection): - super().close_connection(connection) - - self.cluster_connection.close() - def driver_version(self): return redis.__version__