From 624c2c521d42241c4c66092ca8ea9993de73ab64 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Sat, 25 Jan 2025 21:18:03 -0800 Subject: [PATCH 1/4] check_same_thread is False and then used locking on the SQLite db --- lotus/cache.py | 108 +++++++++++++++++++++++++++---------------------- 1 file changed, 59 insertions(+), 49 deletions(-) diff --git a/lotus/cache.py b/lotus/cache.py index 82c1c4cc..cabd695c 100644 --- a/lotus/cache.py +++ b/lotus/cache.py @@ -9,12 +9,14 @@ from enum import Enum from functools import wraps from typing import Any, Callable +import threading import pandas as pd import lotus + def require_cache_enabled(func: Callable) -> Callable: """Decorator to check if caching is enabled before calling the function.""" @@ -118,74 +120,82 @@ def __init__(self, max_size: int, cache_dir=os.path.expanduser("~/.lotus/cache") super().__init__(max_size) self.db_path = os.path.join(cache_dir, "lotus_cache.db") os.makedirs(os.path.dirname(self.db_path), exist_ok=True) - self.conn = sqlite3.connect(self.db_path) + self.conn = sqlite3.connect(self.db_path, check_same_thread=False) + self._lock = threading.Lock() self._create_table() + def _create_table(self): - with self.conn: - self.conn.execute(""" - CREATE TABLE IF NOT EXISTS cache ( - key TEXT PRIMARY KEY, - value BLOB, - last_accessed INTEGER - ) - """) + with self._lock: + with self.conn: + self.conn.execute(""" + CREATE TABLE IF NOT EXISTS cache ( + key TEXT PRIMARY KEY, + value BLOB, + last_accessed INTEGER + ) + """) def _get_time(self): return int(time.time()) @require_cache_enabled def get(self, key: str) -> Any | None: - with self.conn: - cursor = self.conn.execute("SELECT value FROM cache WHERE key = ?", (key,)) - result = cursor.fetchone() - if result: - lotus.logger.debug(f"Cache hit for {key}") - value = pickle.loads(result[0]) - self.conn.execute( - "UPDATE cache SET last_accessed = ? WHERE key = ?", - ( - self._get_time(), - key, - ), - ) - return value - return None + with self._lock: + with self.conn: + cursor = self.conn.execute("SELECT value FROM cache WHERE key = ?", (key,)) + result = cursor.fetchone() + if result: + lotus.logger.debug(f"Cache hit for {key}") + value = pickle.loads(result[0]) + self.conn.execute( + "UPDATE cache SET last_accessed = ? WHERE key = ?", + ( + self._get_time(), + key, + ), + ) + return value + cursor.close() + return None @require_cache_enabled def insert(self, key: str, value: Any): pickled_value = pickle.dumps(value) - with self.conn: - self.conn.execute( - """ - INSERT OR REPLACE INTO cache (key, value, last_accessed) - VALUES (?, ?, ?) - """, - (key, pickled_value, self._get_time()), - ) - self._enforce_size_limit() - - def _enforce_size_limit(self): - with self.conn: - count = self.conn.execute("SELECT COUNT(*) FROM cache").fetchone()[0] - if count > self.max_size: - num_to_delete = count - self.max_size + with self._lock: + with self.conn: self.conn.execute( """ - DELETE FROM cache WHERE key IN ( - SELECT key FROM cache - ORDER BY last_accessed ASC - LIMIT ? - ) + INSERT OR REPLACE INTO cache (key, value, last_accessed) + VALUES (?, ?, ?) """, - (num_to_delete,), + (key, pickled_value, self._get_time()), ) + self._enforce_size_limit() + + def _enforce_size_limit(self): + with self._lock: + with self.conn: + count = self.conn.execute("SELECT COUNT(*) FROM cache").fetchone()[0] + if count > self.max_size: + num_to_delete = count - self.max_size + self.conn.execute( + """ + DELETE FROM cache WHERE key IN ( + SELECT key FROM cache + ORDER BY last_accessed ASC + LIMIT ? + ) + """, + (num_to_delete,), + ) def reset(self, max_size: int | None = None): - with self.conn: - self.conn.execute("DELETE FROM cache") - if max_size is not None: - self.max_size = max_size + with self._lock: + with self.conn: + self.conn.execute("DELETE FROM cache") + if max_size is not None: + self.max_size = max_size def __del__(self): self.conn.close() From dd201d346bde279a750c2b6bb23e5a6a1ce2f5cd Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Sat, 25 Jan 2025 21:43:12 -0800 Subject: [PATCH 2/4] replaced naive approach with more robust one, creating a seperate connection per thread --- lotus/cache.py | 133 ++++++++++++++++++++++++++----------------------- 1 file changed, 71 insertions(+), 62 deletions(-) diff --git a/lotus/cache.py b/lotus/cache.py index cabd695c..45e0871c 100644 --- a/lotus/cache.py +++ b/lotus/cache.py @@ -114,92 +114,101 @@ def create_cache(config: CacheConfig) -> Cache: def create_default_cache(max_size: int = 1024) -> Cache: return CacheFactory.create_cache(CacheConfig(CacheType.IN_MEMORY, max_size)) +class ThreadLocalConnection: + """Wrapper that automatically closes connection when thread dies""" + def __init__(self, db_path: str): + self._db_path = db_path + self._conn = None + + @property + def connection(self) -> sqlite3.Connection: + if self._conn is None: + self._conn = sqlite3.connect(self._db_path) + return self._conn + + def __del__(self): + if self._conn is not None: + self._conn.close() class SQLiteCache(Cache): def __init__(self, max_size: int, cache_dir=os.path.expanduser("~/.lotus/cache")): super().__init__(max_size) self.db_path = os.path.join(cache_dir, "lotus_cache.db") os.makedirs(os.path.dirname(self.db_path), exist_ok=True) - self.conn = sqlite3.connect(self.db_path, check_same_thread=False) - self._lock = threading.Lock() + self._local = threading.local() self._create_table() - + + def _get_connection(self) -> sqlite3.Connection: + if not hasattr(self._local, 'conn_wrapper'): + self._local.conn_wrapper = ThreadLocalConnection(self.db_path) + return self._local.conn_wrapper.connection def _create_table(self): - with self._lock: - with self.conn: - self.conn.execute(""" - CREATE TABLE IF NOT EXISTS cache ( - key TEXT PRIMARY KEY, - value BLOB, - last_accessed INTEGER - ) - """) + with self._get_connection() as conn: + conn.execute(""" + CREATE TABLE IF NOT EXISTS cache ( + key TEXT PRIMARY KEY, + value BLOB, + last_accessed INTEGER + ) + """) def _get_time(self): return int(time.time()) @require_cache_enabled def get(self, key: str) -> Any | None: - with self._lock: - with self.conn: - cursor = self.conn.execute("SELECT value FROM cache WHERE key = ?", (key,)) - result = cursor.fetchone() - if result: - lotus.logger.debug(f"Cache hit for {key}") - value = pickle.loads(result[0]) - self.conn.execute( - "UPDATE cache SET last_accessed = ? WHERE key = ?", - ( - self._get_time(), - key, - ), - ) - return value - cursor.close() - return None + with self._get_connection() as conn: + cursor = conn.execute("SELECT value FROM cache WHERE key = ?", (key,)) + result = cursor.fetchone() + if result: + lotus.logger.debug(f"Cache hit for {key}") + value = pickle.loads(result[0]) + conn.execute( + "UPDATE cache SET last_accessed = ? WHERE key = ?", + ( + self._get_time(), + key, + ), + ) + return value + cursor.close() + return None @require_cache_enabled def insert(self, key: str, value: Any): pickled_value = pickle.dumps(value) - with self._lock: - with self.conn: - self.conn.execute( - """ - INSERT OR REPLACE INTO cache (key, value, last_accessed) - VALUES (?, ?, ?) - """, - (key, pickled_value, self._get_time()), - ) - self._enforce_size_limit() + with self._get_connection() as conn: + conn.execute( + """ + INSERT OR REPLACE INTO cache (key, value, last_accessed) + VALUES (?, ?, ?) + """, + (key, pickled_value, self._get_time()), + ) + self._enforce_size_limit() def _enforce_size_limit(self): - with self._lock: - with self.conn: - count = self.conn.execute("SELECT COUNT(*) FROM cache").fetchone()[0] - if count > self.max_size: - num_to_delete = count - self.max_size - self.conn.execute( - """ - DELETE FROM cache WHERE key IN ( - SELECT key FROM cache - ORDER BY last_accessed ASC - LIMIT ? - ) - """, - (num_to_delete,), + with self._get_connection() as conn: + count = conn.execute("SELECT COUNT(*) FROM cache").fetchone()[0] + if count > self.max_size: + num_to_delete = count - self.max_size + conn.execute( + """ + DELETE FROM cache WHERE key IN ( + SELECT key FROM cache + ORDER BY last_accessed ASC + LIMIT ? ) + """, + (num_to_delete,), + ) def reset(self, max_size: int | None = None): - with self._lock: - with self.conn: - self.conn.execute("DELETE FROM cache") - if max_size is not None: - self.max_size = max_size - - def __del__(self): - self.conn.close() - + with self._get_connection() as conn: + conn.execute("DELETE FROM cache") + if max_size is not None: + self.max_size = max_size class InMemoryCache(Cache): def __init__(self, max_size: int): From a8a98094f0c0e7be2b469dd5813bfc4714d0f9b3 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Sat, 25 Jan 2025 21:48:00 -0800 Subject: [PATCH 3/4] ruff format --- lotus/cache.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/lotus/cache.py b/lotus/cache.py index 45e0871c..549757b8 100644 --- a/lotus/cache.py +++ b/lotus/cache.py @@ -3,20 +3,19 @@ import os import pickle import sqlite3 +import threading import time from abc import ABC, abstractmethod from collections import OrderedDict from enum import Enum from functools import wraps from typing import Any, Callable -import threading import pandas as pd import lotus - def require_cache_enabled(func: Callable) -> Callable: """Decorator to check if caching is enabled before calling the function.""" @@ -114,22 +113,25 @@ def create_cache(config: CacheConfig) -> Cache: def create_default_cache(max_size: int = 1024) -> Cache: return CacheFactory.create_cache(CacheConfig(CacheType.IN_MEMORY, max_size)) + class ThreadLocalConnection: """Wrapper that automatically closes connection when thread dies""" + def __init__(self, db_path: str): self._db_path = db_path self._conn = None - + @property def connection(self) -> sqlite3.Connection: if self._conn is None: self._conn = sqlite3.connect(self._db_path) return self._conn - + def __del__(self): if self._conn is not None: self._conn.close() + class SQLiteCache(Cache): def __init__(self, max_size: int, cache_dir=os.path.expanduser("~/.lotus/cache")): super().__init__(max_size) @@ -137,9 +139,9 @@ def __init__(self, max_size: int, cache_dir=os.path.expanduser("~/.lotus/cache") os.makedirs(os.path.dirname(self.db_path), exist_ok=True) self._local = threading.local() self._create_table() - + def _get_connection(self) -> sqlite3.Connection: - if not hasattr(self._local, 'conn_wrapper'): + if not hasattr(self._local, "conn_wrapper"): self._local.conn_wrapper = ThreadLocalConnection(self.db_path) return self._local.conn_wrapper.connection @@ -210,6 +212,7 @@ def reset(self, max_size: int | None = None): if max_size is not None: self.max_size = max_size + class InMemoryCache(Cache): def __init__(self, max_size: int): super().__init__(max_size) From ea25f76d6754178866af1d5a58ea955e0a5425c4 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Sat, 25 Jan 2025 21:55:51 -0800 Subject: [PATCH 4/4] fix mypy errors --- lotus/cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lotus/cache.py b/lotus/cache.py index 549757b8..540ef51d 100644 --- a/lotus/cache.py +++ b/lotus/cache.py @@ -119,7 +119,7 @@ class ThreadLocalConnection: def __init__(self, db_path: str): self._db_path = db_path - self._conn = None + self._conn: sqlite3.Connection | None = None @property def connection(self) -> sqlite3.Connection: