Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

operator level cache #65

Merged
merged 22 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions lotus/cache.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import hashlib
import json
import os
import pickle
import sqlite3
Expand All @@ -8,6 +10,8 @@
from functools import wraps
from typing import Any, Callable

import pandas as pd

import lotus


Expand All @@ -23,6 +27,44 @@ def wrapper(self, *args, **kwargs):
return wrapper


def operator_cache(func: Callable) -> Callable:
"""Decorator to add operator level caching."""

@wraps(func)
def wrapper(self, *args, **kwargs):
model = lotus.settings.lm
use_operator_cache = kwargs.get("use_operator_cache", False)

if use_operator_cache and model.cache:

def serialize(value):
if isinstance(value, pd.DataFrame):
return value.to_json()
elif hasattr(value, "dict"):
return value.dict()
return value

serialized_kwargs = {key: serialize(value) for key, value in kwargs.items()}
serialized_args = [serialize(arg) for arg in args]
cache_key = hashlib.sha256(
json.dumps({"args": serialized_args, "kwargs": serialized_kwargs}, sort_keys=True).encode()
).hexdigest()

cached_result = model.cache.get(cache_key)
if cached_result is not None:
print(f"Cache hit for {cache_key}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use lotus.logger rather than prints.

return cached_result
print(f"Cache miss for {cache_key}")

result = func(self, *args, **kwargs)
model.cache.insert(cache_key, result)
return result

return func(self, *args, **kwargs)

return wrapper


class CacheType(Enum):
IN_MEMORY = "in_memory"
SQLITE = "sqlite"
Expand Down
5 changes: 5 additions & 0 deletions lotus/sem_ops/sem_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd

import lotus.models
from lotus.cache import operator_cache
from lotus.templates import task_instructions
from lotus.types import LMOutput, SemanticAggOutput

Expand All @@ -14,6 +15,7 @@ def sem_agg(
partition_ids: list[int],
safe_mode: bool = False,
progress_bar_desc: str = "Aggregating",
use_operator_cache: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

High level question - should use_operator_cache be a parameter here or is it better off in Settings? I think there is an argument for the latter, since I see this pattern being used everywhere.

Copy link
Collaborator Author

@StanChan03 StanChan03 Dec 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

putting it in settings is better I'd say. I can put it in settings

) -> SemanticAggOutput:
"""
Aggregates multiple documents into a single answer using a model.
Expand Down Expand Up @@ -148,6 +150,7 @@ def process_group(args):
group, user_instruction, all_cols, suffix, progress_bar_desc = args
return group.sem_agg(user_instruction, all_cols, suffix, None, progress_bar_desc=progress_bar_desc)

@operator_cache
def __call__(
self,
user_instruction: str,
Expand All @@ -156,6 +159,7 @@ def __call__(
group_by: list[str] | None = None,
safe_mode: bool = False,
progress_bar_desc: str = "Aggregating",
use_operator_cache: bool = False,
) -> pd.DataFrame:
"""
Applies semantic aggregation over a dataframe.
Expand Down Expand Up @@ -213,6 +217,7 @@ def __call__(
partition_ids,
safe_mode=safe_mode,
progress_bar_desc=progress_bar_desc,
use_operator_cache=use_operator_cache,
)

# package answer in a dataframe
Expand Down
5 changes: 4 additions & 1 deletion lotus/sem_ops/sem_cluster_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pandas as pd

import lotus
from lotus.cache import operator_cache


@pd.api.extensions.register_dataframe_accessor("sem_cluster_by")
Expand All @@ -19,6 +20,7 @@ def _validate(obj: Any) -> None:
if not isinstance(obj, pd.DataFrame):
raise AttributeError("Must be a DataFrame")

@operator_cache
def __call__(
self,
col_name: str,
Expand All @@ -27,6 +29,7 @@ def __call__(
return_centroids: bool = False,
niter: int = 20,
verbose: bool = False,
use_operator_cache: bool = False,
) -> pd.DataFrame | tuple[pd.DataFrame, np.ndarray]:
"""
Perform semantic clustering on the DataFrame.
Expand All @@ -52,7 +55,7 @@ def __call__(
self._obj["cluster_id"] = pd.Series(indices, index=self._obj.index)
# if return_scores:
# self._obj["centroid_sim_score"] = pd.Series(scores, index=self._obj.index)

# if return_centroids:
# return self._obj, centroids
# else:
Expand Down
6 changes: 5 additions & 1 deletion lotus/sem_ops/sem_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd

import lotus
from lotus.cache import operator_cache
from lotus.models import LM
from lotus.templates import task_instructions
from lotus.types import LMOutput, SemanticExtractOutput, SemanticExtractPostprocessOutput
Expand All @@ -19,6 +20,7 @@ def sem_extract(
postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess,
safe_mode: bool = False,
progress_bar_desc: str = "Extracting",
use_operator_cache: bool = False,
) -> SemanticExtractOutput:
"""
Extracts attributes and values from a list of documents using a model.
Expand All @@ -33,7 +35,6 @@ def sem_extract(
Returns:
SemanticExtractOutput: The outputs, raw outputs, and quotes.
"""

# prepare model inputs
inputs = []
for doc in docs:
Expand Down Expand Up @@ -72,6 +73,7 @@ def _validate(obj: pd.DataFrame) -> None:
if not isinstance(obj, pd.DataFrame):
raise AttributeError("Must be a DataFrame")

@operator_cache
def __call__(
self,
input_cols: list[str],
Expand All @@ -81,6 +83,7 @@ def __call__(
return_raw_outputs: bool = False,
safe_mode: bool = False,
progress_bar_desc: str = "Extracting",
use_operator_cache: bool = False,
) -> pd.DataFrame:
"""
Extracts the attributes and values of a dataframe.
Expand Down Expand Up @@ -115,6 +118,7 @@ def __call__(
postprocessor=postprocessor,
safe_mode=safe_mode,
progress_bar_desc=progress_bar_desc,
use_operator_cache=use_operator_cache,
)

new_df = self._obj.copy()
Expand Down
8 changes: 8 additions & 0 deletions lotus/sem_ops/sem_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from numpy.typing import NDArray

import lotus
from lotus.cache import operator_cache
from lotus.templates import task_instructions
from lotus.types import CascadeArgs, LMOutput, LogprobsForFilterCascade, SemanticFilterOutput
from lotus.utils import show_safe_mode
Expand All @@ -26,6 +27,7 @@ def sem_filter(
safe_mode: bool = False,
show_progress_bar: bool = True,
progress_bar_desc: str = "Filtering",
use_operator_cache: bool = False,
) -> SemanticFilterOutput:
"""
Filters a list of documents based on a given user instruction using a language model.
Expand Down Expand Up @@ -103,6 +105,7 @@ def learn_filter_cascade_thresholds(
strategy=strategy,
safe_mode=False,
progress_bar_desc="Running oracle for threshold learning",
use_operator_cache=False,
).outputs

best_combination, _ = learn_cascade_thresholds(
Expand Down Expand Up @@ -134,6 +137,7 @@ def _validate(obj: Any) -> None:
if not isinstance(obj, pd.DataFrame):
raise AttributeError("Must be a DataFrame")

@operator_cache
def __call__(
self,
user_instruction: str,
Expand All @@ -148,6 +152,7 @@ def __call__(
return_stats: bool = False,
safe_mode: bool = False,
progress_bar_desc: str = "Filtering",
use_operator_cache: bool = False,
) -> pd.DataFrame | tuple[pd.DataFrame, dict[str, Any]]:
"""
Applies semantic filter over a dataframe.
Expand Down Expand Up @@ -245,6 +250,7 @@ def __call__(
safe_mode=safe_mode,
show_progress_bar=True,
progress_bar_desc="Running helper LM",
use_operator_cache=use_operator_cache,
)
helper_outputs, helper_logprobs = helper_output.outputs, helper_output.logprobs
assert helper_logprobs is not None
Expand Down Expand Up @@ -325,6 +331,7 @@ def __call__(
strategy=strategy,
safe_mode=safe_mode,
progress_bar_desc="Running predicate evals with oracle LM",
use_operator_cache=use_operator_cache,
)

for idx, large_idx in enumerate(low_conf_idxs):
Expand All @@ -348,6 +355,7 @@ def __call__(
safe_mode=safe_mode,
show_progress_bar=True,
progress_bar_desc=progress_bar_desc,
use_operator_cache=use_operator_cache,
)
outputs = output.outputs
raw_outputs = output.raw_outputs
Expand Down
10 changes: 10 additions & 0 deletions lotus/sem_ops/sem_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tqdm import tqdm

import lotus
from lotus.cache import operator_cache
from lotus.templates import task_instructions
from lotus.types import CascadeArgs, SemanticJoinOutput
from lotus.utils import show_safe_mode
Expand All @@ -29,6 +30,7 @@ def sem_join(
safe_mode: bool = False,
show_progress_bar: bool = True,
progress_bar_desc: str = "Join comparisons",
use_operator_cache: bool = False,
) -> SemanticJoinOutput:
"""
Joins two series using a model.
Expand Down Expand Up @@ -90,6 +92,7 @@ def sem_join(
default=default,
strategy=strategy,
show_progress_bar=False,
use_operator_cache=use_operator_cache,
)
outputs = output.outputs
raw_outputs = output.raw_outputs
Expand Down Expand Up @@ -139,6 +142,7 @@ def sem_join_cascade(
default: bool = True,
strategy: str | None = None,
safe_mode: bool = False,
use_operator_cache: bool = False,
) -> SemanticJoinOutput:
"""
Joins two series using a cascade helper model and a oracle model.
Expand Down Expand Up @@ -235,6 +239,7 @@ def sem_join_cascade(
default=default,
strategy=strategy,
show_progress_bar=False,
use_operator_cache=use_operator_cache,
)
pbar.update(num_large)
pbar.close()
Expand Down Expand Up @@ -513,6 +518,7 @@ def learn_join_cascade_threshold(
cot_reasoning=cot_reasoning,
strategy=strategy,
progress_bar_desc="Running oracle for threshold learning",
use_operator_cache=False,
)

(pos_threshold, neg_threshold), _ = learn_cascade_thresholds(
Expand Down Expand Up @@ -545,6 +551,7 @@ def _validate(obj: Any) -> None:
if not isinstance(obj, pd.DataFrame):
raise AttributeError("Must be a DataFrame")

@operator_cache
def __call__(
self,
other: pd.DataFrame | pd.Series,
Expand All @@ -559,6 +566,7 @@ def __call__(
return_stats: bool = False,
safe_mode: bool = False,
progress_bar_desc: str = "Join comparisons",
use_operator_cache: bool = False,
) -> pd.DataFrame:
"""
Applies semantic join over a dataframe.
Expand Down Expand Up @@ -672,6 +680,7 @@ def __call__(
default=default,
strategy=strategy,
safe_mode=safe_mode,
use_operator_cache=use_operator_cache,
)
else:
output = sem_join(
Expand All @@ -690,6 +699,7 @@ def __call__(
strategy=strategy,
safe_mode=safe_mode,
progress_bar_desc=progress_bar_desc,
use_operator_cache=use_operator_cache,
)
join_results = output.join_results
all_raw_outputs = output.all_raw_outputs
Expand Down
5 changes: 5 additions & 0 deletions lotus/sem_ops/sem_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd

import lotus
from lotus.cache import operator_cache
from lotus.templates import task_instructions
from lotus.types import LMOutput, SemanticMapOutput, SemanticMapPostprocessOutput
from lotus.utils import show_safe_mode
Expand All @@ -21,6 +22,7 @@ def sem_map(
strategy: str | None = None,
safe_mode: bool = False,
progress_bar_desc: str = "Mapping",
use_operator_cache: bool = False,
) -> SemanticMapOutput:
"""
Maps a list of documents to a list of outputs using a model.
Expand Down Expand Up @@ -80,6 +82,7 @@ def _validate(obj: pd.DataFrame) -> None:
if not isinstance(obj, pd.DataFrame):
raise AttributeError("Must be a DataFrame")

@operator_cache
def __call__(
self,
user_instruction: str,
Expand All @@ -91,6 +94,7 @@ def __call__(
strategy: str | None = None,
safe_mode: bool = False,
progress_bar_desc: str = "Mapping",
use_operator_cache: bool = False,
) -> pd.DataFrame:
"""
Applies semantic map over a dataframe.
Expand Down Expand Up @@ -145,6 +149,7 @@ def __call__(
strategy=strategy,
safe_mode=safe_mode,
progress_bar_desc=progress_bar_desc,
use_operator_cache=use_operator_cache,
)

new_df = self._obj.copy()
Expand Down
3 changes: 3 additions & 0 deletions lotus/sem_ops/sem_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd

import lotus
from lotus.cache import operator_cache
from lotus.types import RerankerOutput, RMOutput


Expand All @@ -19,6 +20,7 @@ def _validate(obj: Any) -> None:
if not isinstance(obj, pd.DataFrame):
raise AttributeError("Must be a DataFrame")

@operator_cache
def __call__(
self,
col_name: str,
Expand All @@ -27,6 +29,7 @@ def __call__(
n_rerank: int | None = None,
return_scores: bool = False,
suffix: str = "_sim_score",
use_operator_cache: bool = False,
) -> pd.DataFrame:
"""
Perform semantic search on the DataFrame.
Expand Down
Loading
Loading