Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LM Caching #31

Merged
merged 9 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def print_usage_after_each_test(setup_models):
print(f"\nUsage stats for {model_name} after test:")
model.print_total_usage()
model.reset_stats()
model.reset_cache()


################################################################################
Expand Down Expand Up @@ -276,3 +277,85 @@ def test_custom_tokenizer():
tokens = custom_lm.count_tokens("Hello, world!")
assert custom_lm.count_tokens([{"role": "user", "content": "Hello, world!"}]) == tokens
assert tokens < 100


################################################################################
# Cache tests
################################################################################
@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_cache(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm, enable_cache=True)

# Check that "What is the capital of France?" becomes cached
first_batch = [
[{"role": "user", "content": "Hello, world!"}],
[{"role": "user", "content": "What is the capital of France?"}],
]

first_responses = lm(first_batch).outputs
assert lm.stats.total_usage.cache_hits == 0

second_batch = [
[{"role": "user", "content": "What is the capital of France?"}],
[{"role": "user", "content": "What is the capital of Germany?"}],
]
second_responses = lm(second_batch).outputs
assert second_responses[0] == first_responses[1]
assert lm.stats.total_usage.cache_hits == 1

# Test clearing cache
lm.reset_cache()
lm.reset_stats()
lm(second_batch)
assert lm.stats.total_usage.cache_hits == 0


@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_disable_cache(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm, enable_cache=False)

batch = [
[{"role": "user", "content": "Hello, world!"}],
[{"role": "user", "content": "What is the capital of France?"}],
]
lm(batch)
assert lm.stats.total_usage.cache_hits == 0
lm(batch)
assert lm.stats.total_usage.cache_hits == 0

# Now enable cache. Note that the first batch is not cached.
lotus.settings.configure(enable_cache=True)
first_responses = lm(batch).outputs
assert lm.stats.total_usage.cache_hits == 0
second_responses = lm(batch).outputs
assert lm.stats.total_usage.cache_hits == 2
assert first_responses == second_responses


@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_reset_cache(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm, enable_cache=True)

batch = [
[{"role": "user", "content": "Hello, world!"}],
[{"role": "user", "content": "What is the capital of France?"}],
]
lm(batch)
assert lm.stats.total_usage.cache_hits == 0
lm(batch)
assert lm.stats.total_usage.cache_hits == 2

lm.reset_cache(max_size=1)
lm(batch)
assert lm.stats.total_usage.cache_hits == 2
lm(batch)
assert lm.stats.total_usage.cache_hits == 3

lm.reset_cache(max_size=0)
lm(batch)
assert lm.stats.total_usage.cache_hits == 3
lm(batch)
assert lm.stats.total_usage.cache_hits == 3
27 changes: 27 additions & 0 deletions examples/model_examples/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pandas as pd

import lotus
from lotus.models import LM

lm = LM(model="gpt-4o-mini")

lotus.settings.configure(lm=lm, enable_cache=True) # default caching is False
data = {
"Course Name": [
"Probability and Random Processes",
"Optimization Methods in Engineering",
"Digital Design and Integrated Circuits",
"Computer Security",
]
}
df = pd.DataFrame(data)
user_instruction = "{Course Name} requires a lot of math"
df = df.sem_filter(user_instruction)
print("====== intial run ======")
print(df)

# run a second time
df = df.sem_filter(user_instruction)
print("====== second run ======")
print(df)

43 changes: 43 additions & 0 deletions lotus/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from collections import OrderedDict
from functools import wraps
from typing import Any, Callable

import lotus


def require_cache_enabled(func: Callable) -> Callable:
"""Decorator to check if caching is enabled before calling the function."""

@wraps(func)
def wrapper(self, *args, **kwargs):
if not lotus.settings.enable_cache:
return None
return func(self, *args, **kwargs)

return wrapper


class Cache:
def __init__(self, max_size: int):
self.max_size = max_size
self.cache: OrderedDict[str, Any] = OrderedDict()

@require_cache_enabled
def get(self, key: str) -> Any | None:
if key in self.cache:
lotus.logger.debug(f"Cache hit for {key}")

return self.cache.get(key)

@require_cache_enabled
def insert(self, key: str, value: Any):
self.cache[key] = value

# LRU eviction
if len(self.cache) > self.max_size:
self.cache.popitem(last=False)

def reset(self, max_size: int | None = None):
self.cache.clear()
if max_size is not None:
self.max_size = max_size
73 changes: 53 additions & 20 deletions lotus/models/lm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
from typing import Any

import litellm
Expand All @@ -9,6 +10,7 @@
from tokenizers import Tokenizer

import lotus
from lotus.cache import Cache
from lotus.types import LMOutput, LMStats, LogprobsForCascade, LogprobsForFilterCascade


Expand All @@ -21,6 +23,7 @@ def __init__(
max_tokens: int = 512,
max_batch_size: int = 64,
tokenizer: Tokenizer | None = None,
max_cache_size: int = 1024,
**kwargs: dict[str, Any],
):
self.model = model
Expand All @@ -31,40 +34,66 @@ def __init__(
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)

self.stats: LMStats = LMStats()
self.cache = Cache(max_cache_size)

def __call__(self, messages: list[list[dict[str, str]]], **kwargs: dict[str, Any]) -> LMOutput:
all_kwargs = {**self.kwargs, **kwargs}

# Set top_logprobs if logprobs requested
if all_kwargs.get("logprobs", False):
all_kwargs["top_logprobs"] = all_kwargs.get("top_logprobs", 10)

all_responses: list[ModelResponse] = []
for i in range(0, len(messages), self.max_batch_size):
batch = messages[i : i + self.max_batch_size]
responses: list[ModelResponse] = batch_completion(
self.model,
batch,
drop_params=True,
**all_kwargs, # type: ignore
)
all_responses.extend(responses)

# throw errors, if any
for resp in all_responses:
if isinstance(resp, OpenAIError):
raise resp
all_kwargs.setdefault("top_logprobs", 10)

# Check cache and separate cached and uncached messages
hashed_messages = [self._hash_messages(msg, all_kwargs) for msg in messages]
cached_responses = [self.cache.get(hash) for hash in hashed_messages]
uncached_data = [
(msg, hash) for msg, hash, resp in zip(messages, hashed_messages, cached_responses) if resp is None
]
self.stats.total_usage.cache_hits += len(messages) - len(uncached_data)

# Process uncached messages in batches
uncached_responses = self._process_uncached_messages(uncached_data, all_kwargs)

# Add new responses to cache
for resp, (_, hash) in zip(uncached_responses, uncached_data):
self._cache_response(resp, hash)

# Merge all responses in original order and extract outputs
all_responses = self._merge_responses(cached_responses, uncached_responses)
outputs = [self._get_top_choice(resp) for resp in all_responses]
logprobs = (
[self._get_top_choice_logprobs(resp) for resp in all_responses] if all_kwargs.get("logprobs") else None
)

for resp in all_responses:
self._update_stats(resp)

return LMOutput(outputs=outputs, logprobs=logprobs)

def _process_uncached_messages(self, uncached_data, all_kwargs):
"""Processes uncached messages in batches and returns responses."""
uncached_responses = []
for i in range(0, len(uncached_data), self.max_batch_size):
batch = [msg for msg, _ in uncached_data[i : i + self.max_batch_size]]
uncached_responses.extend(batch_completion(self.model, batch, drop_params=True, **all_kwargs))
return uncached_responses

def _cache_response(self, response, hash):
"""Caches a response and updates stats if successful."""
if isinstance(response, OpenAIError):
raise response
self._update_stats(response)
self.cache.insert(hash, response)

def _hash_messages(self, messages: list[dict[str, str]], kwargs: dict[str, Any]) -> str:
"""Hash messages and kwargs to create a unique key for the cache"""
to_hash = str(self.model) + str(messages) + str(kwargs)
return hashlib.sha256(to_hash.encode()).hexdigest()

def _merge_responses(
self, cached_responses: list[ModelResponse | None], uncached_responses: list[ModelResponse]
) -> list[ModelResponse]:
"""Merge cached and uncached responses, maintaining order"""
uncached_iter = iter(uncached_responses)
return [resp if resp is not None else next(uncached_iter) for resp in cached_responses]

def _update_stats(self, response: ModelResponse):
if not hasattr(response, "usage"):
return
Expand Down Expand Up @@ -155,8 +184,12 @@ def print_total_usage(self):
print(f"Total prompt tokens: {self.stats.total_usage.prompt_tokens}")
print(f"Total completion tokens: {self.stats.total_usage.completion_tokens}")
print(f"Total tokens: {self.stats.total_usage.total_tokens}")
print(f"Total cache hits: {self.stats.total_usage.cache_hits}")

def reset_stats(self):
self.stats = LMStats(
total_usage=LMStats.TotalUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0, total_cost=0.0)
)

def reset_cache(self, max_size: int | None = None):
self.cache.reset(max_size)
2 changes: 1 addition & 1 deletion lotus/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,4 @@ def __repr__(self) -> str:

# set defaults
settings = Settings()
settings.configure(cascade_is_weight=0.5, cascade_num_calibration_quantiles=50)
settings.configure(cascade_is_weight=0.5, cascade_num_calibration_quantiles=50, enable_cache=False)
1 change: 1 addition & 0 deletions lotus/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class TotalUsage(BaseModel):
completion_tokens: int = 0
total_tokens: int = 0
total_cost: float = 0.0
cache_hits: int = 0

total_usage: TotalUsage = TotalUsage()

Expand Down
Loading