Skip to content

Commit

Permalink
clean up and prefer memory cache
Browse files Browse the repository at this point in the history
  • Loading branch information
jschlyter committed Oct 11, 2024
1 parent 48cd6d4 commit f11eee3
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 34 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ The default configuration file is `aggrec.toml`. Example configuration below:
metrics_endpoint = "http://localhost:4317"
insecure = true

[cache]
size = 1000
ttl = 300


## API

Expand Down
5 changes: 3 additions & 2 deletions aggrec.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@ spans_endpoint = "http://localhost:4317"
metrics_endpoint = "http://localhost:4317"
insecure = true

[redis]
host = "localhost"
[key_cache]
size = 1000
ttl = 300
25 changes: 11 additions & 14 deletions aggrec/key_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,54 +8,51 @@
logger = logging.getLogger(__name__)


DEFAULT_MEMORY_CACHE_SIZE = 1000
DEFAULT_MEMORY_CACHE_TTL = 60


class KeyCache:
@abstractmethod
def get(self, key: str) -> bytes | None:
return None

@abstractmethod
def set(self, key: str, value: bytes, ttl: int | None = None) -> None:
def set(self, key: str, value: bytes) -> None:
pass


class DummyKeyCache(KeyCache):
def get(self, key: str) -> bytes | None:
return None

def set(self, key: str, value: bytes, ttl: int | None = None) -> None:
def set(self, key: str, value: bytes) -> None:
pass


class MemoryKeyCache(KeyCache):
def __init__(self, size: int = DEFAULT_MEMORY_CACHE_SIZE, ttl: int = DEFAULT_MEMORY_CACHE_TTL):
def __init__(self, size: int, ttl: int):
self.cache = ExpiringDict(max_len=size, max_age_seconds=ttl)
logger.info("Using memory cache size=%d ttl=%d", size, ttl)

def get(self, key: str) -> bytes | None:
res = self.cache.get(key)
logger.debug("Cache GET %s (%s)", key, "hit" if res else "miss")
return res

def set(self, key: str, value: bytes, ttl: int | None = None) -> None:
def set(self, key: str, value: bytes) -> None:
logger.debug("Cache SET %s", key)
self.cache[key] = value


class RedisKeyCache(KeyCache):
def __init__(self, redis_client: redis.Redis, default_ttl: int):
def __init__(self, redis_client: redis.Redis, ttl: int):
self.redis_client = redis_client
self.default_ttl = default_ttl
self.ttl = ttl
logger.info("Using Redis cache ttl=%d", ttl)

def get(self, key: str) -> bytes | None:
res = self.redis_client.get(name=key)
logger.debug("Cache GET %s (%s)", key, "hit" if res else "miss")
return res

def set(self, key: str, value: bytes, ttl: int | None = None) -> None:
ttl = ttl if ttl is not None else self.default_ttl
expires_at = int(time.time()) + ttl
logger.debug("Cache SET %s with TTL %d EXAT %d", key, ttl, expires_at)
def set(self, key: str, value: bytes) -> None:
logger.debug("Cache SET %s", key)
expires_at = int(time.time()) + self.ttl
self.redis_client.set(name=key, value=value, exat=expires_at)
15 changes: 9 additions & 6 deletions aggrec/key_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,26 @@


class CacheKeyResolver(HTTPSignatureKeyResolver):
def __init__(self, key_cache: KeyCache):
def __init__(self, key_cache: KeyCache | None):
self.key_cache = key_cache

@abstractmethod
def get_public_key_pem(self, key_id: str) -> bytes:
pass

def resolve_public_key(self, key_id: str):
public_key_pem = self.key_cache.get(key_id)
if not public_key_pem:
if self.key_cache:
public_key_pem = self.key_cache.get(key_id)
if not public_key_pem:
public_key_pem = self.get_public_key_pem(key_id)
self.key_cache.set(key_id, public_key_pem)
else:
public_key_pem = self.get_public_key_pem(key_id)
self.key_cache.set(key_id, public_key_pem)
return load_pem_public_key(public_key_pem)


class FileKeyResolver(CacheKeyResolver):
def __init__(self, client_database_directory: str, key_cache: KeyCache):
def __init__(self, client_database_directory: str, key_cache: KeyCache | None = None):
super().__init__(key_cache=key_cache)
self.client_database_directory = client_database_directory

Expand All @@ -40,7 +43,7 @@ def get_public_key_pem(self, key_id: str) -> bytes:


class UrlKeyResolver(CacheKeyResolver):
def __init__(self, client_database_base_url: str, key_cache: KeyCache):
def __init__(self, client_database_base_url: str, key_cache: KeyCache | None = None):
super().__init__(key_cache=key_cache)
self.client_database_base_url = client_database_base_url
self.httpx_client = httpx.Client()
Expand Down
14 changes: 8 additions & 6 deletions aggrec/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,14 @@ def __init__(self, settings: Settings):
metrics_endpoint=str(settings.otlp.metrics_endpoint),
insecure=settings.otlp.insecure,
)
if self.settings.redis:
redis_client = redis.StrictRedis(host=self.settings.redis.host, port=self.settings.redis.port)
self.logger.debug("Using REDIS at %s:%d", self.settings.redis.host, self.settings.redis.port)
self.key_cache = RedisKeyCache(redis_client=redis_client, default_ttl=self.settings.redis.ttl)
else:
self.key_cache = MemoryKeyCache()
self.key_cache = None
if self.settings.key_cache:
if redis_settings := self.settings.key_cache.redis:
redis_client = redis.StrictRedis(host=redis_settings.host, port=redis_settings.port)
self.logger.debug("Using REDIS at %s:%d", redis_settings.host, redis_settings.port)
self.key_cache = RedisKeyCache(redis_client=redis_client, ttl=self.settings.key_cache.ttl)
elif self.settings.key_cache.size:
self.key_cache = MemoryKeyCache(size=self.settings.key_cache.size, ttl=self.settings.key_cache.ttl)

@staticmethod
def connect_mongodb(settings: Settings):
Expand Down
9 changes: 7 additions & 2 deletions aggrec/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,12 @@ class OtlpSettings(BaseModel):
class RedisSettings(BaseModel):
host: str = Field(description="Redis hostname")
port: int = Field(description="Redis port", default=6379)
ttl: int = Field(description="Redis cache TTL", default=300)


class KeyCacheSettings(BaseModel):
size: int = Field(description="Cache size", default=1000)
ttl: int = Field(description="Cache TTL", default=300)
redis: RedisSettings | None = None


class Settings(BaseSettings):
Expand All @@ -58,7 +63,7 @@ class Settings(BaseSettings):
mqtt: MqttSettings = Field(default=MqttSettings())
mongodb: MongoDB = Field(default=MongoDB())
otlp: OtlpSettings = Field(default=OtlpSettings())
redis: RedisSettings | None = None
key_cache: KeyCacheSettings | None = None

model_config = SettingsConfigDict(toml_file="aggrec.toml")

Expand Down
4 changes: 2 additions & 2 deletions tests/test_key_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_redis_cache():
)

redis_client = fakeredis.FakeRedis()
key_cache = RedisKeyCache(redis_client=redis_client, default_ttl=60)
key_cache = RedisKeyCache(redis_client=redis_client, ttl=60)

res = key_cache.get(key_id)
assert res is None
Expand All @@ -31,7 +31,7 @@ def test_memory_cache():
encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo
)

key_cache = MemoryKeyCache()
key_cache = MemoryKeyCache(size=100, ttl=60)

res = key_cache.get(key_id)
assert res is None
Expand Down
3 changes: 1 addition & 2 deletions tests/test_key_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from cryptography.hazmat.primitives.asymmetric import ed25519
from pytest_httpx import HTTPXMock

from aggrec.key_cache import MemoryKeyCache
from aggrec.key_resolver import UrlKeyResolver


Expand All @@ -15,7 +14,7 @@ def test_url_key_resolver(httpx_mock: HTTPXMock):

httpx_mock.add_response(url=f"https://keys/{key_id}.pem", content=public_key_pem)

resolver = UrlKeyResolver(client_database_base_url="https://keys", key_cache=MemoryKeyCache())
resolver = UrlKeyResolver(client_database_base_url="https://keys")

res = resolver.resolve_public_key(key_id)
assert res == public_key

0 comments on commit f11eee3

Please sign in to comment.