Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache multithreading bugfix #92

Merged
merged 4 commits into from
Feb 6, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 38 additions & 16 deletions lotus/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import pickle
import sqlite3
import threading
import time
from abc import ABC, abstractmethod
from collections import OrderedDict
Expand Down Expand Up @@ -113,17 +114,40 @@ def create_default_cache(max_size: int = 1024) -> Cache:
return CacheFactory.create_cache(CacheConfig(CacheType.IN_MEMORY, max_size))


class ThreadLocalConnection:
"""Wrapper that automatically closes connection when thread dies"""

def __init__(self, db_path: str):
self._db_path = db_path
self._conn: sqlite3.Connection | None = None

@property
def connection(self) -> sqlite3.Connection:
if self._conn is None:
self._conn = sqlite3.connect(self._db_path)
return self._conn

def __del__(self):
if self._conn is not None:
self._conn.close()


class SQLiteCache(Cache):
def __init__(self, max_size: int, cache_dir=os.path.expanduser("~/.lotus/cache")):
super().__init__(max_size)
self.db_path = os.path.join(cache_dir, "lotus_cache.db")
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
self.conn = sqlite3.connect(self.db_path)
self._local = threading.local()
self._create_table()

def _get_connection(self) -> sqlite3.Connection:
if not hasattr(self._local, "conn_wrapper"):
self._local.conn_wrapper = ThreadLocalConnection(self.db_path)
return self._local.conn_wrapper.connection

def _create_table(self):
with self.conn:
self.conn.execute("""
with self._get_connection() as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS cache (
key TEXT PRIMARY KEY,
value BLOB,
Expand All @@ -136,27 +160,28 @@ def _get_time(self):

@require_cache_enabled
def get(self, key: str) -> Any | None:
with self.conn:
cursor = self.conn.execute("SELECT value FROM cache WHERE key = ?", (key,))
with self._get_connection() as conn:
cursor = conn.execute("SELECT value FROM cache WHERE key = ?", (key,))
result = cursor.fetchone()
if result:
lotus.logger.debug(f"Cache hit for {key}")
value = pickle.loads(result[0])
self.conn.execute(
conn.execute(
"UPDATE cache SET last_accessed = ? WHERE key = ?",
(
self._get_time(),
key,
),
)
return value
cursor.close()
return None

@require_cache_enabled
def insert(self, key: str, value: Any):
pickled_value = pickle.dumps(value)
with self.conn:
self.conn.execute(
with self._get_connection() as conn:
conn.execute(
"""
INSERT OR REPLACE INTO cache (key, value, last_accessed)
VALUES (?, ?, ?)
Expand All @@ -166,11 +191,11 @@ def insert(self, key: str, value: Any):
self._enforce_size_limit()

def _enforce_size_limit(self):
with self.conn:
count = self.conn.execute("SELECT COUNT(*) FROM cache").fetchone()[0]
with self._get_connection() as conn:
count = conn.execute("SELECT COUNT(*) FROM cache").fetchone()[0]
if count > self.max_size:
num_to_delete = count - self.max_size
self.conn.execute(
conn.execute(
"""
DELETE FROM cache WHERE key IN (
SELECT key FROM cache
Expand All @@ -182,14 +207,11 @@ def _enforce_size_limit(self):
)

def reset(self, max_size: int | None = None):
with self.conn:
self.conn.execute("DELETE FROM cache")
with self._get_connection() as conn:
conn.execute("DELETE FROM cache")
if max_size is not None:
self.max_size = max_size

def __del__(self):
self.conn.close()


class InMemoryCache(Cache):
def __init__(self, max_size: int):
Expand Down
Loading