diff --git a/CHANGELOG.md b/CHANGELOG.md index 228aa7bdf2..80f7c054d4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ #### Changes +* Python: AZ Affinity - Python Wrapper Support ([#2676](https://github.com/valkey-io/valkey-glide/pull/2676)) * Python: Client API for retrieving internal statistics ([#2707](https://github.com/valkey-io/valkey-glide/pull/2707)) * Node, Python: Adding support for replacing connection configured password ([#2651](https://github.com/valkey-io/valkey-glide/pull/2651))([#2659](https://github.com/valkey-io/valkey-glide/pull/2659)) * Node: Add FT._ALIASLIST command([#2652](https://github.com/valkey-io/valkey-glide/pull/2652)) diff --git a/python/python/glide/config.py b/python/python/glide/config.py index db85202876..b33c037cbf 100644 --- a/python/python/glide/config.py +++ b/python/python/glide/config.py @@ -41,6 +41,11 @@ class ReadFrom(Enum): Spread the requests between all replicas in a round robin manner. If no replica is available, route the requests to the primary. """ + AZ_AFFINITY = ProtobufReadFrom.AZAffinity + """ + Spread the read requests between replicas in the same client's AZ (Aviliablity zone) in a round robin manner, + falling back to other replicas or the primary if needed + """ class ProtocolVersion(Enum): @@ -135,6 +140,7 @@ def __init__( client_name: Optional[str] = None, protocol: ProtocolVersion = ProtocolVersion.RESP3, inflight_requests_limit: Optional[int] = None, + client_az: Optional[str] = None, ): """ Represents the configuration settings for a Glide client. @@ -172,6 +178,12 @@ def __init__( self.client_name = client_name self.protocol = protocol self.inflight_requests_limit = inflight_requests_limit + self.client_az = client_az + + if read_from == ReadFrom.AZ_AFFINITY and not client_az: + raise ValueError( + "client_az mus t be set when read_from is set to AZ_AFFINITY" + ) def _create_a_protobuf_conn_request( self, cluster_mode: bool = False @@ -204,6 +216,8 @@ def _create_a_protobuf_conn_request( request.protocol = self.protocol.value if self.inflight_requests_limit: request.inflight_requests_limit = self.inflight_requests_limit + if self.client_az: + request.client_az = self.client_az return request @@ -293,6 +307,7 @@ def __init__( protocol: ProtocolVersion = ProtocolVersion.RESP3, pubsub_subscriptions: Optional[PubSubSubscriptions] = None, inflight_requests_limit: Optional[int] = None, + client_az: Optional[str] = None, ): super().__init__( addresses=addresses, @@ -303,6 +318,7 @@ def __init__( client_name=client_name, protocol=protocol, inflight_requests_limit=inflight_requests_limit, + client_az=client_az, ) self.reconnect_strategy = reconnect_strategy self.database_id = database_id @@ -442,6 +458,7 @@ def __init__( ] = PeriodicChecksStatus.ENABLED_DEFAULT_CONFIGS, pubsub_subscriptions: Optional[PubSubSubscriptions] = None, inflight_requests_limit: Optional[int] = None, + client_az: Optional[str] = None, ): super().__init__( addresses=addresses, @@ -452,6 +469,7 @@ def __init__( client_name=client_name, protocol=protocol, inflight_requests_limit=inflight_requests_limit, + client_az=client_az, ) self.periodic_checks = periodic_checks self.pubsub_subscriptions = pubsub_subscriptions diff --git a/python/python/tests/conftest.py b/python/python/tests/conftest.py index 437fbd8fbb..85bc58c4b1 100644 --- a/python/python/tests/conftest.py +++ b/python/python/tests/conftest.py @@ -9,6 +9,7 @@ GlideClusterClientConfiguration, NodeAddress, ProtocolVersion, + ReadFrom, ServerCredentials, ) from glide.exceptions import ClosingError, RequestError @@ -132,6 +133,7 @@ def create_clusters(tls, load_module, cluster_endpoints, standalone_endpoints): cluster_mode=True, load_module=load_module, addresses=cluster_endpoints, + replica_count=1, ) pytest.standalone_cluster = ValkeyCluster( tls=tls, @@ -248,6 +250,8 @@ async def create_client( GlideClientConfiguration.PubSubSubscriptions ] = None, inflight_requests_limit: Optional[int] = None, + read_from: ReadFrom = ReadFrom.PRIMARY, + client_az: Optional[str] = None, ) -> Union[GlideClient, GlideClusterClient]: # Create async socket client use_tls = request.config.getoption("--tls") @@ -265,6 +269,8 @@ async def create_client( request_timeout=timeout, pubsub_subscriptions=cluster_mode_pubsub, inflight_requests_limit=inflight_requests_limit, + read_from=read_from, + client_az=client_az, ) return await GlideClusterClient.create(cluster_config) else: @@ -281,6 +287,8 @@ async def create_client( request_timeout=timeout, pubsub_subscriptions=standalone_mode_pubsub, inflight_requests_limit=inflight_requests_limit, + read_from=read_from, + client_az=client_az, ) return await GlideClient.create(config) @@ -381,3 +389,26 @@ async def test_meow_meow(...): reason=f"This feature added in version {min_version}", allow_module_level=True, ) + + +@pytest.fixture(scope="module") +def multiple_replicas_cluster(request): + """ + Fixture to create a special cluster with 4 replicas for specific tests. + """ + tls = request.config.getoption("--tls") + load_module = request.config.getoption("--load-module") + cluster_endpoints = request.config.getoption("--cluster-endpoints") + + if not cluster_endpoints: + multiple_replica_cluster = ValkeyCluster( + tls=tls, + cluster_mode=True, + load_module=load_module, + addresses=cluster_endpoints, + replica_count=4, + ) + yield multiple_replica_cluster + multiple_replica_cluster.__del__() + else: + yield None diff --git a/python/python/tests/test_config.py b/python/python/tests/test_config.py index 93c280245f..3b22adb09c 100644 --- a/python/python/tests/test_config.py +++ b/python/python/tests/test_config.py @@ -52,3 +52,18 @@ def test_periodic_checks_interval_to_protobuf(): config.periodic_checks = PeriodicChecksManualInterval(30) request = config._create_a_protobuf_conn_request(cluster_mode=True) assert request.periodic_checks_manual_interval.duration_in_sec == 30 + + +def test_convert_config_with_azaffinity_to_protobuf(): + az = "us-east-1a" + config = BaseClientConfiguration( + [NodeAddress("127.0.0.1")], + use_tls=True, + read_from=ReadFrom.AZ_AFFINITY, + client_az=az, + ) + request = config._create_a_protobuf_conn_request() + assert isinstance(request, ConnectionRequest) + assert request.tls_mode is TlsMode.SecureTls + assert request.read_from == ProtobufReadFrom.AZAffinity + assert request.client_az == az diff --git a/python/python/tests/test_read_from_strategy.py b/python/python/tests/test_read_from_strategy.py new file mode 100644 index 0000000000..fc15481a07 --- /dev/null +++ b/python/python/tests/test_read_from_strategy.py @@ -0,0 +1,228 @@ +# Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +import re + +import pytest +from glide.async_commands.core import InfoSection +from glide.config import ProtocolVersion, ReadFrom +from glide.constants import OK +from glide.glide_client import GlideClusterClient +from glide.routes import AllNodes, SlotIdRoute, SlotType +from tests.conftest import create_client +from tests.utils.utils import get_first_result + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("multiple_replicas_cluster") +class TestAZAffinity: + async def _get_num_replicas(self, client: GlideClusterClient) -> int: + info_replicas = get_first_result( + await client.info([InfoSection.REPLICATION]) + ).decode() + match = re.search(r"connected_slaves:(\d+)", info_replicas) + if match: + return int(match.group(1)) + else: + raise ValueError( + "Could not find the number of replicas in the INFO REPLICATION response" + ) + + @pytest.mark.skip_if_version_below("8.0.0") + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_routing_by_slot_to_replica_with_az_affinity_strategy_to_all_replicas( + self, + request, + cluster_mode: bool, + protocol: ProtocolVersion, + multiple_replicas_cluster, + ): + """Test that the client with AZ affinity strategy routes in a round-robin manner to all replicas within the specified AZ""" + + az = "us-east-1a" + client_for_config_set = await create_client( + request, + cluster_mode, + addresses=multiple_replicas_cluster.nodes_addr, + protocol=protocol, + timeout=2000, + ) + await client_for_config_set.config_resetstat() == OK + await client_for_config_set.custom_command( + ["CONFIG", "SET", "availability-zone", az], AllNodes() + ) + await client_for_config_set.close() + + client_for_testing_az = await create_client( + request, + cluster_mode, + addresses=multiple_replicas_cluster.nodes_addr, + protocol=protocol, + read_from=ReadFrom.AZ_AFFINITY, + timeout=2000, + client_az=az, + ) + azs = await client_for_testing_az.custom_command( + ["CONFIG", "GET", "availability-zone"], AllNodes() + ) + + # Check that all replicas have the availability zone set to the az + assert all( + ( + node[1].decode() == az + if isinstance(node, list) + else node[b"availability-zone"].decode() == az + ) + for node in azs.values() + ) + + n_replicas = await self._get_num_replicas(client_for_testing_az) + GET_CALLS = 3 * n_replicas + get_cmdstat = f"cmdstat_get:calls={GET_CALLS // n_replicas}" + + for _ in range(GET_CALLS): + await client_for_testing_az.get("foo") + + info_result = await client_for_testing_az.info( + [InfoSection.COMMAND_STATS, InfoSection.SERVER], AllNodes() + ) + + # Check that all replicas have the same number of GET calls + matching_entries_count = sum( + 1 + for value in info_result.values() + if get_cmdstat in value.decode() and az in value.decode() + ) + assert matching_entries_count == n_replicas + + await client_for_testing_az.close() + + @pytest.mark.skip_if_version_below("8.0.0") + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_routing_with_az_affinity_strategy_to_1_replica( + self, + request, + cluster_mode: bool, + protocol: ProtocolVersion, + multiple_replicas_cluster, + ): + """Test that the client with az affinity strategy will only route to the 1 replica with the same az""" + az = "us-east-1a" + GET_CALLS = 3 + get_cmdstat = f"cmdstat_get:calls={GET_CALLS}" + + client_for_config_set = await create_client( + request, + cluster_mode, + addresses=multiple_replicas_cluster.nodes_addr, + protocol=protocol, + timeout=2000, + ) + + # Reset the availability zone for all nodes + await client_for_config_set.custom_command( + ["CONFIG", "SET", "availability-zone", ""], + route=AllNodes(), + ) + await client_for_config_set.config_resetstat() == OK + + # 12182 is the slot of "foo" + await client_for_config_set.custom_command( + ["CONFIG", "SET", "availability-zone", az], + route=SlotIdRoute(SlotType.REPLICA, 12182), + ) + + await client_for_config_set.close() + + client_for_testing_az = await create_client( + request, + cluster_mode, + addresses=multiple_replicas_cluster.nodes_addr, + protocol=protocol, + read_from=ReadFrom.AZ_AFFINITY, + timeout=2000, + client_az=az, + ) + + for _ in range(GET_CALLS): + await client_for_testing_az.get("foo") + + info_result = await client_for_testing_az.info( + [InfoSection.SERVER, InfoSection.COMMAND_STATS], AllNodes() + ) + + # Check that only the replica with az has all the GET calls + matching_entries_count = sum( + 1 + for value in info_result.values() + if get_cmdstat in value.decode() and az in value.decode() + ) + assert matching_entries_count == 1 + + # Check that the other replicas have no availability zone set + changed_az_count = sum( + 1 + for node in info_result.values() + if f"availability_zone:{az}" in node.decode() + ) + assert changed_az_count == 1 + + await client_for_testing_az.close() + + @pytest.mark.skip_if_version_below("8.0.0") + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_az_affinity_non_existing_az( + self, + request, + cluster_mode: bool, + protocol: ProtocolVersion, + multiple_replicas_cluster, + ): + GET_CALLS = 4 + + client_for_testing_az = await create_client( + request, + cluster_mode, + addresses=multiple_replicas_cluster.nodes_addr, + protocol=protocol, + read_from=ReadFrom.AZ_AFFINITY, + timeout=2000, + client_az="non-existing-az", + ) + await client_for_testing_az.config_resetstat() == OK + + for _ in range(GET_CALLS): + await client_for_testing_az.get("foo") + + n_replicas = await self._get_num_replicas(client_for_testing_az) + # We expect the calls to be distributed evenly among the replicas + get_cmdstat = f"cmdstat_get:calls={GET_CALLS // n_replicas}" + + info_result = await client_for_testing_az.info( + [InfoSection.COMMAND_STATS, InfoSection.SERVER], AllNodes() + ) + + matching_entries_count = sum( + 1 for value in info_result.values() if get_cmdstat in value.decode() + ) + assert matching_entries_count == GET_CALLS + + await client_for_testing_az.close() + + @pytest.mark.skip_if_version_below("8.0.0") + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_az_affinity_requires_client_az( + self, request, cluster_mode: bool, protocol: ProtocolVersion + ): + """Test that setting read_from to AZ_AFFINITY without client_az raises an error.""" + with pytest.raises(ValueError): + await create_client( + request, + cluster_mode=cluster_mode, + protocol=protocol, + read_from=ReadFrom.AZ_AFFINITY, + timeout=2000, + )