From 8a207aa729ae221f84916b0edde93cc313ecd2a2 Mon Sep 17 00:00:00 2001 From: Stanley Chan <149976039+StanChan03@users.noreply.github.com> Date: Sun, 12 Jan 2025 19:01:32 -0800 Subject: [PATCH] Independent operator level cache + doc update (#74) Independent operator level caching + doc updates --- .github/tests/lm_tests.py | 17 ++++++++++------- docs/configurations.rst | 11 ++++++++++- lotus/cache.py | 38 ++++++++++++++++++++++++++------------ lotus/models/lm.py | 31 ++++++++++++++++++++----------- lotus/settings.py | 3 +-- 5 files changed, 67 insertions(+), 33 deletions(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index 3a18bf8..1704bbd 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -399,7 +399,7 @@ def test_custom_tokenizer(): @pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) def test_cache(setup_models, model): lm = setup_models[model] - lotus.settings.configure(lm=lm, enable_message_cache=True) + lotus.settings.configure(lm=lm, enable_cache=True) # Check that "What is the capital of France?" becomes cached first_batch = [ @@ -428,19 +428,20 @@ def test_cache(setup_models, model): @pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) def test_disable_cache(setup_models, model): lm = setup_models[model] - lotus.settings.configure(lm=lm, enable_message_cache=False) + lotus.settings.configure(lm=lm, enable_cache=False) batch = [ [{"role": "user", "content": "Hello, world!"}], [{"role": "user", "content": "What is the capital of France?"}], ] + lm(batch) assert lm.stats.total_usage.cache_hits == 0 lm(batch) assert lm.stats.total_usage.cache_hits == 0 # Now enable cache. Note that the first batch is not cached. - lotus.settings.configure(enable_message_cache=True) + lotus.settings.configure(enable_cache=True) first_responses = lm(batch).outputs assert lm.stats.total_usage.cache_hits == 0 second_responses = lm(batch).outputs @@ -451,7 +452,7 @@ def test_disable_cache(setup_models, model): @pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) def test_reset_cache(setup_models, model): lm = setup_models[model] - lotus.settings.configure(lm=lm, enable_message_cache=True) + lotus.settings.configure(lm=lm, enable_cache=True) batch = [ [{"role": "user", "content": "Hello, world!"}], @@ -481,7 +482,7 @@ def test_operator_cache(setup_models, model): cache = CacheFactory.create_cache(cache_config) lm = LM(model="gpt-4o-mini", cache=cache) - lotus.settings.configure(lm=lm, enable_message_cache=True, enable_operator_cache=True) + lotus.settings.configure(lm=lm, enable_cache=True) data = { "Course Name": [ @@ -537,7 +538,7 @@ def test_disable_operator_cache(setup_models, model): cache = CacheFactory.create_cache(cache_config) lm = LM(model="gpt-4o-mini", cache=cache) - lotus.settings.configure(lm=lm, enable_message_cache=True, enable_operator_cache=False) + lotus.settings.configure(lm=lm, enable_cache=False) data = { "Course Name": [ @@ -569,15 +570,17 @@ def test_disable_operator_cache(setup_models, model): user_instruction = "What is a similar course to {Course Name}. Please just output the course name." first_response = df.sem_map(user_instruction) + first_response["_map"] = first_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower() assert lm.stats.total_usage.operator_cache_hits == 0 second_response = df.sem_map(user_instruction) + second_response["_map"] = second_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower() assert lm.stats.total_usage.operator_cache_hits == 0 pd.testing.assert_frame_equal(first_response, second_response) # Now enable operator cache. - lotus.settings.configure(enable_operator_cache=True) + lotus.settings.configure(enable_cache=True) first_responses = df.sem_map(user_instruction) first_responses["_map"] = first_responses["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower() assert lm.stats.total_usage.operator_cache_hits == 0 diff --git a/docs/configurations.rst b/docs/configurations.rst index 0ee0172..16cec15 100644 --- a/docs/configurations.rst +++ b/docs/configurations.rst @@ -24,9 +24,18 @@ Configurable Parameters 1. enable_message_cache: * Description: Enables or Disables cahcing mechanisms * Default: False + * Parameters: + - cache_type: Type of caching (SQLITE or In_MEMORY) + - max_size: maximum size of cache + - cache_dir: Directory for where DB file is stored. Default: "~/.lotus/cache" + * Note: It is recommended to enable caching .. code-block:: python - lotus.settings.configure(enable_message_cache=True) + cache_config = CacheConfig(cache_type=CacheType.SQLITE, max_size=1000) + cache = CacheFactory.create_cache(cache_config) + + lm = LM(model='gpt-4o-mini', cache=cache) + lotus.settings.configure(lm=lm, enable_cache=True) 2. setting RM: * Description: Configures the retrieval model diff --git a/lotus/cache.py b/lotus/cache.py index 82c1c4c..b5b85c7 100644 --- a/lotus/cache.py +++ b/lotus/cache.py @@ -20,7 +20,7 @@ def require_cache_enabled(func: Callable) -> Callable: @wraps(func) def wrapper(self, *args, **kwargs): - if not lotus.settings.enable_message_cache: + if not lotus.settings.enable_cache: return None return func(self, *args, **kwargs) @@ -33,21 +33,39 @@ def operator_cache(func: Callable) -> Callable: @wraps(func) def wrapper(self, *args, **kwargs): model = lotus.settings.lm - use_operator_cache = lotus.settings.enable_operator_cache + use_operator_cache = lotus.settings.enable_cache if use_operator_cache and model.cache: - def serialize(value): - if isinstance(value, pd.DataFrame): - return value.to_json() + def serialize(value: Any) -> Any: + """ + Serialize a value into a JSON-serializable format. + Supports basic types, pandas DataFrames, and objects with a `dict` or `__dict__` method. + """ + if value is None or isinstance(value, (str, int, float, bool)): + return value + elif isinstance(value, pd.DataFrame): + return value.to_json(orient="split") + elif isinstance(value, (list, tuple)): + return [serialize(item) for item in value] + elif isinstance(value, dict): + return {key: serialize(val) for key, val in value.items()} elif hasattr(value, "dict"): return value.dict() - return value - + elif hasattr(value, "__dict__"): + return {key: serialize(val) for key, val in vars(value).items() if not key.startswith("_")} + else: + # For unsupported types, convert to string (last resort) + lotus.logger.warning(f"Unsupported type {type(value)} for serialization. Converting to string.") + return str(value) + + serialize_self = serialize(self) serialized_kwargs = {key: serialize(value) for key, value in kwargs.items()} serialized_args = [serialize(arg) for arg in args] cache_key = hashlib.sha256( - json.dumps({"args": serialized_args, "kwargs": serialized_kwargs}, sort_keys=True).encode() + json.dumps( + {"self": serialize_self, "args": serialized_args, "kwargs": serialized_kwargs}, sort_keys=True + ).encode() ).hexdigest() cached_result = model.cache.get(cache_key) @@ -134,7 +152,6 @@ def _create_table(self): def _get_time(self): return int(time.time()) - @require_cache_enabled def get(self, key: str) -> Any | None: with self.conn: cursor = self.conn.execute("SELECT value FROM cache WHERE key = ?", (key,)) @@ -152,7 +169,6 @@ def get(self, key: str) -> Any | None: return value return None - @require_cache_enabled def insert(self, key: str, value: Any): pickled_value = pickle.dumps(value) with self.conn: @@ -196,14 +212,12 @@ def __init__(self, max_size: int): super().__init__(max_size) self.cache: OrderedDict[str, Any] = OrderedDict() - @require_cache_enabled def get(self, key: str) -> Any | None: if key in self.cache: lotus.logger.debug(f"Cache hit for {key}") return self.cache.get(key) - @require_cache_enabled def insert(self, key: str, value: Any): self.cache[key] = value diff --git a/lotus/models/lm.py b/lotus/models/lm.py index c63868b..8ebccf3 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -55,25 +55,34 @@ def __call__( if all_kwargs.get("logprobs", False): all_kwargs.setdefault("top_logprobs", 10) - # Check cache and separate cached and uncached messages - hashed_messages = [self._hash_messages(msg, all_kwargs) for msg in messages] - cached_responses = [self.cache.get(hash) for hash in hashed_messages] - uncached_data = [ - (msg, hash) for msg, hash, resp in zip(messages, hashed_messages, cached_responses) if resp is None - ] + if lotus.settings.enable_cache: + # Check cache and separate cached and uncached messages + hashed_messages = [self._hash_messages(msg, all_kwargs) for msg in messages] + cached_responses = [self.cache.get(hash) for hash in hashed_messages] + + uncached_data = ( + [(msg, hash) for msg, hash, resp in zip(messages, hashed_messages, cached_responses) if resp is None] + if lotus.settings.enable_cache + else [(msg, "no-cache") for msg in messages] + ) + self.stats.total_usage.cache_hits += len(messages) - len(uncached_data) # Process uncached messages in batches uncached_responses = self._process_uncached_messages( uncached_data, all_kwargs, show_progress_bar, progress_bar_desc ) - - # Add new responses to cache - for resp, (_, hash) in zip(uncached_responses, uncached_data): - self._cache_response(resp, hash) + if lotus.settings.enable_cache: + # Add new responses to cache + for resp, (_, hash) in zip(uncached_responses, uncached_data): + self._cache_response(resp, hash) # Merge all responses in original order and extract outputs - all_responses = self._merge_responses(cached_responses, uncached_responses) + all_responses = ( + self._merge_responses(cached_responses, uncached_responses) + if lotus.settings.enable_cache + else uncached_responses + ) outputs = [self._get_top_choice(resp) for resp in all_responses] logprobs = ( [self._get_top_choice_logprobs(resp) for resp in all_responses] if all_kwargs.get("logprobs") else None diff --git a/lotus/settings.py b/lotus/settings.py index 99e5944..ce12363 100644 --- a/lotus/settings.py +++ b/lotus/settings.py @@ -12,8 +12,7 @@ class Settings: reranker: lotus.models.Reranker | None = None # Cache settings - enable_message_cache: bool = False - enable_operator_cache: bool = False + enable_cache: bool = False # Serialization setting serialization_format: SerializationFormat = SerializationFormat.DEFAULT