Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

operator level cache #65

Merged
merged 22 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 122 additions & 4 deletions .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tokenizers import Tokenizer

import lotus
from lotus.cache import CacheConfig, CacheFactory, CacheType
from lotus.models import LM, SentenceTransformersRM
from lotus.types import CascadeArgs

Expand Down Expand Up @@ -398,7 +399,7 @@ def test_custom_tokenizer():
@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_cache(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm, enable_cache=True)
lotus.settings.configure(lm=lm, enable_message_cache=True)

# Check that "What is the capital of France?" becomes cached
first_batch = [
Expand Down Expand Up @@ -427,7 +428,7 @@ def test_cache(setup_models, model):
@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_disable_cache(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm, enable_cache=False)
lotus.settings.configure(lm=lm, enable_message_cache=False)

batch = [
[{"role": "user", "content": "Hello, world!"}],
Expand All @@ -439,7 +440,7 @@ def test_disable_cache(setup_models, model):
assert lm.stats.total_usage.cache_hits == 0

# Now enable cache. Note that the first batch is not cached.
lotus.settings.configure(enable_cache=True)
lotus.settings.configure(enable_message_cache=True)
first_responses = lm(batch).outputs
assert lm.stats.total_usage.cache_hits == 0
second_responses = lm(batch).outputs
Expand All @@ -450,7 +451,7 @@ def test_disable_cache(setup_models, model):
@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_reset_cache(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm, enable_cache=True)
lotus.settings.configure(lm=lm, enable_message_cache=True)

batch = [
[{"role": "user", "content": "Hello, world!"}],
Expand All @@ -472,3 +473,120 @@ def test_reset_cache(setup_models, model):
assert lm.stats.total_usage.cache_hits == 3
lm(batch)
assert lm.stats.total_usage.cache_hits == 3


@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_operator_cache(setup_models, model):
cache_config = CacheConfig(cache_type=CacheType.SQLITE, max_size=1000)
cache = CacheFactory.create_cache(cache_config)

lm = LM(model="gpt-4o-mini", cache=cache)
lotus.settings.configure(lm=lm, enable_message_cache=True, enable_operator_cache=True)

data = {
"Course Name": [
"Dynamics and Control of Chemical Processes",
"Optimization Methods in Engineering",
"Chemical Kinetics and Catalysis",
"Transport Phenomena and Separations",
]
}

expected_response = pd.DataFrame(
{
"Course Name": [
"Dynamics and Control of Chemical Processes",
"Optimization Methods in Engineering",
"Chemical Kinetics and Catalysis",
"Transport Phenomena and Separations",
],
"_map": [
"Process Dynamics and Control",
"Advanced Optimization Techniques in Engineering",
"Reaction Kinetics and Mechanisms",
"Fluid Mechanics and Mass Transfer",
],
}
)

df = pd.DataFrame(data)
user_instruction = "What is a similar course to {Course Name}. Please just output the course name."

first_response = df.sem_map(user_instruction)
assert lm.stats.total_usage.operator_cache_hits == 0

second_response = df.sem_map(user_instruction)
assert lm.stats.total_usage.operator_cache_hits == 1

first_response["_map"] = first_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
second_response["_map"] = second_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
expected_response["_map"] = expected_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()

pd.testing.assert_frame_equal(first_response, second_response)
pd.testing.assert_frame_equal(first_response, expected_response)
pd.testing.assert_frame_equal(second_response, expected_response)

lm.reset_cache()
lm.reset_stats()
assert lm.stats.total_usage.operator_cache_hits == 0


@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_disable_operator_cache(setup_models, model):
cache_config = CacheConfig(cache_type=CacheType.SQLITE, max_size=1000)
cache = CacheFactory.create_cache(cache_config)

lm = LM(model="gpt-4o-mini", cache=cache)
lotus.settings.configure(lm=lm, enable_message_cache=True, enable_operator_cache=False)

data = {
"Course Name": [
"Dynamics and Control of Chemical Processes",
"Optimization Methods in Engineering",
"Chemical Kinetics and Catalysis",
"Transport Phenomena and Separations",
]
}

expected_response = pd.DataFrame(
{
"Course Name": [
"Dynamics and Control of Chemical Processes",
"Optimization Methods in Engineering",
"Chemical Kinetics and Catalysis",
"Transport Phenomena and Separations",
],
"_map": [
"Process Dynamics and Control",
"Advanced Optimization Techniques in Engineering",
"Reaction Kinetics and Mechanisms",
"Fluid Mechanics and Mass Transfer",
],
}
)

df = pd.DataFrame(data)
user_instruction = "What is a similar course to {Course Name}. Please just output the course name."

first_response = df.sem_map(user_instruction)
assert lm.stats.total_usage.operator_cache_hits == 0

second_response = df.sem_map(user_instruction)
assert lm.stats.total_usage.operator_cache_hits == 0

pd.testing.assert_frame_equal(first_response, second_response)

# Now enable operator cache.
lotus.settings.configure(enable_operator_cache=True)
first_responses = df.sem_map(user_instruction)
first_responses["_map"] = first_responses["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
assert lm.stats.total_usage.operator_cache_hits == 0
second_responses = df.sem_map(user_instruction)
second_responses["_map"] = second_responses["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
assert lm.stats.total_usage.operator_cache_hits == 1

expected_response["_map"] = expected_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()

pd.testing.assert_frame_equal(first_responses, second_responses)
pd.testing.assert_frame_equal(first_responses, expected_response)
pd.testing.assert_frame_equal(second_responses, expected_response)
4 changes: 2 additions & 2 deletions docs/configurations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ Using the Settings module
Configurable Parameters
--------------------------

1. enable_cache:
1. enable_message_cache:
* Description: Enables or Disables cahcing mechanisms
* Default: False
.. code-block:: python

lotus.settings.configure(enable_cache=True)
lotus.settings.configure(enable_message_cache=True)

2. setting RM:
* Description: Configures the retrieval model
Expand Down
2 changes: 1 addition & 1 deletion examples/model_examples/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

lm = LM(model="gpt-4o-mini", cache=cache)

lotus.settings.configure(lm=lm, enable_cache=True) # default caching is False
lotus.settings.configure(lm=lm, enable_message_cache=True) # default caching is False
data = {
"Course Name": [
"Probability and Random Processes",
Expand Down
45 changes: 44 additions & 1 deletion lotus/cache.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import hashlib
import json
import os
import pickle
import sqlite3
Expand All @@ -8,6 +10,8 @@
from functools import wraps
from typing import Any, Callable

import pandas as pd

import lotus


Expand All @@ -16,13 +20,52 @@ def require_cache_enabled(func: Callable) -> Callable:

@wraps(func)
def wrapper(self, *args, **kwargs):
if not lotus.settings.enable_cache:
if not lotus.settings.enable_message_cache:
return None
return func(self, *args, **kwargs)

return wrapper


def operator_cache(func: Callable) -> Callable:
"""Decorator to add operator level caching."""

@wraps(func)
def wrapper(self, *args, **kwargs):
model = lotus.settings.lm
use_operator_cache = lotus.settings.enable_operator_cache

if use_operator_cache and model.cache:

def serialize(value):
if isinstance(value, pd.DataFrame):
return value.to_json()
elif hasattr(value, "dict"):
return value.dict()
return value

serialized_kwargs = {key: serialize(value) for key, value in kwargs.items()}
serialized_args = [serialize(arg) for arg in args]
cache_key = hashlib.sha256(
json.dumps({"args": serialized_args, "kwargs": serialized_kwargs}, sort_keys=True).encode()
).hexdigest()

cached_result = model.cache.get(cache_key)
if cached_result is not None:
lotus.logger.debug(f"Cache hit for {cache_key}")
model.stats.total_usage.operator_cache_hits += 1
return cached_result
lotus.logger.debug(f"Cache miss for {cache_key}")

result = func(self, *args, **kwargs)
model.cache.insert(cache_key, result)
return result

return func(self, *args, **kwargs)

return wrapper


class CacheType(Enum):
IN_MEMORY = "in_memory"
SQLITE = "sqlite"
Expand Down
2 changes: 2 additions & 0 deletions lotus/sem_ops/sem_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd

import lotus.models
from lotus.cache import operator_cache
from lotus.templates import task_instructions
from lotus.types import LMOutput, SemanticAggOutput

Expand Down Expand Up @@ -148,6 +149,7 @@ def process_group(args):
group, user_instruction, all_cols, suffix, progress_bar_desc = args
return group.sem_agg(user_instruction, all_cols, suffix, None, progress_bar_desc=progress_bar_desc)

@operator_cache
def __call__(
self,
user_instruction: str,
Expand Down
4 changes: 3 additions & 1 deletion lotus/sem_ops/sem_cluster_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pandas as pd

import lotus
from lotus.cache import operator_cache


@pd.api.extensions.register_dataframe_accessor("sem_cluster_by")
Expand All @@ -19,6 +20,7 @@ def _validate(obj: Any) -> None:
if not isinstance(obj, pd.DataFrame):
raise AttributeError("Must be a DataFrame")

@operator_cache
def __call__(
self,
col_name: str,
Expand Down Expand Up @@ -52,7 +54,7 @@ def __call__(
self._obj["cluster_id"] = pd.Series(indices, index=self._obj.index)
# if return_scores:
# self._obj["centroid_sim_score"] = pd.Series(scores, index=self._obj.index)

# if return_centroids:
# return self._obj, centroids
# else:
Expand Down
3 changes: 2 additions & 1 deletion lotus/sem_ops/sem_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd

import lotus
from lotus.cache import operator_cache
from lotus.models import LM
from lotus.templates import task_instructions
from lotus.types import LMOutput, SemanticExtractOutput, SemanticExtractPostprocessOutput
Expand Down Expand Up @@ -33,7 +34,6 @@ def sem_extract(
Returns:
SemanticExtractOutput: The outputs, raw outputs, and quotes.
"""

# prepare model inputs
inputs = []
for doc in docs:
Expand Down Expand Up @@ -72,6 +72,7 @@ def _validate(obj: pd.DataFrame) -> None:
if not isinstance(obj, pd.DataFrame):
raise AttributeError("Must be a DataFrame")

@operator_cache
def __call__(
self,
input_cols: list[str],
Expand Down
2 changes: 2 additions & 0 deletions lotus/sem_ops/sem_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from numpy.typing import NDArray

import lotus
from lotus.cache import operator_cache
from lotus.templates import task_instructions
from lotus.types import CascadeArgs, LMOutput, LogprobsForFilterCascade, SemanticFilterOutput
from lotus.utils import show_safe_mode
Expand Down Expand Up @@ -134,6 +135,7 @@ def _validate(obj: Any) -> None:
if not isinstance(obj, pd.DataFrame):
raise AttributeError("Must be a DataFrame")

@operator_cache
def __call__(
self,
user_instruction: str,
Expand Down
3 changes: 2 additions & 1 deletion lotus/sem_ops/sem_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tqdm import tqdm

import lotus
from lotus.cache import operator_cache
from lotus.templates import task_instructions
from lotus.types import CascadeArgs, SemanticJoinOutput
from lotus.utils import show_safe_mode
Expand Down Expand Up @@ -234,7 +235,6 @@ def sem_join_cascade(
cot_reasoning=cot_reasoning,
default=default,
strategy=strategy,
show_progress_bar=False,
)
pbar.update(num_large)
pbar.close()
Expand Down Expand Up @@ -545,6 +545,7 @@ def _validate(obj: Any) -> None:
if not isinstance(obj, pd.DataFrame):
raise AttributeError("Must be a DataFrame")

@operator_cache
def __call__(
self,
other: pd.DataFrame | pd.Series,
Expand Down
2 changes: 2 additions & 0 deletions lotus/sem_ops/sem_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd

import lotus
from lotus.cache import operator_cache
from lotus.templates import task_instructions
from lotus.types import LMOutput, SemanticMapOutput, SemanticMapPostprocessOutput
from lotus.utils import show_safe_mode
Expand Down Expand Up @@ -80,6 +81,7 @@ def _validate(obj: pd.DataFrame) -> None:
if not isinstance(obj, pd.DataFrame):
raise AttributeError("Must be a DataFrame")

@operator_cache
def __call__(
self,
user_instruction: str,
Expand Down
2 changes: 2 additions & 0 deletions lotus/sem_ops/sem_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd

import lotus
from lotus.cache import operator_cache
from lotus.types import RerankerOutput, RMOutput


Expand All @@ -19,6 +20,7 @@ def _validate(obj: Any) -> None:
if not isinstance(obj, pd.DataFrame):
raise AttributeError("Must be a DataFrame")

@operator_cache
def __call__(
self,
col_name: str,
Expand Down
Loading
Loading