Skip to content

Commit

Permalink
add public key caching via REDIS
Browse files Browse the repository at this point in the history
  • Loading branch information
jschlyter committed Oct 11, 2024
1 parent 6e3ae4b commit 9e31626
Show file tree
Hide file tree
Showing 9 changed files with 284 additions and 146 deletions.
3 changes: 3 additions & 0 deletions aggrec.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ topic = "aggregates"
spans_endpoint = "http://localhost:4317"
metrics_endpoint = "http://localhost:4317"
insecure = true

[redis]
host = "localhost"
4 changes: 3 additions & 1 deletion aggrec/aggregates.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,9 @@ async def create_aggregate(
span = trace.get_current_span()

with tracer.start_as_current_span("http_request_verifier"):
http_request_verifier = RequestVerifier(client_database=request.app.settings.clients_database)
http_request_verifier = RequestVerifier(
client_database=request.app.settings.clients_database, key_cache=request.app.key_cache
)
res = await http_request_verifier.verify(request)

creator = res.parameters.get("keyid")
Expand Down
40 changes: 6 additions & 34 deletions aggrec/helpers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import hashlib
import logging
from datetime import datetime, timezone
from urllib.parse import urljoin

import http_sf
import httpx
import pendulum
from cryptography.hazmat.primitives.serialization import load_pem_public_key
from fastapi import HTTPException, Request, status
from http_message_signatures import (
HTTPMessageVerifier,
Expand All @@ -18,7 +15,9 @@
from http_message_signatures.algorithms import signature_algorithms as supported_signature_algorithms
from http_message_signatures.exceptions import InvalidSignature
from pydantic import AnyHttpUrl, DirectoryPath
from werkzeug.utils import safe_join

from .key_cache import KeyCache
from .key_resolver import FileKeyResolver, UrlKeyResolver

DEFAULT_SIGNATURE_ALGORITHM = algorithms.ECDSA_P256_SHA256
HASH_ALGORITHMS = {"sha-256": hashlib.sha256, "sha-512": hashlib.sha512}
Expand All @@ -40,50 +39,23 @@ class ContentDigestMissing(ContentDigestException):
pass


class FileKeyResolver(HTTPSignatureKeyResolver):
def __init__(self, client_database_directory: str):
self.client_database_directory = client_database_directory

def resolve_public_key(self, key_id: str):
filename = safe_join(self.client_database_directory, f"{key_id}.pem")
try:
with open(filename, "rb") as fp:
return load_pem_public_key(fp.read())
except FileNotFoundError as exc:
raise KeyError(key_id) from exc


class UrlKeyResolver(HTTPSignatureKeyResolver):
def __init__(self, client_database_base_url: str):
self.client_database_base_url = client_database_base_url
self.httpx_client = httpx.Client()

def resolve_public_key(self, key_id: str):
public_key_url = urljoin(self.client_database_base_url, f"{key_id}.pem")
try:
response = self.httpx_client.get(public_key_url)
response.raise_for_status()
return load_pem_public_key(response.content)
except httpx.HTTPError as exc:
raise KeyError(key_id) from exc


class RequestVerifier:
def __init__(
self,
algorithm: HTTPSignatureAlgorithm | None = None,
key_resolver: HTTPSignatureKeyResolver | None = None,
client_database: AnyHttpUrl | DirectoryPath | None = None,
key_cache: KeyCache | None = None,
):
self.algorithm = algorithm or DEFAULT_SIGNATURE_ALGORITHM
if key_resolver:
self.key_resolver = key_resolver
elif client_database and (
str(client_database).startswith("http://") or str(client_database).startswith("https://")
):
self.key_resolver = UrlKeyResolver(str(client_database))
self.key_resolver = UrlKeyResolver(client_database_base_url=str(client_database), key_cache=key_cache)
elif client_database:
self.key_resolver = FileKeyResolver(str(client_database))
self.key_resolver = FileKeyResolver(client_database_directory=str(client_database), key_cache=key_cache)
else:
raise ValueError("No key resolver nor client database specified")
self.logger = logging.getLogger(__name__).getChild(self.__class__.__name__)
Expand Down
42 changes: 42 additions & 0 deletions aggrec/key_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import logging
import time
from abc import abstractmethod

import redis

logger = logging.getLogger(__name__)


class KeyCache:
@abstractmethod
def get(self, key: str) -> bytes | None:
return None

@abstractmethod
def set(self, key: str, value: bytes, ttl: int | None = None) -> None:
pass


class NoyKeyCache(KeyCache):
def get(self, key: str) -> bytes | None:
return None

def set(self, key: str, value: bytes, ttl: int | None = None) -> None:
pass


class RedisKeyCache(KeyCache):
def __init__(self, redis_client: redis.Redis, default_ttl: int | None = None):
self.redis_client = redis_client
self.default_ttl = default_ttl

def get(self, key: str) -> bytes | None:
res = self.redis_client.get(name=key)
logger.debug("Cache GET %s (%s)", key, "hit" if res else "miss")
return res

def set(self, key: str, value: bytes, ttl: int | None = None) -> None:
ttl = ttl if ttl is not None else self.default_ttl
expires_at = int(time.time()) + ttl
logger.debug("Cache SET %s with TTL %d EXAT %d", key, ttl, expires_at)
self.redis_client.set(name=key, value=value, exat=expires_at)
55 changes: 55 additions & 0 deletions aggrec/key_resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from abc import abstractmethod
from urllib.parse import urljoin

import httpx
from cryptography.hazmat.primitives.serialization import load_pem_public_key
from http_message_signatures import HTTPSignatureKeyResolver
from werkzeug.utils import safe_join

from .key_cache import KeyCache


class CacheKeyResolver(HTTPSignatureKeyResolver):
def __init__(self, key_cache: KeyCache):
self.key_cache = key_cache

@abstractmethod
def get_public_key_pem(self, key_id: str) -> bytes:
pass

def resolve_public_key(self, key_id: str):
public_key_pem = self.key_cache.get(key_id)
if not public_key_pem:
public_key_pem = self.get_public_key_pem(key_id)
self.key_cache.set(key_id, public_key_pem)
return load_pem_public_key(public_key_pem)


class FileKeyResolver(CacheKeyResolver):
def __init__(self, client_database_directory: str, key_cache: KeyCache):
super().__init__(key_cache=key_cache)
self.client_database_directory = client_database_directory

def get_public_key_pem(self, key_id: str) -> bytes:
filename = safe_join(self.client_database_directory, f"{key_id}.pem")
try:
with open(filename, "rb") as fp:
return fp.read()
except FileNotFoundError as exc:
raise KeyError(key_id) from exc


class UrlKeyResolver(CacheKeyResolver):
def __init__(self, client_database_base_url: str, key_cache: KeyCache):
super().__init__(key_cache=key_cache)
self.client_database_base_url = client_database_base_url
self.httpx_client = httpx.Client()

def get_public_key_pem(self, key_id: str) -> bytes:
public_key_url = urljoin(self.client_database_base_url, f"{key_id}.pem")
try:
response = self.httpx_client.get(public_key_url)
response.raise_for_status()
return response.content
except httpx.HTTPError as exc:
raise KeyError(key_id) from exc
8 changes: 8 additions & 0 deletions aggrec/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import aiomqtt
import boto3
import mongoengine
import redis
import uvicorn
from fastapi import FastAPI
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
Expand All @@ -13,6 +14,7 @@
import aggrec.extras

from . import OPENAPI_METADATA, __verbose_version__
from .key_cache import NoyKeyCache, RedisKeyCache
from .logging import JsonFormatter # noqa
from .settings import Settings
from .telemetry import configure_opentelemetry
Expand Down Expand Up @@ -68,6 +70,12 @@ def __init__(self, settings: Settings):
metrics_endpoint=str(settings.otlp.metrics_endpoint),
insecure=settings.otlp.insecure,
)
if self.settings.redis:
redis_client = redis.StrictRedis(host=self.settings.redis.host, port=self.settings.redis.port)
self.logger.debug("Using REDIS at %s:%d", self.settings.redis.host, self.settings.redis.port)
self.key_cache = RedisKeyCache(redis_client=redis_client, default_ttl=self.settings.redis.ttl)
else:
self.key_cache = NoyKeyCache()

@staticmethod
def connect_mongodb(settings: Settings):
Expand Down
7 changes: 7 additions & 0 deletions aggrec/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,20 @@ class OtlpSettings(BaseModel):
insecure: bool = False


class RedisSettings(BaseModel):
host: str = Field(description="Redis hostname")
port: int = Field(description="Redis port", default=6379)
ttl: int = Field(description="Redis cache TTL", default=300)


class Settings(BaseSettings):
metadata_base_url: AnyHttpUrl = Field(default="http://127.0.0.1")
clients_database: DirectoryPath | AnyHttpUrl = Field(default="clients")
s3: S3 = Field(default=S3())
mqtt: MqttSettings = Field(default=MqttSettings())
mongodb: MongoDB = Field(default=MongoDB())
otlp: OtlpSettings = Field(default=OtlpSettings())
redis: RedisSettings | None = None

model_config = SettingsConfigDict(toml_file="aggrec.toml")

Expand Down
Loading

0 comments on commit 9e31626

Please sign in to comment.