diff --git a/django_valkey/async_cache/client/default.py b/django_valkey/async_cache/client/default.py index c724d22..5d6b657 100644 --- a/django_valkey/async_cache/client/default.py +++ b/django_valkey/async_cache/client/default.py @@ -403,12 +403,15 @@ async def clear(self, client: AValkey | Any | None = None) -> bool: aclear = clear - async def decode(self, value) -> Any: + async def decode(self, value: bytes) -> Any: """ Decode the given value. """ try: - value = int(value) + if value.isdigit(): + value = int(value) + else: + value = float(value) except (ValueError, TypeError): # Handle values that weren't compressed (small stuff) with suppress(CompressorError): @@ -419,11 +422,11 @@ async def decode(self, value) -> Any: adecode = decode - async def encode(self, value) -> bytes | int: + async def encode(self, value) -> bytes | int | float: """ Encode the given value. """ - if isinstance(value, bool) or not isinstance(value, int): + if type(value) is not int and type(value) is not float: value = self._serializer.dumps(value) return self._compressor.compress(value) diff --git a/django_valkey/base_client.py b/django_valkey/base_client.py index a1e11dd..b26c2c0 100644 --- a/django_valkey/base_client.py +++ b/django_valkey/base_client.py @@ -524,12 +524,15 @@ def clear(self, client: Backend | Any | None = None) -> bool: except _main_exceptions as e: raise ConnectionInterrupted(connection=client) from e - def decode(self, value: EncodableT) -> Any: + def decode(self, value: bytes) -> Any: """ Decode the given value. """ try: - value = int(value) + if value.isdigit(): + value = int(value) + else: + value = float(value) except (ValueError, TypeError): # Handle little values, chosen to be not compressed with suppress(CompressorError): @@ -537,12 +540,12 @@ def decode(self, value: EncodableT) -> Any: value = self._serializer.loads(value) return value - def encode(self, value: EncodableT) -> bytes | int: + def encode(self, value: EncodableT) -> bytes | int | float: """ Encode the given value. """ - if isinstance(value, bool) or not isinstance(value, int): + if type(value) is not int and type(value) is not float: value = self._serializer.dumps(value) return self._compressor.compress(value) diff --git a/tests/test_backend.py b/tests/test_backend.py index 9264ba3..071d75a 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -30,6 +30,31 @@ def patch_itersize_setting() -> Iterable[None]: class TestDjangoValkeyCache: + def test_set_int(self, cache: ValkeyCache): + if isinstance(cache.client, herd.HerdClient): + pytest.skip("herd client's set method works differently") + cache.set("test_key", 1) + result = cache.get("test_key") + assert type(result) is int + # shard client doesn't have get_client() + if not isinstance(cache.client, ShardClient): + raw_client = cache.client._get_client(write=False, client=None) + else: + raw_client = cache.client.get_server(":1:test_key") + assert raw_client.get(":1:test_key") == b"1" + + def test_set_float(self, cache: ValkeyCache): + if isinstance(cache.client, herd.HerdClient): + pytest.skip("herd client's set method works differently") + cache.set("test_key2", 1.1) + result = cache.get("test_key2") + assert type(result) is float + if not isinstance(cache.client, ShardClient): + raw_client = cache.client._get_client(write=False, client=None) + else: + raw_client = cache.client.get_server(":1:test_key2") + assert raw_client.get(":1:test_key2") == b"1.1" + def test_setnx(self, cache: ValkeyCache): # we should ensure there is no test_key_nx in valkey cache.delete("test_key_nx") @@ -867,6 +892,24 @@ def test_sadd(self, cache: ValkeyCache): assert cache.sadd("foo", "bar") == 1 assert cache.smembers("foo") == {"bar"} + def test_sadd_int(self, cache: ValkeyCache): + cache.sadd("foo", 1) + assert cache.smembers("foo") == {1} + if not isinstance(cache.client, ShardClient): + raw_client = cache.client._get_client(write=False, client=None) + else: + raw_client = cache.client.get_server(":1:foo") + assert raw_client.smembers(":1:foo") == [b"1"] + + def test_sadd_float(self, cache: ValkeyCache): + cache.sadd("foo", 1.2) + assert cache.smembers("foo") == {1.2} + if not isinstance(cache.client, ShardClient): + raw_client = cache.client._get_client(write=False, client=None) + else: + raw_client = cache.client.get_server(":1:foo") + assert raw_client.smembers(":1:foo") == [b"1.2"] + def test_scard(self, cache: ValkeyCache): cache.sadd("foo", "bar", "bar2") assert cache.scard("foo") == 2 diff --git a/tests/tests_async/test_backend.py b/tests/tests_async/test_backend.py index 68e5787..d0239f9 100644 --- a/tests/tests_async/test_backend.py +++ b/tests/tests_async/test_backend.py @@ -31,6 +31,24 @@ async def patch_itersize_setting() -> Iterable[None]: @pytest.mark.asyncio(loop_scope="session") class TestAsyncDjangoValkeyCache: + async def test_set_int(self, cache: AsyncValkeyCache): + if isinstance(cache.client, AsyncHerdClient): + pytest.skip("Herd client's set method works differently") + await cache.aset("test_key", 1) + result = await cache.aget("test_key") + assert type(result) is int + raw_client = await cache.client._get_client(write=False, client=None) + assert await raw_client.get(":1:test_key") == b"1" + + async def test_set_float(self, cache: AsyncValkeyCache): + if isinstance(cache.client, AsyncHerdClient): + pytest.skip("Herd client's set method works differently") + await cache.aset("test_key2", 1.1) + result = await cache.aget("test_key2") + assert type(result) is float + raw_client = await cache.client._get_client(write=False, client=None) + assert await raw_client.get(":1:test_key2") == b"1.1" + async def test_setnx(self, cache: AsyncValkeyCache): await cache.delete("test_key_nx") res = await cache.get("test_key_nx") @@ -895,6 +913,18 @@ async def test_sadd(self, cache: AsyncValkeyCache): assert await cache.asadd("foo", "bar") == 1 assert await cache.asmembers("foo") == {"bar"} + async def test_sadd_int(self, cache: AsyncValkeyCache): + await cache.asadd("foo", 1) + assert await cache.asmembers("foo") == {1} + raw_client = await cache.client._get_client(write=False, client=None) + assert await raw_client.smembers(":1:foo") == [b"1"] + + async def test_sadd_float(self, cache: AsyncValkeyCache): + await cache.asadd("foo", 1.2) + assert await cache.asmembers("foo") == {1.2} + raw_client = await cache.client._get_client(write=False, client=None) + assert await raw_client.smembers(":1:foo") == [b"1.2"] + async def test_scard(self, cache: AsyncValkeyCache): await cache.asadd("foo", "bar", "bar2") assert await cache.ascard("foo") == 2