diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index ae68c109..e8e46f47 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -47,6 +47,7 @@ def print_usage_after_each_test(setup_models): print(f"\nUsage stats for {model_name} after test:") model.print_total_usage() model.reset_stats() + model.reset_cache() ################################################################################ @@ -276,3 +277,85 @@ def test_custom_tokenizer(): tokens = custom_lm.count_tokens("Hello, world!") assert custom_lm.count_tokens([{"role": "user", "content": "Hello, world!"}]) == tokens assert tokens < 100 + + +################################################################################ +# Cache tests +################################################################################ +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) +def test_cache(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm, enable_cache=True) + + # Check that "What is the capital of France?" becomes cached + first_batch = [ + [{"role": "user", "content": "Hello, world!"}], + [{"role": "user", "content": "What is the capital of France?"}], + ] + + first_responses = lm(first_batch).outputs + assert lm.stats.total_usage.cache_hits == 0 + + second_batch = [ + [{"role": "user", "content": "What is the capital of France?"}], + [{"role": "user", "content": "What is the capital of Germany?"}], + ] + second_responses = lm(second_batch).outputs + assert second_responses[0] == first_responses[1] + assert lm.stats.total_usage.cache_hits == 1 + + # Test clearing cache + lm.reset_cache() + lm.reset_stats() + lm(second_batch) + assert lm.stats.total_usage.cache_hits == 0 + + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) +def test_disable_cache(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm, enable_cache=False) + + batch = [ + [{"role": "user", "content": "Hello, world!"}], + [{"role": "user", "content": "What is the capital of France?"}], + ] + lm(batch) + assert lm.stats.total_usage.cache_hits == 0 + lm(batch) + assert lm.stats.total_usage.cache_hits == 0 + + # Now enable cache. Note that the first batch is not cached. + lotus.settings.configure(enable_cache=True) + first_responses = lm(batch).outputs + assert lm.stats.total_usage.cache_hits == 0 + second_responses = lm(batch).outputs + assert lm.stats.total_usage.cache_hits == 2 + assert first_responses == second_responses + + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) +def test_reset_cache(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm, enable_cache=True) + + batch = [ + [{"role": "user", "content": "Hello, world!"}], + [{"role": "user", "content": "What is the capital of France?"}], + ] + lm(batch) + assert lm.stats.total_usage.cache_hits == 0 + lm(batch) + assert lm.stats.total_usage.cache_hits == 2 + + lm.reset_cache(max_size=1) + lm(batch) + assert lm.stats.total_usage.cache_hits == 2 + lm(batch) + assert lm.stats.total_usage.cache_hits == 3 + + lm.reset_cache(max_size=0) + lm(batch) + assert lm.stats.total_usage.cache_hits == 3 + lm(batch) + assert lm.stats.total_usage.cache_hits == 3 diff --git a/examples/model_examples/cache.py b/examples/model_examples/cache.py new file mode 100644 index 00000000..556ed84d --- /dev/null +++ b/examples/model_examples/cache.py @@ -0,0 +1,27 @@ +import pandas as pd + +import lotus +from lotus.models import LM + +lm = LM(model="gpt-4o-mini") + +lotus.settings.configure(lm=lm, enable_cache=True) # default caching is False +data = { + "Course Name": [ + "Probability and Random Processes", + "Optimization Methods in Engineering", + "Digital Design and Integrated Circuits", + "Computer Security", + ] +} +df = pd.DataFrame(data) +user_instruction = "{Course Name} requires a lot of math" +df = df.sem_filter(user_instruction) +print("====== intial run ======") +print(df) + +# run a second time +df = df.sem_filter(user_instruction) +print("====== second run ======") +print(df) + diff --git a/lotus/cache.py b/lotus/cache.py new file mode 100644 index 00000000..33004013 --- /dev/null +++ b/lotus/cache.py @@ -0,0 +1,43 @@ +from collections import OrderedDict +from functools import wraps +from typing import Any, Callable + +import lotus + + +def require_cache_enabled(func: Callable) -> Callable: + """Decorator to check if caching is enabled before calling the function.""" + + @wraps(func) + def wrapper(self, *args, **kwargs): + if not lotus.settings.enable_cache: + return None + return func(self, *args, **kwargs) + + return wrapper + + +class Cache: + def __init__(self, max_size: int): + self.max_size = max_size + self.cache: OrderedDict[str, Any] = OrderedDict() + + @require_cache_enabled + def get(self, key: str) -> Any | None: + if key in self.cache: + lotus.logger.debug(f"Cache hit for {key}") + + return self.cache.get(key) + + @require_cache_enabled + def insert(self, key: str, value: Any): + self.cache[key] = value + + # LRU eviction + if len(self.cache) > self.max_size: + self.cache.popitem(last=False) + + def reset(self, max_size: int | None = None): + self.cache.clear() + if max_size is not None: + self.max_size = max_size diff --git a/lotus/models/lm.py b/lotus/models/lm.py index ee03a444..30852ebb 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -1,3 +1,4 @@ +import hashlib from typing import Any import litellm @@ -9,6 +10,7 @@ from tokenizers import Tokenizer import lotus +from lotus.cache import Cache from lotus.types import LMOutput, LMStats, LogprobsForCascade, LogprobsForFilterCascade @@ -21,6 +23,7 @@ def __init__( max_tokens: int = 512, max_batch_size: int = 64, tokenizer: Tokenizer | None = None, + max_cache_size: int = 1024, **kwargs: dict[str, Any], ): self.model = model @@ -31,40 +34,66 @@ def __init__( self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) self.stats: LMStats = LMStats() + self.cache = Cache(max_cache_size) def __call__(self, messages: list[list[dict[str, str]]], **kwargs: dict[str, Any]) -> LMOutput: all_kwargs = {**self.kwargs, **kwargs} # Set top_logprobs if logprobs requested if all_kwargs.get("logprobs", False): - all_kwargs["top_logprobs"] = all_kwargs.get("top_logprobs", 10) - - all_responses: list[ModelResponse] = [] - for i in range(0, len(messages), self.max_batch_size): - batch = messages[i : i + self.max_batch_size] - responses: list[ModelResponse] = batch_completion( - self.model, - batch, - drop_params=True, - **all_kwargs, # type: ignore - ) - all_responses.extend(responses) - - # throw errors, if any - for resp in all_responses: - if isinstance(resp, OpenAIError): - raise resp + all_kwargs.setdefault("top_logprobs", 10) + # Check cache and separate cached and uncached messages + hashed_messages = [self._hash_messages(msg, all_kwargs) for msg in messages] + cached_responses = [self.cache.get(hash) for hash in hashed_messages] + uncached_data = [ + (msg, hash) for msg, hash, resp in zip(messages, hashed_messages, cached_responses) if resp is None + ] + self.stats.total_usage.cache_hits += len(messages) - len(uncached_data) + + # Process uncached messages in batches + uncached_responses = self._process_uncached_messages(uncached_data, all_kwargs) + + # Add new responses to cache + for resp, (_, hash) in zip(uncached_responses, uncached_data): + self._cache_response(resp, hash) + + # Merge all responses in original order and extract outputs + all_responses = self._merge_responses(cached_responses, uncached_responses) outputs = [self._get_top_choice(resp) for resp in all_responses] logprobs = ( [self._get_top_choice_logprobs(resp) for resp in all_responses] if all_kwargs.get("logprobs") else None ) - for resp in all_responses: - self._update_stats(resp) - return LMOutput(outputs=outputs, logprobs=logprobs) + def _process_uncached_messages(self, uncached_data, all_kwargs): + """Processes uncached messages in batches and returns responses.""" + uncached_responses = [] + for i in range(0, len(uncached_data), self.max_batch_size): + batch = [msg for msg, _ in uncached_data[i : i + self.max_batch_size]] + uncached_responses.extend(batch_completion(self.model, batch, drop_params=True, **all_kwargs)) + return uncached_responses + + def _cache_response(self, response, hash): + """Caches a response and updates stats if successful.""" + if isinstance(response, OpenAIError): + raise response + self._update_stats(response) + self.cache.insert(hash, response) + + def _hash_messages(self, messages: list[dict[str, str]], kwargs: dict[str, Any]) -> str: + """Hash messages and kwargs to create a unique key for the cache""" + to_hash = str(self.model) + str(messages) + str(kwargs) + return hashlib.sha256(to_hash.encode()).hexdigest() + + def _merge_responses( + self, cached_responses: list[ModelResponse | None], uncached_responses: list[ModelResponse] + ) -> list[ModelResponse]: + """Merge cached and uncached responses, maintaining order""" + uncached_iter = iter(uncached_responses) + return [resp if resp is not None else next(uncached_iter) for resp in cached_responses] + def _update_stats(self, response: ModelResponse): if not hasattr(response, "usage"): return @@ -155,8 +184,12 @@ def print_total_usage(self): print(f"Total prompt tokens: {self.stats.total_usage.prompt_tokens}") print(f"Total completion tokens: {self.stats.total_usage.completion_tokens}") print(f"Total tokens: {self.stats.total_usage.total_tokens}") + print(f"Total cache hits: {self.stats.total_usage.cache_hits}") def reset_stats(self): self.stats = LMStats( total_usage=LMStats.TotalUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0, total_cost=0.0) ) + + def reset_cache(self, max_size: int | None = None): + self.cache.reset(max_size) diff --git a/lotus/settings.py b/lotus/settings.py index a928880c..27aefdac 100644 --- a/lotus/settings.py +++ b/lotus/settings.py @@ -115,4 +115,4 @@ def __repr__(self) -> str: # set defaults settings = Settings() -settings.configure(cascade_is_weight=0.5, cascade_num_calibration_quantiles=50) +settings.configure(cascade_is_weight=0.5, cascade_num_calibration_quantiles=50, enable_cache=False) diff --git a/lotus/types.py b/lotus/types.py index 28cbcfe9..1d7a3bcc 100644 --- a/lotus/types.py +++ b/lotus/types.py @@ -29,6 +29,7 @@ class TotalUsage(BaseModel): completion_tokens: int = 0 total_tokens: int = 0 total_cost: float = 0.0 + cache_hits: int = 0 total_usage: TotalUsage = TotalUsage()