Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Independent operator level cache + doc update #74

Merged
merged 13 commits into from
Jan 13, 2025
17 changes: 10 additions & 7 deletions .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def test_custom_tokenizer():
@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_cache(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm, enable_message_cache=True)
lotus.settings.configure(lm=lm, enable_cache=True)

# Check that "What is the capital of France?" becomes cached
first_batch = [
Expand Down Expand Up @@ -428,19 +428,20 @@ def test_cache(setup_models, model):
@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_disable_cache(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm, enable_message_cache=False)
lotus.settings.configure(lm=lm, enable_cache=False)

batch = [
[{"role": "user", "content": "Hello, world!"}],
[{"role": "user", "content": "What is the capital of France?"}],
]

lm(batch)
assert lm.stats.total_usage.cache_hits == 0
lm(batch)
assert lm.stats.total_usage.cache_hits == 0

# Now enable cache. Note that the first batch is not cached.
lotus.settings.configure(enable_message_cache=True)
lotus.settings.configure(enable_cache=True)
first_responses = lm(batch).outputs
assert lm.stats.total_usage.cache_hits == 0
second_responses = lm(batch).outputs
Expand All @@ -451,7 +452,7 @@ def test_disable_cache(setup_models, model):
@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_reset_cache(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm, enable_message_cache=True)
lotus.settings.configure(lm=lm, enable_cache=True)

batch = [
[{"role": "user", "content": "Hello, world!"}],
Expand Down Expand Up @@ -481,7 +482,7 @@ def test_operator_cache(setup_models, model):
cache = CacheFactory.create_cache(cache_config)

lm = LM(model="gpt-4o-mini", cache=cache)
lotus.settings.configure(lm=lm, enable_message_cache=True, enable_operator_cache=True)
lotus.settings.configure(lm=lm, enable_cache=True)

data = {
"Course Name": [
Expand Down Expand Up @@ -537,7 +538,7 @@ def test_disable_operator_cache(setup_models, model):
cache = CacheFactory.create_cache(cache_config)

lm = LM(model="gpt-4o-mini", cache=cache)
lotus.settings.configure(lm=lm, enable_message_cache=True, enable_operator_cache=False)
lotus.settings.configure(lm=lm, enable_cache=False)

data = {
"Course Name": [
Expand Down Expand Up @@ -569,15 +570,17 @@ def test_disable_operator_cache(setup_models, model):
user_instruction = "What is a similar course to {Course Name}. Please just output the course name."

first_response = df.sem_map(user_instruction)
first_response["_map"] = first_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
assert lm.stats.total_usage.operator_cache_hits == 0

second_response = df.sem_map(user_instruction)
second_response["_map"] = second_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
assert lm.stats.total_usage.operator_cache_hits == 0

pd.testing.assert_frame_equal(first_response, second_response)

# Now enable operator cache.
lotus.settings.configure(enable_operator_cache=True)
lotus.settings.configure(enable_cache=True)
first_responses = df.sem_map(user_instruction)
first_responses["_map"] = first_responses["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
assert lm.stats.total_usage.operator_cache_hits == 0
Expand Down
11 changes: 10 additions & 1 deletion docs/configurations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,18 @@ Configurable Parameters
1. enable_message_cache:
* Description: Enables or Disables cahcing mechanisms
* Default: False
* Parameters:
- cache_type: Type of caching (SQLITE or In_MEMORY)
- max_size: maximum size of cache
- cache_dir: Directory for where DB file is stored. Default: "~/.lotus/cache"
* Note: It is recommended to enable caching
.. code-block:: python

lotus.settings.configure(enable_message_cache=True)
cache_config = CacheConfig(cache_type=CacheType.SQLITE, max_size=1000)
cache = CacheFactory.create_cache(cache_config)

lm = LM(model='gpt-4o-mini', cache=cache)
lotus.settings.configure(lm=lm, enable_cache=True)

2. setting RM:
* Description: Configures the retrieval model
Expand Down
38 changes: 26 additions & 12 deletions lotus/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def require_cache_enabled(func: Callable) -> Callable:

@wraps(func)
def wrapper(self, *args, **kwargs):
if not lotus.settings.enable_message_cache:
if not lotus.settings.enable_cache:
return None
return func(self, *args, **kwargs)

Expand All @@ -33,21 +33,39 @@ def operator_cache(func: Callable) -> Callable:
@wraps(func)
def wrapper(self, *args, **kwargs):
model = lotus.settings.lm
use_operator_cache = lotus.settings.enable_operator_cache
use_operator_cache = lotus.settings.enable_cache

if use_operator_cache and model.cache:

def serialize(value):
if isinstance(value, pd.DataFrame):
return value.to_json()
def serialize(value: Any) -> Any:
"""
Serialize a value into a JSON-serializable format.
Supports basic types, pandas DataFrames, and objects with a `dict` or `__dict__` method.
"""
if value is None or isinstance(value, (str, int, float, bool)):
return value
elif isinstance(value, pd.DataFrame):
return value.to_json(orient="split")
elif isinstance(value, (list, tuple)):
return [serialize(item) for item in value]
elif isinstance(value, dict):
return {key: serialize(val) for key, val in value.items()}
elif hasattr(value, "dict"):
return value.dict()
return value

elif hasattr(value, "__dict__"):
return {key: serialize(val) for key, val in vars(value).items() if not key.startswith("_")}
else:
# For unsupported types, convert to string (last resort)
lotus.logger.warning(f"Unsupported type {type(value)} for serialization. Converting to string.")
return str(value)

serialize_self = serialize(self)
serialized_kwargs = {key: serialize(value) for key, value in kwargs.items()}
serialized_args = [serialize(arg) for arg in args]
cache_key = hashlib.sha256(
json.dumps({"args": serialized_args, "kwargs": serialized_kwargs}, sort_keys=True).encode()
json.dumps(
{"self": serialize_self, "args": serialized_args, "kwargs": serialized_kwargs}, sort_keys=True
).encode()
).hexdigest()

cached_result = model.cache.get(cache_key)
Expand Down Expand Up @@ -134,7 +152,6 @@ def _create_table(self):
def _get_time(self):
return int(time.time())

@require_cache_enabled
def get(self, key: str) -> Any | None:
with self.conn:
cursor = self.conn.execute("SELECT value FROM cache WHERE key = ?", (key,))
Expand All @@ -152,7 +169,6 @@ def get(self, key: str) -> Any | None:
return value
return None

@require_cache_enabled
def insert(self, key: str, value: Any):
pickled_value = pickle.dumps(value)
with self.conn:
Expand Down Expand Up @@ -196,14 +212,12 @@ def __init__(self, max_size: int):
super().__init__(max_size)
self.cache: OrderedDict[str, Any] = OrderedDict()

@require_cache_enabled
def get(self, key: str) -> Any | None:
if key in self.cache:
lotus.logger.debug(f"Cache hit for {key}")

return self.cache.get(key)

@require_cache_enabled
def insert(self, key: str, value: Any):
self.cache[key] = value

Expand Down
31 changes: 20 additions & 11 deletions lotus/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,34 @@ def __call__(
if all_kwargs.get("logprobs", False):
all_kwargs.setdefault("top_logprobs", 10)

# Check cache and separate cached and uncached messages
hashed_messages = [self._hash_messages(msg, all_kwargs) for msg in messages]
cached_responses = [self.cache.get(hash) for hash in hashed_messages]
uncached_data = [
(msg, hash) for msg, hash, resp in zip(messages, hashed_messages, cached_responses) if resp is None
]
if lotus.settings.enable_cache:
# Check cache and separate cached and uncached messages
hashed_messages = [self._hash_messages(msg, all_kwargs) for msg in messages]
cached_responses = [self.cache.get(hash) for hash in hashed_messages]

uncached_data = (
[(msg, hash) for msg, hash, resp in zip(messages, hashed_messages, cached_responses) if resp is None]
if lotus.settings.enable_cache
else [(msg, "no-cache") for msg in messages]
)

self.stats.total_usage.cache_hits += len(messages) - len(uncached_data)

# Process uncached messages in batches
uncached_responses = self._process_uncached_messages(
uncached_data, all_kwargs, show_progress_bar, progress_bar_desc
)

# Add new responses to cache
for resp, (_, hash) in zip(uncached_responses, uncached_data):
self._cache_response(resp, hash)
if lotus.settings.enable_cache:
# Add new responses to cache
for resp, (_, hash) in zip(uncached_responses, uncached_data):
self._cache_response(resp, hash)

# Merge all responses in original order and extract outputs
all_responses = self._merge_responses(cached_responses, uncached_responses)
all_responses = (
self._merge_responses(cached_responses, uncached_responses)
if lotus.settings.enable_cache
else uncached_responses
)
outputs = [self._get_top_choice(resp) for resp in all_responses]
logprobs = (
[self._get_top_choice_logprobs(resp) for resp in all_responses] if all_kwargs.get("logprobs") else None
Expand Down
3 changes: 1 addition & 2 deletions lotus/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ class Settings:
reranker: lotus.models.Reranker | None = None

# Cache settings
enable_message_cache: bool = False
enable_operator_cache: bool = False
enable_cache: bool = False

# Serialization setting
serialization_format: SerializationFormat = SerializationFormat.DEFAULT
Expand Down
Loading