Skip to content

Commit

Permalink
Independent operator level cache + doc update (#74)
Browse files Browse the repository at this point in the history
Independent operator level caching + doc updates
  • Loading branch information
StanChan03 authored Jan 13, 2025
1 parent 9761855 commit 8a207aa
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 33 deletions.
17 changes: 10 additions & 7 deletions .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def test_custom_tokenizer():
@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_cache(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm, enable_message_cache=True)
lotus.settings.configure(lm=lm, enable_cache=True)

# Check that "What is the capital of France?" becomes cached
first_batch = [
Expand Down Expand Up @@ -428,19 +428,20 @@ def test_cache(setup_models, model):
@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_disable_cache(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm, enable_message_cache=False)
lotus.settings.configure(lm=lm, enable_cache=False)

batch = [
[{"role": "user", "content": "Hello, world!"}],
[{"role": "user", "content": "What is the capital of France?"}],
]

lm(batch)
assert lm.stats.total_usage.cache_hits == 0
lm(batch)
assert lm.stats.total_usage.cache_hits == 0

# Now enable cache. Note that the first batch is not cached.
lotus.settings.configure(enable_message_cache=True)
lotus.settings.configure(enable_cache=True)
first_responses = lm(batch).outputs
assert lm.stats.total_usage.cache_hits == 0
second_responses = lm(batch).outputs
Expand All @@ -451,7 +452,7 @@ def test_disable_cache(setup_models, model):
@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_reset_cache(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm, enable_message_cache=True)
lotus.settings.configure(lm=lm, enable_cache=True)

batch = [
[{"role": "user", "content": "Hello, world!"}],
Expand Down Expand Up @@ -481,7 +482,7 @@ def test_operator_cache(setup_models, model):
cache = CacheFactory.create_cache(cache_config)

lm = LM(model="gpt-4o-mini", cache=cache)
lotus.settings.configure(lm=lm, enable_message_cache=True, enable_operator_cache=True)
lotus.settings.configure(lm=lm, enable_cache=True)

data = {
"Course Name": [
Expand Down Expand Up @@ -537,7 +538,7 @@ def test_disable_operator_cache(setup_models, model):
cache = CacheFactory.create_cache(cache_config)

lm = LM(model="gpt-4o-mini", cache=cache)
lotus.settings.configure(lm=lm, enable_message_cache=True, enable_operator_cache=False)
lotus.settings.configure(lm=lm, enable_cache=False)

data = {
"Course Name": [
Expand Down Expand Up @@ -569,15 +570,17 @@ def test_disable_operator_cache(setup_models, model):
user_instruction = "What is a similar course to {Course Name}. Please just output the course name."

first_response = df.sem_map(user_instruction)
first_response["_map"] = first_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
assert lm.stats.total_usage.operator_cache_hits == 0

second_response = df.sem_map(user_instruction)
second_response["_map"] = second_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
assert lm.stats.total_usage.operator_cache_hits == 0

pd.testing.assert_frame_equal(first_response, second_response)

# Now enable operator cache.
lotus.settings.configure(enable_operator_cache=True)
lotus.settings.configure(enable_cache=True)
first_responses = df.sem_map(user_instruction)
first_responses["_map"] = first_responses["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
assert lm.stats.total_usage.operator_cache_hits == 0
Expand Down
11 changes: 10 additions & 1 deletion docs/configurations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,18 @@ Configurable Parameters
1. enable_message_cache:
* Description: Enables or Disables cahcing mechanisms
* Default: False
* Parameters:
- cache_type: Type of caching (SQLITE or In_MEMORY)
- max_size: maximum size of cache
- cache_dir: Directory for where DB file is stored. Default: "~/.lotus/cache"
* Note: It is recommended to enable caching
.. code-block:: python
lotus.settings.configure(enable_message_cache=True)
cache_config = CacheConfig(cache_type=CacheType.SQLITE, max_size=1000)
cache = CacheFactory.create_cache(cache_config)
lm = LM(model='gpt-4o-mini', cache=cache)
lotus.settings.configure(lm=lm, enable_cache=True)
2. setting RM:
* Description: Configures the retrieval model
Expand Down
38 changes: 26 additions & 12 deletions lotus/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def require_cache_enabled(func: Callable) -> Callable:

@wraps(func)
def wrapper(self, *args, **kwargs):
if not lotus.settings.enable_message_cache:
if not lotus.settings.enable_cache:
return None
return func(self, *args, **kwargs)

Expand All @@ -33,21 +33,39 @@ def operator_cache(func: Callable) -> Callable:
@wraps(func)
def wrapper(self, *args, **kwargs):
model = lotus.settings.lm
use_operator_cache = lotus.settings.enable_operator_cache
use_operator_cache = lotus.settings.enable_cache

if use_operator_cache and model.cache:

def serialize(value):
if isinstance(value, pd.DataFrame):
return value.to_json()
def serialize(value: Any) -> Any:
"""
Serialize a value into a JSON-serializable format.
Supports basic types, pandas DataFrames, and objects with a `dict` or `__dict__` method.
"""
if value is None or isinstance(value, (str, int, float, bool)):
return value
elif isinstance(value, pd.DataFrame):
return value.to_json(orient="split")
elif isinstance(value, (list, tuple)):
return [serialize(item) for item in value]
elif isinstance(value, dict):
return {key: serialize(val) for key, val in value.items()}
elif hasattr(value, "dict"):
return value.dict()
return value

elif hasattr(value, "__dict__"):
return {key: serialize(val) for key, val in vars(value).items() if not key.startswith("_")}
else:
# For unsupported types, convert to string (last resort)
lotus.logger.warning(f"Unsupported type {type(value)} for serialization. Converting to string.")
return str(value)

serialize_self = serialize(self)
serialized_kwargs = {key: serialize(value) for key, value in kwargs.items()}
serialized_args = [serialize(arg) for arg in args]
cache_key = hashlib.sha256(
json.dumps({"args": serialized_args, "kwargs": serialized_kwargs}, sort_keys=True).encode()
json.dumps(
{"self": serialize_self, "args": serialized_args, "kwargs": serialized_kwargs}, sort_keys=True
).encode()
).hexdigest()

cached_result = model.cache.get(cache_key)
Expand Down Expand Up @@ -134,7 +152,6 @@ def _create_table(self):
def _get_time(self):
return int(time.time())

@require_cache_enabled
def get(self, key: str) -> Any | None:
with self.conn:
cursor = self.conn.execute("SELECT value FROM cache WHERE key = ?", (key,))
Expand All @@ -152,7 +169,6 @@ def get(self, key: str) -> Any | None:
return value
return None

@require_cache_enabled
def insert(self, key: str, value: Any):
pickled_value = pickle.dumps(value)
with self.conn:
Expand Down Expand Up @@ -196,14 +212,12 @@ def __init__(self, max_size: int):
super().__init__(max_size)
self.cache: OrderedDict[str, Any] = OrderedDict()

@require_cache_enabled
def get(self, key: str) -> Any | None:
if key in self.cache:
lotus.logger.debug(f"Cache hit for {key}")

return self.cache.get(key)

@require_cache_enabled
def insert(self, key: str, value: Any):
self.cache[key] = value

Expand Down
31 changes: 20 additions & 11 deletions lotus/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,34 @@ def __call__(
if all_kwargs.get("logprobs", False):
all_kwargs.setdefault("top_logprobs", 10)

# Check cache and separate cached and uncached messages
hashed_messages = [self._hash_messages(msg, all_kwargs) for msg in messages]
cached_responses = [self.cache.get(hash) for hash in hashed_messages]
uncached_data = [
(msg, hash) for msg, hash, resp in zip(messages, hashed_messages, cached_responses) if resp is None
]
if lotus.settings.enable_cache:
# Check cache and separate cached and uncached messages
hashed_messages = [self._hash_messages(msg, all_kwargs) for msg in messages]
cached_responses = [self.cache.get(hash) for hash in hashed_messages]

uncached_data = (
[(msg, hash) for msg, hash, resp in zip(messages, hashed_messages, cached_responses) if resp is None]
if lotus.settings.enable_cache
else [(msg, "no-cache") for msg in messages]
)

self.stats.total_usage.cache_hits += len(messages) - len(uncached_data)

# Process uncached messages in batches
uncached_responses = self._process_uncached_messages(
uncached_data, all_kwargs, show_progress_bar, progress_bar_desc
)

# Add new responses to cache
for resp, (_, hash) in zip(uncached_responses, uncached_data):
self._cache_response(resp, hash)
if lotus.settings.enable_cache:
# Add new responses to cache
for resp, (_, hash) in zip(uncached_responses, uncached_data):
self._cache_response(resp, hash)

# Merge all responses in original order and extract outputs
all_responses = self._merge_responses(cached_responses, uncached_responses)
all_responses = (
self._merge_responses(cached_responses, uncached_responses)
if lotus.settings.enable_cache
else uncached_responses
)
outputs = [self._get_top_choice(resp) for resp in all_responses]
logprobs = (
[self._get_top_choice_logprobs(resp) for resp in all_responses] if all_kwargs.get("logprobs") else None
Expand Down
3 changes: 1 addition & 2 deletions lotus/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ class Settings:
reranker: lotus.models.Reranker | None = None

# Cache settings
enable_message_cache: bool = False
enable_operator_cache: bool = False
enable_cache: bool = False

# Serialization setting
serialization_format: SerializationFormat = SerializationFormat.DEFAULT
Expand Down

0 comments on commit 8a207aa

Please sign in to comment.