From 9bd4b0954f9d9bd74066b3b84385ea0b86129043 Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Tue, 24 Dec 2024 21:58:55 -0800 Subject: [PATCH 01/21] sem_extract+sem_map level cache --- lotus/sem_ops/sem_extract.py | 29 +++++++++++++++++++++++++++- lotus/sem_ops/sem_map.py | 37 +++++++++++++++++++++++++++++++++++- 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/lotus/sem_ops/sem_extract.py b/lotus/sem_ops/sem_extract.py index 053dc5a..350f803 100644 --- a/lotus/sem_ops/sem_extract.py +++ b/lotus/sem_ops/sem_extract.py @@ -1,3 +1,5 @@ +import hashlib +import json from typing import Any, Callable import pandas as pd @@ -19,6 +21,7 @@ def sem_extract( postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess, safe_mode: bool = False, progress_bar_desc: str = "Extracting", + use_operator_cache: bool = False, ) -> SemanticExtractOutput: """ Extracts attributes and values from a list of documents using a model. @@ -34,6 +37,23 @@ def sem_extract( SemanticExtractOutput: The outputs, raw outputs, and quotes. """ + # prepare docs for serialization + cache_docs = [json.loads(json.dumps(doc, sort_keys=True)) for doc in docs] + output_cols = {key: value for key, value in sorted(output_cols.items())} + + # generate cache key + cache_key = hashlib.sha256( + json.dumps({"docs": cache_docs, "output_cols": output_cols, "extract_quotes": extract_quotes}).encode() + ).hexdigest() + + # check cache + if use_operator_cache and model.cache: + cached_result = model.cache.get(cache_key) + if cached_result is not None: + print(f"Cache hit for {cache_key}") + return cached_result + print(f"Cache miss for {cache_key}") + # prepare model inputs inputs = [] for doc in docs: @@ -58,7 +78,12 @@ def sem_extract( if safe_mode: model.print_total_usage() - return SemanticExtractOutput(**postprocess_output.model_dump()) + result = SemanticExtractOutput(**postprocess_output.model_dump()) + + if use_operator_cache and model.cache: + print(f"Inserting cache for {cache_key}") + model.cache.insert(cache_key, result) + return result @pd.api.extensions.register_dataframe_accessor("sem_extract") @@ -81,6 +106,7 @@ def __call__( return_raw_outputs: bool = False, safe_mode: bool = False, progress_bar_desc: str = "Extracting", + use_operator_cache: bool = False, ) -> pd.DataFrame: """ Extracts the attributes and values of a dataframe. @@ -115,6 +141,7 @@ def __call__( postprocessor=postprocessor, safe_mode=safe_mode, progress_bar_desc=progress_bar_desc, + use_operator_cache=use_operator_cache, ) new_df = self._obj.copy() diff --git a/lotus/sem_ops/sem_map.py b/lotus/sem_ops/sem_map.py index 9e8b991..f738806 100644 --- a/lotus/sem_ops/sem_map.py +++ b/lotus/sem_ops/sem_map.py @@ -1,3 +1,5 @@ +import hashlib +import json from typing import Any, Callable import pandas as pd @@ -21,6 +23,7 @@ def sem_map( strategy: str | None = None, safe_mode: bool = False, progress_bar_desc: str = "Mapping", + use_operator_cache: bool = False, ) -> SemanticMapOutput: """ Maps a list of documents to a list of outputs using a model. @@ -37,6 +40,31 @@ def sem_map( Returns: SemanticMapOutput: The outputs, raw outputs, and explanations. """ + # prepare docs for serialization + cache_docs = [json.loads(json.dumps(docs, sort_keys=True)) for doc in docs] + + # generate cache key + cache_key = hashlib.sha256( + json.dumps( + { + "docs": cache_docs, + "user_instruction": user_instruction, + "examples_multimodal_data": examples_multimodal_data, + "examples_answers": examples_answers, + "cot_reasoning": cot_reasoning, + "strategy": strategy, + } + ).encode() + ).hexdigest() + + # check cache + if use_operator_cache and model.cache: + cache_result = model.cache.get(cache_key) + if cache_result is not None: + print(f'Cache hit for "{cache_key}"') + return cache_result + print(f'Cache miss for "{cache_key}"') + # prepare model inputs inputs = [] for doc in docs: @@ -64,7 +92,12 @@ def sem_map( if safe_mode: model.print_total_usage() - return SemanticMapOutput(**postprocess_output.model_dump()) + result = SemanticMapOutput(**postprocess_output.model_dump()) + + if use_operator_cache and model.cache: + model.cache.insert(cache_key, result) + + return result @pd.api.extensions.register_dataframe_accessor("sem_map") @@ -91,6 +124,7 @@ def __call__( strategy: str | None = None, safe_mode: bool = False, progress_bar_desc: str = "Mapping", + use_operator_cache: bool = False, ) -> pd.DataFrame: """ Applies semantic map over a dataframe. @@ -145,6 +179,7 @@ def __call__( strategy=strategy, safe_mode=safe_mode, progress_bar_desc=progress_bar_desc, + use_operator_cache=use_operator_cache, ) new_df = self._obj.copy() From 2057704555f447b7dbc71b58bcc507a04d04ea51 Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Wed, 25 Dec 2024 15:17:04 -0800 Subject: [PATCH 02/21] general operator cache --- lotus/cache.py | 30 ++++++++++++++++++++++++++++++ lotus/sem_ops/sem_extract.py | 29 +++-------------------------- lotus/sem_ops/sem_map.py | 36 +++--------------------------------- 3 files changed, 36 insertions(+), 59 deletions(-) diff --git a/lotus/cache.py b/lotus/cache.py index 74cadd5..5d0d195 100644 --- a/lotus/cache.py +++ b/lotus/cache.py @@ -1,3 +1,5 @@ +import hashlib +import json import os import pickle import sqlite3 @@ -23,6 +25,34 @@ def wrapper(self, *args, **kwargs): return wrapper +def operator_cache(func: Callable) -> Callable: + """Decorator to add operator level caching.""" + + @wraps(func) + def wrapper(self, *args, **kwargs): + model = lotus.settings.lm + use_operator_cache = kwargs.get("use_operator_cache", False) + + if use_operator_cache and model.cache: + cache_key = hashlib.sha256( + json.dumps({"args": args, "kwargs": kwargs}, sort_keys=True).encode() + ).hexdigest() + + cached_result = model.cache.get(cache_key) + if cached_result is not None: + print(f"Cache hit for {cache_key}") + return cached_result + print(f"Cache miss for {cache_key}") + + result = func(self, *args, **kwargs) + model.cache.insert(cache_key, result) + return result + + return func(self, *args, **kwargs) + + return wrapper + + class CacheType(Enum): IN_MEMORY = "in_memory" SQLITE = "sqlite" diff --git a/lotus/sem_ops/sem_extract.py b/lotus/sem_ops/sem_extract.py index 350f803..42a6918 100644 --- a/lotus/sem_ops/sem_extract.py +++ b/lotus/sem_ops/sem_extract.py @@ -1,10 +1,9 @@ -import hashlib -import json from typing import Any, Callable import pandas as pd import lotus +from lotus.cache import operator_cache from lotus.models import LM from lotus.templates import task_instructions from lotus.types import LMOutput, SemanticExtractOutput, SemanticExtractPostprocessOutput @@ -36,24 +35,6 @@ def sem_extract( Returns: SemanticExtractOutput: The outputs, raw outputs, and quotes. """ - - # prepare docs for serialization - cache_docs = [json.loads(json.dumps(doc, sort_keys=True)) for doc in docs] - output_cols = {key: value for key, value in sorted(output_cols.items())} - - # generate cache key - cache_key = hashlib.sha256( - json.dumps({"docs": cache_docs, "output_cols": output_cols, "extract_quotes": extract_quotes}).encode() - ).hexdigest() - - # check cache - if use_operator_cache and model.cache: - cached_result = model.cache.get(cache_key) - if cached_result is not None: - print(f"Cache hit for {cache_key}") - return cached_result - print(f"Cache miss for {cache_key}") - # prepare model inputs inputs = [] for doc in docs: @@ -78,12 +59,7 @@ def sem_extract( if safe_mode: model.print_total_usage() - result = SemanticExtractOutput(**postprocess_output.model_dump()) - - if use_operator_cache and model.cache: - print(f"Inserting cache for {cache_key}") - model.cache.insert(cache_key, result) - return result + return SemanticExtractOutput(**postprocess_output.model_dump()) @pd.api.extensions.register_dataframe_accessor("sem_extract") @@ -97,6 +73,7 @@ def _validate(obj: pd.DataFrame) -> None: if not isinstance(obj, pd.DataFrame): raise AttributeError("Must be a DataFrame") + @operator_cache def __call__( self, input_cols: list[str], diff --git a/lotus/sem_ops/sem_map.py b/lotus/sem_ops/sem_map.py index f738806..46dc84a 100644 --- a/lotus/sem_ops/sem_map.py +++ b/lotus/sem_ops/sem_map.py @@ -1,10 +1,9 @@ -import hashlib -import json from typing import Any, Callable import pandas as pd import lotus +from lotus.cache import operator_cache from lotus.templates import task_instructions from lotus.types import LMOutput, SemanticMapOutput, SemanticMapPostprocessOutput from lotus.utils import show_safe_mode @@ -40,31 +39,6 @@ def sem_map( Returns: SemanticMapOutput: The outputs, raw outputs, and explanations. """ - # prepare docs for serialization - cache_docs = [json.loads(json.dumps(docs, sort_keys=True)) for doc in docs] - - # generate cache key - cache_key = hashlib.sha256( - json.dumps( - { - "docs": cache_docs, - "user_instruction": user_instruction, - "examples_multimodal_data": examples_multimodal_data, - "examples_answers": examples_answers, - "cot_reasoning": cot_reasoning, - "strategy": strategy, - } - ).encode() - ).hexdigest() - - # check cache - if use_operator_cache and model.cache: - cache_result = model.cache.get(cache_key) - if cache_result is not None: - print(f'Cache hit for "{cache_key}"') - return cache_result - print(f'Cache miss for "{cache_key}"') - # prepare model inputs inputs = [] for doc in docs: @@ -92,12 +66,7 @@ def sem_map( if safe_mode: model.print_total_usage() - result = SemanticMapOutput(**postprocess_output.model_dump()) - - if use_operator_cache and model.cache: - model.cache.insert(cache_key, result) - - return result + return SemanticMapOutput(**postprocess_output.model_dump()) @pd.api.extensions.register_dataframe_accessor("sem_map") @@ -113,6 +82,7 @@ def _validate(obj: pd.DataFrame) -> None: if not isinstance(obj, pd.DataFrame): raise AttributeError("Must be a DataFrame") + @operator_cache def __call__( self, user_instruction: str, From 03947bd55621279e5ccb814cb039f2678c4940ef Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Wed, 25 Dec 2024 15:24:40 -0800 Subject: [PATCH 03/21] sem_agg operator cache --- lotus/sem_ops/sem_agg.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lotus/sem_ops/sem_agg.py b/lotus/sem_ops/sem_agg.py index dfb934b..cea5a65 100644 --- a/lotus/sem_ops/sem_agg.py +++ b/lotus/sem_ops/sem_agg.py @@ -3,6 +3,7 @@ import pandas as pd import lotus.models +from lotus.cache import operator_cache from lotus.templates import task_instructions from lotus.types import LMOutput, SemanticAggOutput @@ -14,6 +15,7 @@ def sem_agg( partition_ids: list[int], safe_mode: bool = False, progress_bar_desc: str = "Aggregating", + use_operator_cache: bool = False, ) -> SemanticAggOutput: """ Aggregates multiple documents into a single answer using a model. @@ -143,6 +145,7 @@ def __init__(self, pandas_obj: Any): def _validate(obj: Any) -> None: pass + @operator_cache def __call__( self, user_instruction: str, @@ -151,6 +154,7 @@ def __call__( group_by: list[str] | None = None, safe_mode: bool = False, progress_bar_desc: str = "Aggregating", + use_operator_cache: bool = False, ) -> pd.DataFrame: """ Applies semantic aggregation over a dataframe. @@ -181,9 +185,6 @@ def __call__( if column not in self._obj.columns: raise ValueError(f"column {column} not found in DataFrame. Given usr instruction: {user_instruction}") - - - if group_by: grouped = self._obj.groupby(group_by) new_df = pd.DataFrame() @@ -191,9 +192,7 @@ def __call__( res = group.sem_agg(user_instruction, all_cols, suffix, None, progress_bar_desc=progress_bar_desc) new_df = pd.concat([new_df, res]) return new_df - - - + # Sort df by partition_id if it exists if "_lotus_partition_id" in self._obj.columns: self._obj = self._obj.sort_values(by="_lotus_partition_id") @@ -213,6 +212,7 @@ def __call__( partition_ids, safe_mode=safe_mode, progress_bar_desc=progress_bar_desc, + use_operator_cache=use_operator_cache, ) # package answer in a dataframe From 24f21b3db84b40c6a3837bec03e3d4dbd12a65d8 Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Wed, 25 Dec 2024 20:30:05 -0800 Subject: [PATCH 04/21] sem_filter operator cache + kwarg serilization --- lotus/cache.py | 13 ++++++++++++- lotus/sem_ops/sem_filter.py | 8 ++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/lotus/cache.py b/lotus/cache.py index 5d0d195..990b57c 100644 --- a/lotus/cache.py +++ b/lotus/cache.py @@ -10,6 +10,8 @@ from functools import wraps from typing import Any, Callable +import pandas as pd + import lotus @@ -34,8 +36,17 @@ def wrapper(self, *args, **kwargs): use_operator_cache = kwargs.get("use_operator_cache", False) if use_operator_cache and model.cache: + + def serialize(value): + if isinstance(value, pd.DataFrame): + return value.to_json() + elif hasattr(value, "dict"): + return value.dict() + return value + + serilized_kwargs = {key: serialize(value) for key, value in kwargs.items()} cache_key = hashlib.sha256( - json.dumps({"args": args, "kwargs": kwargs}, sort_keys=True).encode() + json.dumps({"args": args, "kwargs": serilized_kwargs}, sort_keys=True).encode() ).hexdigest() cached_result = model.cache.get(cache_key) diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index c03a064..b6b7b19 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -5,6 +5,7 @@ from numpy.typing import NDArray import lotus +from lotus.cache import operator_cache from lotus.templates import task_instructions from lotus.types import CascadeArgs, LMOutput, LogprobsForFilterCascade, SemanticFilterOutput from lotus.utils import show_safe_mode @@ -26,6 +27,7 @@ def sem_filter( safe_mode: bool = False, show_progress_bar: bool = True, progress_bar_desc: str = "Filtering", + use_operator_cache: bool = False, ) -> SemanticFilterOutput: """ Filters a list of documents based on a given user instruction using a language model. @@ -103,6 +105,7 @@ def learn_filter_cascade_thresholds( strategy=strategy, safe_mode=False, progress_bar_desc="Running oracle for threshold learning", + use_operator_cache=False, ).outputs best_combination, _ = learn_cascade_thresholds( @@ -134,6 +137,7 @@ def _validate(obj: Any) -> None: if not isinstance(obj, pd.DataFrame): raise AttributeError("Must be a DataFrame") + @operator_cache def __call__( self, user_instruction: str, @@ -148,6 +152,7 @@ def __call__( return_stats: bool = False, safe_mode: bool = False, progress_bar_desc: str = "Filtering", + use_operator_cache: bool = False, ) -> pd.DataFrame | tuple[pd.DataFrame, dict[str, Any]]: """ Applies semantic filter over a dataframe. @@ -245,6 +250,7 @@ def __call__( safe_mode=safe_mode, show_progress_bar=True, progress_bar_desc="Running helper LM", + use_operator_cache=use_operator_cache, ) helper_outputs, helper_logprobs = helper_output.outputs, helper_output.logprobs assert helper_logprobs is not None @@ -325,6 +331,7 @@ def __call__( strategy=strategy, safe_mode=safe_mode, progress_bar_desc="Running predicate evals with oracle LM", + use_operator_cache=use_operator_cache, ) for idx, large_idx in enumerate(low_conf_idxs): @@ -348,6 +355,7 @@ def __call__( safe_mode=safe_mode, show_progress_bar=True, progress_bar_desc=progress_bar_desc, + use_operator_cache=use_operator_cache, ) outputs = output.outputs raw_outputs = output.raw_outputs From 7ecd48db5a4811e0302a5ea72f8c862060617a99 Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Wed, 25 Dec 2024 20:32:00 -0800 Subject: [PATCH 05/21] typo --- lotus/cache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lotus/cache.py b/lotus/cache.py index 990b57c..92a50df 100644 --- a/lotus/cache.py +++ b/lotus/cache.py @@ -44,9 +44,9 @@ def serialize(value): return value.dict() return value - serilized_kwargs = {key: serialize(value) for key, value in kwargs.items()} + serialized_kwargs = {key: serialize(value) for key, value in kwargs.items()} cache_key = hashlib.sha256( - json.dumps({"args": args, "kwargs": serilized_kwargs}, sort_keys=True).encode() + json.dumps({"args": args, "kwargs": serialized_kwargs}, sort_keys=True).encode() ).hexdigest() cached_result = model.cache.get(cache_key) From 813d83c6f13c91beeb9e95bf8a1aa38094581993 Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Thu, 26 Dec 2024 15:15:41 -0800 Subject: [PATCH 06/21] sem_join + sem_cluster operator level cache --- lotus/cache.py | 3 ++- lotus/sem_ops/sem_cluster_by.py | 5 ++++- lotus/sem_ops/sem_join.py | 10 ++++++++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/lotus/cache.py b/lotus/cache.py index 92a50df..94bd1ba 100644 --- a/lotus/cache.py +++ b/lotus/cache.py @@ -45,8 +45,9 @@ def serialize(value): return value serialized_kwargs = {key: serialize(value) for key, value in kwargs.items()} + serialized_args = [serialize(arg) for arg in args] cache_key = hashlib.sha256( - json.dumps({"args": args, "kwargs": serialized_kwargs}, sort_keys=True).encode() + json.dumps({"args": serialized_args, "kwargs": serialized_kwargs}, sort_keys=True).encode() ).hexdigest() cached_result = model.cache.get(cache_key) diff --git a/lotus/sem_ops/sem_cluster_by.py b/lotus/sem_ops/sem_cluster_by.py index 5811101..4764db1 100644 --- a/lotus/sem_ops/sem_cluster_by.py +++ b/lotus/sem_ops/sem_cluster_by.py @@ -4,6 +4,7 @@ import pandas as pd import lotus +from lotus.cache import operator_cache @pd.api.extensions.register_dataframe_accessor("sem_cluster_by") @@ -19,6 +20,7 @@ def _validate(obj: Any) -> None: if not isinstance(obj, pd.DataFrame): raise AttributeError("Must be a DataFrame") + @operator_cache def __call__( self, col_name: str, @@ -27,6 +29,7 @@ def __call__( return_centroids: bool = False, niter: int = 20, verbose: bool = False, + use_operator_cache: bool = False, ) -> pd.DataFrame | tuple[pd.DataFrame, np.ndarray]: """ Perform semantic clustering on the DataFrame. @@ -52,7 +55,7 @@ def __call__( self._obj["cluster_id"] = pd.Series(indices, index=self._obj.index) # if return_scores: # self._obj["centroid_sim_score"] = pd.Series(scores, index=self._obj.index) - + # if return_centroids: # return self._obj, centroids # else: diff --git a/lotus/sem_ops/sem_join.py b/lotus/sem_ops/sem_join.py index abb2765..82f47f7 100644 --- a/lotus/sem_ops/sem_join.py +++ b/lotus/sem_ops/sem_join.py @@ -4,6 +4,7 @@ from tqdm import tqdm import lotus +from lotus.cache import operator_cache from lotus.templates import task_instructions from lotus.types import CascadeArgs, SemanticJoinOutput from lotus.utils import show_safe_mode @@ -29,6 +30,7 @@ def sem_join( safe_mode: bool = False, show_progress_bar: bool = True, progress_bar_desc: str = "Join comparisons", + use_operator_cache: bool = False, ) -> SemanticJoinOutput: """ Joins two series using a model. @@ -90,6 +92,7 @@ def sem_join( default=default, strategy=strategy, show_progress_bar=False, + use_operator_cache=use_operator_cache, ) outputs = output.outputs raw_outputs = output.raw_outputs @@ -139,6 +142,7 @@ def sem_join_cascade( default: bool = True, strategy: str | None = None, safe_mode: bool = False, + use_operator_cache: bool = False, ) -> SemanticJoinOutput: """ Joins two series using a cascade helper model and a oracle model. @@ -235,6 +239,7 @@ def sem_join_cascade( default=default, strategy=strategy, show_progress_bar=False, + use_operator_cache=use_operator_cache, ) pbar.update(num_large) pbar.close() @@ -513,6 +518,7 @@ def learn_join_cascade_threshold( cot_reasoning=cot_reasoning, strategy=strategy, progress_bar_desc="Running oracle for threshold learning", + use_operator_cache=False, ) (pos_threshold, neg_threshold), _ = learn_cascade_thresholds( @@ -545,6 +551,7 @@ def _validate(obj: Any) -> None: if not isinstance(obj, pd.DataFrame): raise AttributeError("Must be a DataFrame") + @operator_cache def __call__( self, other: pd.DataFrame | pd.Series, @@ -559,6 +566,7 @@ def __call__( return_stats: bool = False, safe_mode: bool = False, progress_bar_desc: str = "Join comparisons", + use_operator_cache: bool = False, ) -> pd.DataFrame: """ Applies semantic join over a dataframe. @@ -672,6 +680,7 @@ def __call__( default=default, strategy=strategy, safe_mode=safe_mode, + use_operator_cache=use_operator_cache, ) else: output = sem_join( @@ -690,6 +699,7 @@ def __call__( strategy=strategy, safe_mode=safe_mode, progress_bar_desc=progress_bar_desc, + use_operator_cache=use_operator_cache, ) join_results = output.join_results all_raw_outputs = output.all_raw_outputs From 335df10758cb6c7d523466231d7a9eeffe559fea Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Thu, 26 Dec 2024 16:40:17 -0800 Subject: [PATCH 07/21] sem_search + sim_sem_join + sem_topk --- lotus/sem_ops/sem_search.py | 3 +++ lotus/sem_ops/sem_sim_join.py | 3 +++ lotus/sem_ops/sem_topk.py | 5 ++++- 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/lotus/sem_ops/sem_search.py b/lotus/sem_ops/sem_search.py index 5846cc1..fd20de9 100644 --- a/lotus/sem_ops/sem_search.py +++ b/lotus/sem_ops/sem_search.py @@ -3,6 +3,7 @@ import pandas as pd import lotus +from lotus.cache import operator_cache from lotus.types import RerankerOutput, RMOutput @@ -19,6 +20,7 @@ def _validate(obj: Any) -> None: if not isinstance(obj, pd.DataFrame): raise AttributeError("Must be a DataFrame") + @operator_cache def __call__( self, col_name: str, @@ -27,6 +29,7 @@ def __call__( n_rerank: int | None = None, return_scores: bool = False, suffix: str = "_sim_score", + use_operator_cache: bool = False, ) -> pd.DataFrame: """ Perform semantic search on the DataFrame. diff --git a/lotus/sem_ops/sem_sim_join.py b/lotus/sem_ops/sem_sim_join.py index b1fd986..7d30d41 100644 --- a/lotus/sem_ops/sem_sim_join.py +++ b/lotus/sem_ops/sem_sim_join.py @@ -3,6 +3,7 @@ import pandas as pd import lotus +from lotus.cache import operator_cache from lotus.models import RM from lotus.types import RMOutput @@ -20,6 +21,7 @@ def _validate(obj: Any) -> None: if not isinstance(obj, pd.DataFrame): raise AttributeError("Must be a DataFrame") + @operator_cache def __call__( self, other: pd.DataFrame, @@ -30,6 +32,7 @@ def __call__( rsuffix: str = "", score_suffix: str = "", keep_index: bool = False, + use_operator_cache: bool = False, ) -> pd.DataFrame: """ Perform semantic similarity join on the DataFrame. diff --git a/lotus/sem_ops/sem_topk.py b/lotus/sem_ops/sem_topk.py index c92e81d..18101c5 100644 --- a/lotus/sem_ops/sem_topk.py +++ b/lotus/sem_ops/sem_topk.py @@ -7,6 +7,7 @@ from tqdm import tqdm import lotus +from lotus.cache import operator_cache from lotus.templates import task_instructions from lotus.types import LMOutput, SemanticTopKOutput from lotus.utils import show_safe_mode @@ -386,6 +387,7 @@ def process_group(args): return_stats=return_stats, ) + @operator_cache def __call__( self, user_instruction: str, @@ -396,6 +398,7 @@ def __call__( cascade_threshold: float | None = None, return_stats: bool = False, safe_mode: bool = False, + use_operator_cache: bool = False, ) -> pd.DataFrame | tuple[pd.DataFrame, dict[str, Any]]: """ Sorts the DataFrame based on the user instruction and returns the top K rows. @@ -438,7 +441,7 @@ def __call__( with ThreadPoolExecutor(max_workers=lotus.settings.parallel_groupby_max_threads) as executor: results = list(executor.map(SemTopKDataframe.process_group, group_args)) - + if return_stats: new_df = pd.concat([res[0] for res in results]) stats = {name: res[1] for name, res in zip(grouped.groups.keys(), results)} From ebced3c9d948ca6645a73ca54e864a57f3966072 Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Thu, 26 Dec 2024 17:08:29 -0800 Subject: [PATCH 08/21] enable_operator_cache in settings --- .../op_examples/~/.lotus/cache/lotus_cache.db | Bin 0 -> 45056 bytes lotus/cache.py | 2 +- lotus/sem_ops/sem_agg.py | 3 --- lotus/sem_ops/sem_cluster_by.py | 1 - lotus/sem_ops/sem_extract.py | 3 --- lotus/sem_ops/sem_filter.py | 6 ------ lotus/sem_ops/sem_join.py | 9 --------- lotus/sem_ops/sem_map.py | 3 --- lotus/sem_ops/sem_search.py | 1 - lotus/sem_ops/sem_sim_join.py | 1 - lotus/sem_ops/sem_topk.py | 1 - lotus/settings.py | 1 + 12 files changed, 2 insertions(+), 29 deletions(-) create mode 100644 examples/op_examples/~/.lotus/cache/lotus_cache.db diff --git a/examples/op_examples/~/.lotus/cache/lotus_cache.db b/examples/op_examples/~/.lotus/cache/lotus_cache.db new file mode 100644 index 0000000000000000000000000000000000000000..cc70c38696d3bd753def7749af8dfe98717624a6 GIT binary patch literal 45056 zcmeHQNo?fCdTy;Vqmf1)Z{xB3jLyW19J_nCFEfzsmRf6T-$oK(@D<5olM=hx;v%W_ z;>45pyumc=SNXab9r&b@^RL1+zRruZ03VE{zpd#wY<(rW zurxJWSlY;z#x`C(Hs&s}1DqY1nIAd+x{W+vL7Ij=59`^fxs|brv8DS`dimwf-fP3x zIx<#O$NS!n4a{Caf$t>0m87T&oqTiwmKS3@d*J`V;IHGamk0l6@IMFteemA~|7Gx> z(x5n)90UP^06~BtKoB4Z5CjMU1Ob8oL4Y7Y5FiMgcLe$_KnXAP!s%i{_2UKn-kAOO zbi%2-qc6TP@I_|eiz}ZG-W>e%E5EsN?aG^j0R)qSAV3fx2oMAa0t5kq06~BtKoB4Z z5CqOM0%P5G2Ray0M;wPJL*rRiWjKU5nx$x3rf8j0DNfgD9y6TE@iND<21JlGLl6uO za|~lp(#0edgcy$DL|&x~RuvgqlxU4cEW;TxrSdWa@Q4;^gX49b*JPRFX%=Btz_P$| zvZ!B3Qc(p_mt@RpqRfg&)Sx^p)>%%2_&lf40>^UhRAyrQ01y)iiL1K;G zBo&UOIZ@XIjin`lm09>?upz3n#?d^h(+tf(5}Iy^m@!12rZgmBL5=?;MpuoVv_`CE zFqFhfhR9Nkfe{jr%Il)Y83N*EF)ojxv2f@-)@6#5RVacgXi#-o#YT6MiX>99%BqHp zX#*KL(nL|_Wt9;m1F41vI0;DNWvC3p%1|_!ktAdoJYqy72$ay3qyj|{B~?}t%P~6A z4P6ur(J(Ys)fvjrc}dV1sGgv60t4-UDNTnwDb-L7o-+{cEJ90Pk^p*vqGbjN(CEAv zqb6Z!a9R>MNn|*g7GzBVX<$WK$GQaVhiS^7Wgcl*GzLJHJ33ynaI~FH9J?ZeAe5yrb?bFPijVKV?&r9H>;?KgoAPvUeH{O ze5^;){a;qfVNWH_A_|`@*Zdsk0SFX}Bv*lLFsO`^q-Pu7TazAw z{GF;#UFA8_MQ z1cuMnF><}EW9v8ipqkyTW8)~iggnpke5h^|_W2GJL_;<-_Z(msRJ9n}BqW*(%xiGl zwH9;?3){M;hWWhtlKJ8%;o~i_pec5tjUmchpa$m4A4ldZ@C}9i0A(sn zO2vyN-uT0h|8eL){_>L~pP{fPXkyv3ntW?V^6F9BE55Uht6o&1!-1L$BHTBWBK5=_3jI_B#j!~@ATEWMdwDPxYzh5c3q;T86imq+ zgQCSKJ=@sfmGt!(A*)&o4yy!tKGlc-=;-%e3<5Rm_prMQ-B$4}XadMSsehk&!MtSN zH2cjf<|B=)t|+FZ>$s}akc&X5;_M&O0_dp~1r!=Ile25-kO*j!8m=NM3NO3392csJ zDj1~xQKaiu+>RhLCX@%eX(~_e=Uk9t+p1W8EVAkjd|~MN9V)=HhaIZpz{P zn&aVIvk}7n_1sd5cy6i~h1s@y)7H#6ZW&drgQV0@{{RFhXO*$ZQd@#nVavOkRAz&{ zs#>c%QTSWOB>^H5Z)c3#4Fo9cl45k9IWFcd$VOTxz?RCtqpjO)R;>6 z!dJv?G_w@VoS z(rCnZqVPQYPD|tPvbG%r*+1Ksu4d#C<(Gru z&<~*G7m6Emh1sdmd*wNXM01>N3rFjpg0-olXojLzh6^0UFSveL8JE+*ejNQ7sR$+EWi+Fl^TG1DF)?~WO?M>gAS*-9o^LNG-k((>16$erK6E(dsA+K)C zkKzrz0M$jlIehbW7CMv#?K9la3OUdUxl|{xO`UK%t0y20=TXhdr=2REW!=tB`hIP` zYS(XPmt(Xse1gJUhKb_voKwZGW&wi$bnCEfSC?HiFD|dnjCco|cd)s`N!Vs-xE- zQV?EDCg%!t+qv)}h=-ygpaJ2NCP;>N>*mdNiGjKIVNROk`O(Dd{rf$C{me7}9E4Yo zCVoxr9uL&r$*5(P!~P~tNx4ibyn6p&V{SvM9E~exCH%>Q#}Cs9A3JW8Fn7Yo8{zG$ z-}J&OjSEdpZwBFm<^iD5PafDoiGS7f0QB~QM|=~v=bviZBfbB#_WaY!O*Q{k*x#mn z&EF5NpR7vFyTGyTQHbWx!@d^XXTB3Yc|wme-wS&ioyYtj>}_-s^TY6wX44*a0rR8q zLQ~+)j{}(R<;#^X2U7n3LZ;(|rYMuo`w;kSy)$P2L*1u+W&`a1=V!(KMQUviG1mGh z-m>&jo-dY{oRwvNIi4>*tnB~oW7+?Rvj?5#KANAKF@)veEc-$Ecm}HGlc&iKIu9O{;x}%PV?}H zUQ$_D5|a@^JR{SZ%D_Va@dw2J%j^3a96DmJUjj3sXBA z*xtcw1%9lB`X2KC>cXnu&pjR2h+m*!Y_-f^(BKT^1}x}KM}=wgkX z!OOeFZB<>IwUs+nUOcEwFOG0)tIEtsrMj}WVUKFn&GPEh%y-WJxt09(T&ccA71tJb zY8!j=C99tIXP2k(sxw_O_g5-~vWHmTy~DMD0E?|Z7XFW5sa8~3!v6{XPsZ&g|0ewZ zoAZCN{(m3mZx8rvt*f#Af6BcdXSx1AQJS$CeO=i0$2N-0a6_5T_-Y3d`ep^>{INwzbwIWmrF z^|i%XxiU+6<*}WK;*z}`Y?{k^E!6jr|L6H-X<@NCrR^+N>GgSTt++F$E+_}H%$~ik zD)!8pKU?0ciGd6Jey3h6F3i`&wJm?r7+Ig3rfcP@ZLe@!iyPD|w=y*~vs;{9T3kF@ zME`Eq|2JwA-o)f0(ic3tNSkP4b)BDFnOT}b>Op<{U~X**abwiTjC4?F0fA@59}EA- zI`DsuCj6i9f5QLW=rrs9{jX=>j~oO6f&f8)AV3fx2oMAa0t5kq072kE2)y@NH>`_& zhqr*as}fXembPtUMUC%? zBR^3MDz$oAm`YIG6ygxIVSc3H#L?f@)hA!1JwF}8Bh(4Kzcl|b&GA9OBukTBd&I-Gd{lqmcm2fQ|&DYPVB)fIyHaj zQ>u7RnTj3T!|kRN_B6jYO&U-H28vf>#LdSJ@rEDIB=}v{v8TCKVu=D0(P|Lf{u6eg zJ=%ITUD|68WEz1<^}17w#^7D3hIuQzeq0^Eh7atO!@i~Xv!&qX^c-N~BiD}&n^Q~Y zQdj&1*P>EKk|vH523ZG*fo=e{8>rbOlnNNl#sL6c*n@mOv7rom6|kTw$L5sbHLx*; z&CZG*q(+8Vv5|_w1owR}-T(hsvj4wpDg%GyAP5iyzAXZuay>8-{^?Jn@7wR=c=SF$4(`2v9DYIkIK(b1j&(~x{5bqT{Wyr=^5Gl#ao~s_ zhp+9&ffK+v1M%a~@K|ydb{w)KOROe< zOyCFTE|I@Rg|C7rihxS;%I9s2r zNTY@NUa+%YSt>ou^#6r6`rqo{p~3&bBgFaS|8M`@d#`l%PZ9o4>^O)W$2algNcjI5 z`f*@cc)CLP{~7XsL!?zY_J+}^?$-(2?D1=;8X5Gy8eIKx2Bxs`v2}~eNL^ag&MbSF0O8h>*HR1 za&jHqd>{FUe7M*DTO2$z_&+#=IG_CgoxNY(?S1~&g#Q!%PxwFK|L5dA4xCKh<8ZjN z7@o=cKP$=@UR*}8n5MCe4Mu~PIKa!kILu275B}^QzV&~Bl4(}d;7tqUbpXwmJ*BS$ zApD>3|M;yz;1ux4|NlMf|NVcFfj@E(1PB5I0fGQQfFM8+AP5iy2m%CwbB(~e!xv%M zNX`AjE?6Ck?zjKv PzM=g;y!83|wg3M=E Callable: @wraps(func) def wrapper(self, *args, **kwargs): model = lotus.settings.lm - use_operator_cache = kwargs.get("use_operator_cache", False) + use_operator_cache = lotus.settings.enable_operator_cache if use_operator_cache and model.cache: diff --git a/lotus/sem_ops/sem_agg.py b/lotus/sem_ops/sem_agg.py index fb31763..706f12f 100644 --- a/lotus/sem_ops/sem_agg.py +++ b/lotus/sem_ops/sem_agg.py @@ -15,7 +15,6 @@ def sem_agg( partition_ids: list[int], safe_mode: bool = False, progress_bar_desc: str = "Aggregating", - use_operator_cache: bool = False, ) -> SemanticAggOutput: """ Aggregates multiple documents into a single answer using a model. @@ -159,7 +158,6 @@ def __call__( group_by: list[str] | None = None, safe_mode: bool = False, progress_bar_desc: str = "Aggregating", - use_operator_cache: bool = False, ) -> pd.DataFrame: """ Applies semantic aggregation over a dataframe. @@ -217,7 +215,6 @@ def __call__( partition_ids, safe_mode=safe_mode, progress_bar_desc=progress_bar_desc, - use_operator_cache=use_operator_cache, ) # package answer in a dataframe diff --git a/lotus/sem_ops/sem_cluster_by.py b/lotus/sem_ops/sem_cluster_by.py index 4764db1..fc8a9c6 100644 --- a/lotus/sem_ops/sem_cluster_by.py +++ b/lotus/sem_ops/sem_cluster_by.py @@ -29,7 +29,6 @@ def __call__( return_centroids: bool = False, niter: int = 20, verbose: bool = False, - use_operator_cache: bool = False, ) -> pd.DataFrame | tuple[pd.DataFrame, np.ndarray]: """ Perform semantic clustering on the DataFrame. diff --git a/lotus/sem_ops/sem_extract.py b/lotus/sem_ops/sem_extract.py index 42a6918..ed9619e 100644 --- a/lotus/sem_ops/sem_extract.py +++ b/lotus/sem_ops/sem_extract.py @@ -20,7 +20,6 @@ def sem_extract( postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess, safe_mode: bool = False, progress_bar_desc: str = "Extracting", - use_operator_cache: bool = False, ) -> SemanticExtractOutput: """ Extracts attributes and values from a list of documents using a model. @@ -83,7 +82,6 @@ def __call__( return_raw_outputs: bool = False, safe_mode: bool = False, progress_bar_desc: str = "Extracting", - use_operator_cache: bool = False, ) -> pd.DataFrame: """ Extracts the attributes and values of a dataframe. @@ -118,7 +116,6 @@ def __call__( postprocessor=postprocessor, safe_mode=safe_mode, progress_bar_desc=progress_bar_desc, - use_operator_cache=use_operator_cache, ) new_df = self._obj.copy() diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index b6b7b19..d6253b8 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -27,7 +27,6 @@ def sem_filter( safe_mode: bool = False, show_progress_bar: bool = True, progress_bar_desc: str = "Filtering", - use_operator_cache: bool = False, ) -> SemanticFilterOutput: """ Filters a list of documents based on a given user instruction using a language model. @@ -105,7 +104,6 @@ def learn_filter_cascade_thresholds( strategy=strategy, safe_mode=False, progress_bar_desc="Running oracle for threshold learning", - use_operator_cache=False, ).outputs best_combination, _ = learn_cascade_thresholds( @@ -152,7 +150,6 @@ def __call__( return_stats: bool = False, safe_mode: bool = False, progress_bar_desc: str = "Filtering", - use_operator_cache: bool = False, ) -> pd.DataFrame | tuple[pd.DataFrame, dict[str, Any]]: """ Applies semantic filter over a dataframe. @@ -250,7 +247,6 @@ def __call__( safe_mode=safe_mode, show_progress_bar=True, progress_bar_desc="Running helper LM", - use_operator_cache=use_operator_cache, ) helper_outputs, helper_logprobs = helper_output.outputs, helper_output.logprobs assert helper_logprobs is not None @@ -331,7 +327,6 @@ def __call__( strategy=strategy, safe_mode=safe_mode, progress_bar_desc="Running predicate evals with oracle LM", - use_operator_cache=use_operator_cache, ) for idx, large_idx in enumerate(low_conf_idxs): @@ -355,7 +350,6 @@ def __call__( safe_mode=safe_mode, show_progress_bar=True, progress_bar_desc=progress_bar_desc, - use_operator_cache=use_operator_cache, ) outputs = output.outputs raw_outputs = output.raw_outputs diff --git a/lotus/sem_ops/sem_join.py b/lotus/sem_ops/sem_join.py index 82f47f7..0050f49 100644 --- a/lotus/sem_ops/sem_join.py +++ b/lotus/sem_ops/sem_join.py @@ -30,7 +30,6 @@ def sem_join( safe_mode: bool = False, show_progress_bar: bool = True, progress_bar_desc: str = "Join comparisons", - use_operator_cache: bool = False, ) -> SemanticJoinOutput: """ Joins two series using a model. @@ -92,7 +91,6 @@ def sem_join( default=default, strategy=strategy, show_progress_bar=False, - use_operator_cache=use_operator_cache, ) outputs = output.outputs raw_outputs = output.raw_outputs @@ -142,7 +140,6 @@ def sem_join_cascade( default: bool = True, strategy: str | None = None, safe_mode: bool = False, - use_operator_cache: bool = False, ) -> SemanticJoinOutput: """ Joins two series using a cascade helper model and a oracle model. @@ -238,8 +235,6 @@ def sem_join_cascade( cot_reasoning=cot_reasoning, default=default, strategy=strategy, - show_progress_bar=False, - use_operator_cache=use_operator_cache, ) pbar.update(num_large) pbar.close() @@ -518,7 +513,6 @@ def learn_join_cascade_threshold( cot_reasoning=cot_reasoning, strategy=strategy, progress_bar_desc="Running oracle for threshold learning", - use_operator_cache=False, ) (pos_threshold, neg_threshold), _ = learn_cascade_thresholds( @@ -566,7 +560,6 @@ def __call__( return_stats: bool = False, safe_mode: bool = False, progress_bar_desc: str = "Join comparisons", - use_operator_cache: bool = False, ) -> pd.DataFrame: """ Applies semantic join over a dataframe. @@ -680,7 +673,6 @@ def __call__( default=default, strategy=strategy, safe_mode=safe_mode, - use_operator_cache=use_operator_cache, ) else: output = sem_join( @@ -699,7 +691,6 @@ def __call__( strategy=strategy, safe_mode=safe_mode, progress_bar_desc=progress_bar_desc, - use_operator_cache=use_operator_cache, ) join_results = output.join_results all_raw_outputs = output.all_raw_outputs diff --git a/lotus/sem_ops/sem_map.py b/lotus/sem_ops/sem_map.py index 46dc84a..9708bb1 100644 --- a/lotus/sem_ops/sem_map.py +++ b/lotus/sem_ops/sem_map.py @@ -22,7 +22,6 @@ def sem_map( strategy: str | None = None, safe_mode: bool = False, progress_bar_desc: str = "Mapping", - use_operator_cache: bool = False, ) -> SemanticMapOutput: """ Maps a list of documents to a list of outputs using a model. @@ -94,7 +93,6 @@ def __call__( strategy: str | None = None, safe_mode: bool = False, progress_bar_desc: str = "Mapping", - use_operator_cache: bool = False, ) -> pd.DataFrame: """ Applies semantic map over a dataframe. @@ -149,7 +147,6 @@ def __call__( strategy=strategy, safe_mode=safe_mode, progress_bar_desc=progress_bar_desc, - use_operator_cache=use_operator_cache, ) new_df = self._obj.copy() diff --git a/lotus/sem_ops/sem_search.py b/lotus/sem_ops/sem_search.py index fd20de9..de2df35 100644 --- a/lotus/sem_ops/sem_search.py +++ b/lotus/sem_ops/sem_search.py @@ -29,7 +29,6 @@ def __call__( n_rerank: int | None = None, return_scores: bool = False, suffix: str = "_sim_score", - use_operator_cache: bool = False, ) -> pd.DataFrame: """ Perform semantic search on the DataFrame. diff --git a/lotus/sem_ops/sem_sim_join.py b/lotus/sem_ops/sem_sim_join.py index 7d30d41..47d3cbe 100644 --- a/lotus/sem_ops/sem_sim_join.py +++ b/lotus/sem_ops/sem_sim_join.py @@ -32,7 +32,6 @@ def __call__( rsuffix: str = "", score_suffix: str = "", keep_index: bool = False, - use_operator_cache: bool = False, ) -> pd.DataFrame: """ Perform semantic similarity join on the DataFrame. diff --git a/lotus/sem_ops/sem_topk.py b/lotus/sem_ops/sem_topk.py index 18101c5..b5ecd5e 100644 --- a/lotus/sem_ops/sem_topk.py +++ b/lotus/sem_ops/sem_topk.py @@ -398,7 +398,6 @@ def __call__( cascade_threshold: float | None = None, return_stats: bool = False, safe_mode: bool = False, - use_operator_cache: bool = False, ) -> pd.DataFrame | tuple[pd.DataFrame, dict[str, Any]]: """ Sorts the DataFrame based on the user instruction and returns the top K rows. diff --git a/lotus/settings.py b/lotus/settings.py index ce12363..ba0b05b 100644 --- a/lotus/settings.py +++ b/lotus/settings.py @@ -13,6 +13,7 @@ class Settings: # Cache settings enable_cache: bool = False + enable_operator_cache: bool = False # Serialization setting serialization_format: SerializationFormat = SerializationFormat.DEFAULT From 1542ed233463cf3ede5658ea79b89b6a10345319 Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Thu, 26 Dec 2024 17:20:49 -0800 Subject: [PATCH 09/21] updates + tests --- .github/tests/lm_tests.py | 47 +++++++++++++++++++++++++++++--- docs/configurations.rst | 4 +-- examples/model_examples/cache.py | 2 +- lotus/cache.py | 6 ++-- lotus/settings.py | 2 +- tests/test_settings.py | 6 ++-- 6 files changed, 53 insertions(+), 14 deletions(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index 94fcd59..e504233 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -398,7 +398,7 @@ def test_custom_tokenizer(): @pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) def test_cache(setup_models, model): lm = setup_models[model] - lotus.settings.configure(lm=lm, enable_cache=True) + lotus.settings.configure(lm=lm, enable_message_cache=True) # Check that "What is the capital of France?" becomes cached first_batch = [ @@ -427,7 +427,7 @@ def test_cache(setup_models, model): @pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) def test_disable_cache(setup_models, model): lm = setup_models[model] - lotus.settings.configure(lm=lm, enable_cache=False) + lotus.settings.configure(lm=lm, enable_message_cache=False) batch = [ [{"role": "user", "content": "Hello, world!"}], @@ -439,7 +439,7 @@ def test_disable_cache(setup_models, model): assert lm.stats.total_usage.cache_hits == 0 # Now enable cache. Note that the first batch is not cached. - lotus.settings.configure(enable_cache=True) + lotus.settings.configure(enable_message_cache=True) first_responses = lm(batch).outputs assert lm.stats.total_usage.cache_hits == 0 second_responses = lm(batch).outputs @@ -450,7 +450,7 @@ def test_disable_cache(setup_models, model): @pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) def test_reset_cache(setup_models, model): lm = setup_models[model] - lotus.settings.configure(lm=lm, enable_cache=True) + lotus.settings.configure(lm=lm, enable_message_cache=True) batch = [ [{"role": "user", "content": "Hello, world!"}], @@ -472,3 +472,42 @@ def test_reset_cache(setup_models, model): assert lm.stats.total_usage.cache_hits == 3 lm(batch) assert lm.stats.total_usage.cache_hits == 3 + + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) +def test_operator_cache(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm, enable_message_cache=True, enable_operator_cache=True) + + batch = [ + [{"role": "user", "content": "Hello, world!"}], + [{"role": "user", "content": "What is the capital of France?"}], + ] + first_responses = lm(batch).outputs + assert lm.stats.total_usage.cache_hits == 0 + second_responses = lm(batch).outputs + assert lm.stats.total_usage.cache_hits == 2 + assert first_responses == second_responses + + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) +def test_disable_operator_cache(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm, enable_message_cache=True, enable_operator_cache=False) + + batch = [ + [{"role": "user", "content": "Hello, world!"}], + [{"role": "user", "content": "What is the capital of France?"}], + ] + lm(batch) + assert lm.stats.total_usage.cache_hits == 0 + lm(batch) + assert lm.stats.total_usage.cache_hits == 0 + + # Now enable operator cache. Note that the first batch is not cached. + lotus.settings.configure(enable_operator_cache=True) + first_responses = lm(batch).outputs + assert lm.stats.total_usage.cache_hits == 0 + second_responses = lm(batch).outputs + assert lm.stats.total_usage.cache_hits == 2 + assert first_responses == second_responses diff --git a/docs/configurations.rst b/docs/configurations.rst index 79c5a46..0ee0172 100644 --- a/docs/configurations.rst +++ b/docs/configurations.rst @@ -21,12 +21,12 @@ Using the Settings module Configurable Parameters -------------------------- -1. enable_cache: +1. enable_message_cache: * Description: Enables or Disables cahcing mechanisms * Default: False .. code-block:: python - lotus.settings.configure(enable_cache=True) + lotus.settings.configure(enable_message_cache=True) 2. setting RM: * Description: Configures the retrieval model diff --git a/examples/model_examples/cache.py b/examples/model_examples/cache.py index 5314ca5..95bc282 100644 --- a/examples/model_examples/cache.py +++ b/examples/model_examples/cache.py @@ -11,7 +11,7 @@ lm = LM(model="gpt-4o-mini", cache=cache) -lotus.settings.configure(lm=lm, enable_cache=True) # default caching is False +lotus.settings.configure(lm=lm, enable_message_cache=True) # default caching is False data = { "Course Name": [ "Probability and Random Processes", diff --git a/lotus/cache.py b/lotus/cache.py index a5c90fb..a732a8f 100644 --- a/lotus/cache.py +++ b/lotus/cache.py @@ -20,7 +20,7 @@ def require_cache_enabled(func: Callable) -> Callable: @wraps(func) def wrapper(self, *args, **kwargs): - if not lotus.settings.enable_cache: + if not lotus.settings.enable_message_cache: return None return func(self, *args, **kwargs) @@ -52,9 +52,9 @@ def serialize(value): cached_result = model.cache.get(cache_key) if cached_result is not None: - print(f"Cache hit for {cache_key}") + lotus.logger.debug(f"Cache hit for {cache_key}") return cached_result - print(f"Cache miss for {cache_key}") + lotus.logger.debug(f"Cache miss for {cache_key}") result = func(self, *args, **kwargs) model.cache.insert(cache_key, result) diff --git a/lotus/settings.py b/lotus/settings.py index ba0b05b..99e5944 100644 --- a/lotus/settings.py +++ b/lotus/settings.py @@ -12,7 +12,7 @@ class Settings: reranker: lotus.models.Reranker | None = None # Cache settings - enable_cache: bool = False + enable_message_cache: bool = False enable_operator_cache: bool = False # Serialization setting diff --git a/tests/test_settings.py b/tests/test_settings.py index dc6f871..4f251bb 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -13,12 +13,12 @@ def test_initial_values(self, settings): assert settings.rm is None assert settings.helper_lm is None assert settings.reranker is None - assert settings.enable_cache is False + assert settings.enable_message_cache is False assert settings.serialization_format == SerializationFormat.DEFAULT def test_configure_method(self, settings): - settings.configure(enable_cache=True) - assert settings.enable_cache is True + settings.configure(enable_message_cache=True) + assert settings.enable_message_cache is True def test_invalid_setting(self, settings): with pytest.raises(ValueError, match="Invalid setting: invalid_setting"): From a5761e3f02c9e9ed134392316c0fbd2528829cce Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Thu, 26 Dec 2024 17:22:32 -0800 Subject: [PATCH 10/21] remove file --- .../op_examples/~/.lotus/cache/lotus_cache.db | Bin 45056 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 examples/op_examples/~/.lotus/cache/lotus_cache.db diff --git a/examples/op_examples/~/.lotus/cache/lotus_cache.db b/examples/op_examples/~/.lotus/cache/lotus_cache.db deleted file mode 100644 index cc70c38696d3bd753def7749af8dfe98717624a6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 45056 zcmeHQNo?fCdTy;Vqmf1)Z{xB3jLyW19J_nCFEfzsmRf6T-$oK(@D<5olM=hx;v%W_ z;>45pyumc=SNXab9r&b@^RL1+zRruZ03VE{zpd#wY<(rW zurxJWSlY;z#x`C(Hs&s}1DqY1nIAd+x{W+vL7Ij=59`^fxs|brv8DS`dimwf-fP3x zIx<#O$NS!n4a{Caf$t>0m87T&oqTiwmKS3@d*J`V;IHGamk0l6@IMFteemA~|7Gx> z(x5n)90UP^06~BtKoB4Z5CjMU1Ob8oL4Y7Y5FiMgcLe$_KnXAP!s%i{_2UKn-kAOO zbi%2-qc6TP@I_|eiz}ZG-W>e%E5EsN?aG^j0R)qSAV3fx2oMAa0t5kq06~BtKoB4Z z5CqOM0%P5G2Ray0M;wPJL*rRiWjKU5nx$x3rf8j0DNfgD9y6TE@iND<21JlGLl6uO za|~lp(#0edgcy$DL|&x~RuvgqlxU4cEW;TxrSdWa@Q4;^gX49b*JPRFX%=Btz_P$| zvZ!B3Qc(p_mt@RpqRfg&)Sx^p)>%%2_&lf40>^UhRAyrQ01y)iiL1K;G zBo&UOIZ@XIjin`lm09>?upz3n#?d^h(+tf(5}Iy^m@!12rZgmBL5=?;MpuoVv_`CE zFqFhfhR9Nkfe{jr%Il)Y83N*EF)ojxv2f@-)@6#5RVacgXi#-o#YT6MiX>99%BqHp zX#*KL(nL|_Wt9;m1F41vI0;DNWvC3p%1|_!ktAdoJYqy72$ay3qyj|{B~?}t%P~6A z4P6ur(J(Ys)fvjrc}dV1sGgv60t4-UDNTnwDb-L7o-+{cEJ90Pk^p*vqGbjN(CEAv zqb6Z!a9R>MNn|*g7GzBVX<$WK$GQaVhiS^7Wgcl*GzLJHJ33ynaI~FH9J?ZeAe5yrb?bFPijVKV?&r9H>;?KgoAPvUeH{O ze5^;){a;qfVNWH_A_|`@*Zdsk0SFX}Bv*lLFsO`^q-Pu7TazAw z{GF;#UFA8_MQ z1cuMnF><}EW9v8ipqkyTW8)~iggnpke5h^|_W2GJL_;<-_Z(msRJ9n}BqW*(%xiGl zwH9;?3){M;hWWhtlKJ8%;o~i_pec5tjUmchpa$m4A4ldZ@C}9i0A(sn zO2vyN-uT0h|8eL){_>L~pP{fPXkyv3ntW?V^6F9BE55Uht6o&1!-1L$BHTBWBK5=_3jI_B#j!~@ATEWMdwDPxYzh5c3q;T86imq+ zgQCSKJ=@sfmGt!(A*)&o4yy!tKGlc-=;-%e3<5Rm_prMQ-B$4}XadMSsehk&!MtSN zH2cjf<|B=)t|+FZ>$s}akc&X5;_M&O0_dp~1r!=Ile25-kO*j!8m=NM3NO3392csJ zDj1~xQKaiu+>RhLCX@%eX(~_e=Uk9t+p1W8EVAkjd|~MN9V)=HhaIZpz{P zn&aVIvk}7n_1sd5cy6i~h1s@y)7H#6ZW&drgQV0@{{RFhXO*$ZQd@#nVavOkRAz&{ zs#>c%QTSWOB>^H5Z)c3#4Fo9cl45k9IWFcd$VOTxz?RCtqpjO)R;>6 z!dJv?G_w@VoS z(rCnZqVPQYPD|tPvbG%r*+1Ksu4d#C<(Gru z&<~*G7m6Emh1sdmd*wNXM01>N3rFjpg0-olXojLzh6^0UFSveL8JE+*ejNQ7sR$+EWi+Fl^TG1DF)?~WO?M>gAS*-9o^LNG-k((>16$erK6E(dsA+K)C zkKzrz0M$jlIehbW7CMv#?K9la3OUdUxl|{xO`UK%t0y20=TXhdr=2REW!=tB`hIP` zYS(XPmt(Xse1gJUhKb_voKwZGW&wi$bnCEfSC?HiFD|dnjCco|cd)s`N!Vs-xE- zQV?EDCg%!t+qv)}h=-ygpaJ2NCP;>N>*mdNiGjKIVNROk`O(Dd{rf$C{me7}9E4Yo zCVoxr9uL&r$*5(P!~P~tNx4ibyn6p&V{SvM9E~exCH%>Q#}Cs9A3JW8Fn7Yo8{zG$ z-}J&OjSEdpZwBFm<^iD5PafDoiGS7f0QB~QM|=~v=bviZBfbB#_WaY!O*Q{k*x#mn z&EF5NpR7vFyTGyTQHbWx!@d^XXTB3Yc|wme-wS&ioyYtj>}_-s^TY6wX44*a0rR8q zLQ~+)j{}(R<;#^X2U7n3LZ;(|rYMuo`w;kSy)$P2L*1u+W&`a1=V!(KMQUviG1mGh z-m>&jo-dY{oRwvNIi4>*tnB~oW7+?Rvj?5#KANAKF@)veEc-$Ecm}HGlc&iKIu9O{;x}%PV?}H zUQ$_D5|a@^JR{SZ%D_Va@dw2J%j^3a96DmJUjj3sXBA z*xtcw1%9lB`X2KC>cXnu&pjR2h+m*!Y_-f^(BKT^1}x}KM}=wgkX z!OOeFZB<>IwUs+nUOcEwFOG0)tIEtsrMj}WVUKFn&GPEh%y-WJxt09(T&ccA71tJb zY8!j=C99tIXP2k(sxw_O_g5-~vWHmTy~DMD0E?|Z7XFW5sa8~3!v6{XPsZ&g|0ewZ zoAZCN{(m3mZx8rvt*f#Af6BcdXSx1AQJS$CeO=i0$2N-0a6_5T_-Y3d`ep^>{INwzbwIWmrF z^|i%XxiU+6<*}WK;*z}`Y?{k^E!6jr|L6H-X<@NCrR^+N>GgSTt++F$E+_}H%$~ik zD)!8pKU?0ciGd6Jey3h6F3i`&wJm?r7+Ig3rfcP@ZLe@!iyPD|w=y*~vs;{9T3kF@ zME`Eq|2JwA-o)f0(ic3tNSkP4b)BDFnOT}b>Op<{U~X**abwiTjC4?F0fA@59}EA- zI`DsuCj6i9f5QLW=rrs9{jX=>j~oO6f&f8)AV3fx2oMAa0t5kq072kE2)y@NH>`_& zhqr*as}fXembPtUMUC%? zBR^3MDz$oAm`YIG6ygxIVSc3H#L?f@)hA!1JwF}8Bh(4Kzcl|b&GA9OBukTBd&I-Gd{lqmcm2fQ|&DYPVB)fIyHaj zQ>u7RnTj3T!|kRN_B6jYO&U-H28vf>#LdSJ@rEDIB=}v{v8TCKVu=D0(P|Lf{u6eg zJ=%ITUD|68WEz1<^}17w#^7D3hIuQzeq0^Eh7atO!@i~Xv!&qX^c-N~BiD}&n^Q~Y zQdj&1*P>EKk|vH523ZG*fo=e{8>rbOlnNNl#sL6c*n@mOv7rom6|kTw$L5sbHLx*; z&CZG*q(+8Vv5|_w1owR}-T(hsvj4wpDg%GyAP5iyzAXZuay>8-{^?Jn@7wR=c=SF$4(`2v9DYIkIK(b1j&(~x{5bqT{Wyr=^5Gl#ao~s_ zhp+9&ffK+v1M%a~@K|ydb{w)KOROe< zOyCFTE|I@Rg|C7rihxS;%I9s2r zNTY@NUa+%YSt>ou^#6r6`rqo{p~3&bBgFaS|8M`@d#`l%PZ9o4>^O)W$2algNcjI5 z`f*@cc)CLP{~7XsL!?zY_J+}^?$-(2?D1=;8X5Gy8eIKx2Bxs`v2}~eNL^ag&MbSF0O8h>*HR1 za&jHqd>{FUe7M*DTO2$z_&+#=IG_CgoxNY(?S1~&g#Q!%PxwFK|L5dA4xCKh<8ZjN z7@o=cKP$=@UR*}8n5MCe4Mu~PIKa!kILu275B}^QzV&~Bl4(}d;7tqUbpXwmJ*BS$ zApD>3|M;yz;1ux4|NlMf|NVcFfj@E(1PB5I0fGQQfFM8+AP5iy2m%CwbB(~e!xv%M zNX`AjE?6Ck?zjKv PzM=g;y!83|wg3M=E Date: Thu, 26 Dec 2024 17:38:16 -0800 Subject: [PATCH 11/21] increment cache hits --- lotus/cache.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lotus/cache.py b/lotus/cache.py index a732a8f..5e9d3af 100644 --- a/lotus/cache.py +++ b/lotus/cache.py @@ -53,6 +53,7 @@ def serialize(value): cached_result = model.cache.get(cache_key) if cached_result is not None: lotus.logger.debug(f"Cache hit for {cache_key}") + model.stats.total_usage.cache_hits += 1 return cached_result lotus.logger.debug(f"Cache miss for {cache_key}") From 3955735190c1f12acb2f886c05c69eb5a0b8c427 Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Thu, 26 Dec 2024 17:59:37 -0800 Subject: [PATCH 12/21] operator cache hit type --- .github/tests/lm_tests.py | 12 ++++++------ lotus/cache.py | 2 +- lotus/types.py | 1 + 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index e504233..c529aaf 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -484,9 +484,9 @@ def test_operator_cache(setup_models, model): [{"role": "user", "content": "What is the capital of France?"}], ] first_responses = lm(batch).outputs - assert lm.stats.total_usage.cache_hits == 0 + assert lm.stats.total_usage.operator_cache_hits == 0 second_responses = lm(batch).outputs - assert lm.stats.total_usage.cache_hits == 2 + assert lm.stats.total_usage.operator_cache_hits == 2 assert first_responses == second_responses @@ -500,14 +500,14 @@ def test_disable_operator_cache(setup_models, model): [{"role": "user", "content": "What is the capital of France?"}], ] lm(batch) - assert lm.stats.total_usage.cache_hits == 0 + assert lm.stats.total_usage.operator_cache_hits == 0 lm(batch) - assert lm.stats.total_usage.cache_hits == 0 + assert lm.stats.total_usage.operator_cache_hits == 0 # Now enable operator cache. Note that the first batch is not cached. lotus.settings.configure(enable_operator_cache=True) first_responses = lm(batch).outputs - assert lm.stats.total_usage.cache_hits == 0 + assert lm.stats.total_usage.operator_cache_hits == 0 second_responses = lm(batch).outputs - assert lm.stats.total_usage.cache_hits == 2 + assert lm.stats.total_usage.operator_cache_hits == 2 assert first_responses == second_responses diff --git a/lotus/cache.py b/lotus/cache.py index 5e9d3af..82c1c4c 100644 --- a/lotus/cache.py +++ b/lotus/cache.py @@ -53,7 +53,7 @@ def serialize(value): cached_result = model.cache.get(cache_key) if cached_result is not None: lotus.logger.debug(f"Cache hit for {cache_key}") - model.stats.total_usage.cache_hits += 1 + model.stats.total_usage.operator_cache_hits += 1 return cached_result lotus.logger.debug(f"Cache miss for {cache_key}") diff --git a/lotus/types.py b/lotus/types.py index 96b9079..c4cbb6d 100644 --- a/lotus/types.py +++ b/lotus/types.py @@ -32,6 +32,7 @@ class TotalUsage(BaseModel): total_tokens: int = 0 total_cost: float = 0.0 cache_hits: int = 0 + operator_cache_hits: int = 0 total_usage: TotalUsage = TotalUsage() From 4386b7c9c72a0584b764ca376a1cac3024fb4bae Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Thu, 26 Dec 2024 18:20:37 -0800 Subject: [PATCH 13/21] operator_cache_hit testing --- .github/tests/lm_tests.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index c529aaf..d888e7b 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -486,6 +486,7 @@ def test_operator_cache(setup_models, model): first_responses = lm(batch).outputs assert lm.stats.total_usage.operator_cache_hits == 0 second_responses = lm(batch).outputs + lotus.logger.debug(f"Operator Cache hits: {lm.stats.total_usage.operator_cache_hits}") assert lm.stats.total_usage.operator_cache_hits == 2 assert first_responses == second_responses @@ -509,5 +510,6 @@ def test_disable_operator_cache(setup_models, model): first_responses = lm(batch).outputs assert lm.stats.total_usage.operator_cache_hits == 0 second_responses = lm(batch).outputs + lotus.logger.debug(f"Operator Cache hits: {lm.stats.total_usage.operator_cache_hits}") assert lm.stats.total_usage.operator_cache_hits == 2 assert first_responses == second_responses From c966d6571933f0f7aea7c2bf77d93e4ea70e7db5 Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Thu, 26 Dec 2024 19:10:31 -0800 Subject: [PATCH 14/21] initialize op cache for tests --- .github/tests/lm_tests.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index d888e7b..91fc43c 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -5,6 +5,7 @@ from tokenizers import Tokenizer import lotus +from lotus.cache import CacheConfig, CacheFactory, CacheType from lotus.models import LM, SentenceTransformersRM from lotus.types import CascadeArgs @@ -476,7 +477,10 @@ def test_reset_cache(setup_models, model): @pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) def test_operator_cache(setup_models, model): - lm = setup_models[model] + cache_config = CacheConfig(cache_type=CacheType.SQLITE, max_size=1000) + cache = CacheFactory.create_cache(cache_config) + + lm = LM(model="gpt-4o-mini", cache=cache) lotus.settings.configure(lm=lm, enable_message_cache=True, enable_operator_cache=True) batch = [ @@ -486,14 +490,16 @@ def test_operator_cache(setup_models, model): first_responses = lm(batch).outputs assert lm.stats.total_usage.operator_cache_hits == 0 second_responses = lm(batch).outputs - lotus.logger.debug(f"Operator Cache hits: {lm.stats.total_usage.operator_cache_hits}") assert lm.stats.total_usage.operator_cache_hits == 2 assert first_responses == second_responses @pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) def test_disable_operator_cache(setup_models, model): - lm = setup_models[model] + cache_config = CacheConfig(cache_type=CacheType.SQLITE, max_size=1000) + cache = CacheFactory.create_cache(cache_config) + + lm = LM(model="gpt-4o-mini", cache=cache) lotus.settings.configure(lm=lm, enable_message_cache=True, enable_operator_cache=False) batch = [ @@ -510,6 +516,5 @@ def test_disable_operator_cache(setup_models, model): first_responses = lm(batch).outputs assert lm.stats.total_usage.operator_cache_hits == 0 second_responses = lm(batch).outputs - lotus.logger.debug(f"Operator Cache hits: {lm.stats.total_usage.operator_cache_hits}") assert lm.stats.total_usage.operator_cache_hits == 2 assert first_responses == second_responses From 20d7deaccacd34988a4c5fd9e52d962e6d86792e Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Thu, 26 Dec 2024 19:51:34 -0800 Subject: [PATCH 15/21] fix test --- .github/tests/lm_tests.py | 55 +++++++++++++++++++++++++++------------ 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index 91fc43c..ce15fdd 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -483,15 +483,25 @@ def test_operator_cache(setup_models, model): lm = LM(model="gpt-4o-mini", cache=cache) lotus.settings.configure(lm=lm, enable_message_cache=True, enable_operator_cache=True) - batch = [ - [{"role": "user", "content": "Hello, world!"}], - [{"role": "user", "content": "What is the capital of France?"}], - ] - first_responses = lm(batch).outputs + data = { + "Course Name": [ + "Dynamics and Control of Chemical Processes", + "Optimization Methods in Engineering", + "Chemical Kinetics and Catalysis", + "Transport Phenomena and Separations", + ] + } + + df = pd.DataFrame(data) + user_instruction = "What is a similar course to {Course Name}. Be concise?" + + first_response = df.sem_map(df, user_instruction) assert lm.stats.total_usage.operator_cache_hits == 0 - second_responses = lm(batch).outputs - assert lm.stats.total_usage.operator_cache_hits == 2 - assert first_responses == second_responses + + second_response = df.sem_map(df, user_instruction) + assert lm.stats.total_usage.operator_cache_hits == 1 + + assert first_response == second_response @pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) @@ -502,19 +512,30 @@ def test_disable_operator_cache(setup_models, model): lm = LM(model="gpt-4o-mini", cache=cache) lotus.settings.configure(lm=lm, enable_message_cache=True, enable_operator_cache=False) - batch = [ - [{"role": "user", "content": "Hello, world!"}], - [{"role": "user", "content": "What is the capital of France?"}], - ] - lm(batch) + data = { + "Course Name": [ + "Dynamics and Control of Chemical Processes", + "Optimization Methods in Engineering", + "Chemical Kinetics and Catalysis", + "Transport Phenomena and Separations", + ] + } + + df = pd.DataFrame(data) + user_instruction = "What is a similar course to {Course Name}. Be concise?" + + first_response = df.sem_map(df, user_instruction) assert lm.stats.total_usage.operator_cache_hits == 0 - lm(batch) + + second_response = df.sem_map(df, user_instruction) assert lm.stats.total_usage.operator_cache_hits == 0 + assert first_response == second_response + # Now enable operator cache. Note that the first batch is not cached. lotus.settings.configure(enable_operator_cache=True) - first_responses = lm(batch).outputs + first_responses = df.sem_map(df, user_instruction) assert lm.stats.total_usage.operator_cache_hits == 0 - second_responses = lm(batch).outputs - assert lm.stats.total_usage.operator_cache_hits == 2 + second_responses = df.sem_map(df, user_instruction) + assert lm.stats.total_usage.operator_cache_hits == 1 assert first_responses == second_responses From 2ad4ce2c45c0ecbb9e70fad8f6f4417c4bdebfa8 Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Thu, 26 Dec 2024 19:56:28 -0800 Subject: [PATCH 16/21] test typo --- .github/tests/lm_tests.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index ce15fdd..7bdd59e 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -495,10 +495,10 @@ def test_operator_cache(setup_models, model): df = pd.DataFrame(data) user_instruction = "What is a similar course to {Course Name}. Be concise?" - first_response = df.sem_map(df, user_instruction) + first_response = df.sem_map(user_instruction) assert lm.stats.total_usage.operator_cache_hits == 0 - second_response = df.sem_map(df, user_instruction) + second_response = df.sem_map(user_instruction) assert lm.stats.total_usage.operator_cache_hits == 1 assert first_response == second_response @@ -524,18 +524,18 @@ def test_disable_operator_cache(setup_models, model): df = pd.DataFrame(data) user_instruction = "What is a similar course to {Course Name}. Be concise?" - first_response = df.sem_map(df, user_instruction) + first_response = df.sem_map(user_instruction) assert lm.stats.total_usage.operator_cache_hits == 0 - second_response = df.sem_map(df, user_instruction) + second_response = df.sem_map(user_instruction) assert lm.stats.total_usage.operator_cache_hits == 0 assert first_response == second_response # Now enable operator cache. Note that the first batch is not cached. lotus.settings.configure(enable_operator_cache=True) - first_responses = df.sem_map(df, user_instruction) + first_responses = df.sem_map(user_instruction) assert lm.stats.total_usage.operator_cache_hits == 0 - second_responses = df.sem_map(df, user_instruction) + second_responses = df.sem_map(user_instruction) assert lm.stats.total_usage.operator_cache_hits == 1 assert first_responses == second_responses From e846172bbf0107c861a4876be0465d9d5e2b5a0f Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Thu, 26 Dec 2024 20:09:41 -0800 Subject: [PATCH 17/21] include expected response --- .github/tests/lm_tests.py | 42 +++++++++++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index 7bdd59e..1917bd5 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -492,8 +492,25 @@ def test_operator_cache(setup_models, model): ] } + expected_response = pd.DataFrame( + { + "Course Name": [ + "Dynamics and Control of Chemical Processes", + "Optimization Methods in Engineering", + "Chemical Kinetics and Catalysis", + "Transport Phenomena and Separations", + ], + "_map": [ + "«Process Dynamics and Control»", + "«Advanced Optimization Techniques in Engineering»", + "«Reaction Kinetics and Mechanisms»", + "Fluid Mechanics and Mass Transfer", + ], + } + ) + df = pd.DataFrame(data) - user_instruction = "What is a similar course to {Course Name}. Be concise?" + user_instruction = "What is a similar course to {Course Name}. Please just output the course name." first_response = df.sem_map(user_instruction) assert lm.stats.total_usage.operator_cache_hits == 0 @@ -501,7 +518,7 @@ def test_operator_cache(setup_models, model): second_response = df.sem_map(user_instruction) assert lm.stats.total_usage.operator_cache_hits == 1 - assert first_response == second_response + assert first_response == second_response == expected_response @pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) @@ -521,8 +538,25 @@ def test_disable_operator_cache(setup_models, model): ] } + expected_response = pd.DataFrame( + { + "Course Name": [ + "Dynamics and Control of Chemical Processes", + "Optimization Methods in Engineering", + "Chemical Kinetics and Catalysis", + "Transport Phenomena and Separations", + ], + "_map": [ + "«Process Dynamics and Control»", + "«Advanced Optimization Techniques in Engineering»", + "«Reaction Kinetics and Mechanisms»", + "Fluid Mechanics and Mass Transfer", + ], + } + ) + df = pd.DataFrame(data) - user_instruction = "What is a similar course to {Course Name}. Be concise?" + user_instruction = "What is a similar course to {Course Name}. Please just output the course name." first_response = df.sem_map(user_instruction) assert lm.stats.total_usage.operator_cache_hits == 0 @@ -538,4 +572,4 @@ def test_disable_operator_cache(setup_models, model): assert lm.stats.total_usage.operator_cache_hits == 0 second_responses = df.sem_map(user_instruction) assert lm.stats.total_usage.operator_cache_hits == 1 - assert first_responses == second_responses + assert first_responses == second_responses == expected_response From 0de3d6e5dfbf082b547179440b6a9ae7062d2b89 Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Thu, 26 Dec 2024 20:51:46 -0800 Subject: [PATCH 18/21] fix --- .github/tests/lm_tests.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index 1917bd5..dd26a06 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -501,9 +501,9 @@ def test_operator_cache(setup_models, model): "Transport Phenomena and Separations", ], "_map": [ - "«Process Dynamics and Control»", - "«Advanced Optimization Techniques in Engineering»", - "«Reaction Kinetics and Mechanisms»", + "Process Dynamics and Control", + "Advanced Optimization Techniques in Engineering", + "Reaction Kinetics and Mechanisms", "Fluid Mechanics and Mass Transfer", ], } @@ -518,7 +518,12 @@ def test_operator_cache(setup_models, model): second_response = df.sem_map(user_instruction) assert lm.stats.total_usage.operator_cache_hits == 1 - assert first_response == second_response == expected_response + result_1 = first_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True) + result_2 = second_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True) + + pd.testing.assert_frame_equal(result_1, result_2) + pd.testing.assert_frame_equal(result_1, expected_response) + pd.testing.assert_frame_equal(result_2, expected_response) @pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) @@ -547,9 +552,9 @@ def test_disable_operator_cache(setup_models, model): "Transport Phenomena and Separations", ], "_map": [ - "«Process Dynamics and Control»", - "«Advanced Optimization Techniques in Engineering»", - "«Reaction Kinetics and Mechanisms»", + "Process Dynamics and Control", + "Advanced Optimization Techniques in Engineering", + "Reaction Kinetics and Mechanisms", "Fluid Mechanics and Mass Transfer", ], } @@ -564,12 +569,15 @@ def test_disable_operator_cache(setup_models, model): second_response = df.sem_map(user_instruction) assert lm.stats.total_usage.operator_cache_hits == 0 - assert first_response == second_response + pd.testing.assert_frame_equal(first_response, second_response) - # Now enable operator cache. Note that the first batch is not cached. + # Now enable operator cache. lotus.settings.configure(enable_operator_cache=True) - first_responses = df.sem_map(user_instruction) + first_responses = df.sem_map(user_instruction)["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True) assert lm.stats.total_usage.operator_cache_hits == 0 - second_responses = df.sem_map(user_instruction) + second_responses = df.sem_map(user_instruction)["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True) assert lm.stats.total_usage.operator_cache_hits == 1 - assert first_responses == second_responses == expected_response + + pd.testing.assert_frame_equal(first_responses, second_responses) + pd.testing.assert_frame_equal(first_responses, expected_response) + pd.testing.assert_frame_equal(second_responses, expected_response) From e4039211d31a9d5b77ecbf9aa343bb0a75ed8740 Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Thu, 26 Dec 2024 20:58:47 -0800 Subject: [PATCH 19/21] update --- .github/tests/lm_tests.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index dd26a06..da3dfab 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -525,6 +525,10 @@ def test_operator_cache(setup_models, model): pd.testing.assert_frame_equal(result_1, expected_response) pd.testing.assert_frame_equal(result_2, expected_response) + lm.reset_cache() + lm.reset_stats() + assert lm.stats.total_usage.operator_cache_hits == 0 + @pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) def test_disable_operator_cache(setup_models, model): @@ -573,9 +577,11 @@ def test_disable_operator_cache(setup_models, model): # Now enable operator cache. lotus.settings.configure(enable_operator_cache=True) - first_responses = df.sem_map(user_instruction)["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True) + first_responses = df.sem_map(user_instruction) + first_responses = first_responses[first_responses["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True)] assert lm.stats.total_usage.operator_cache_hits == 0 - second_responses = df.sem_map(user_instruction)["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True) + second_responses = df.sem_map(user_instruction) + second_responses = second_responses[second_responses["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True)] assert lm.stats.total_usage.operator_cache_hits == 1 pd.testing.assert_frame_equal(first_responses, second_responses) From 53ec0401be3335745fe6e66706069b17475052be Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Thu, 26 Dec 2024 21:00:21 -0800 Subject: [PATCH 20/21] update --- .github/tests/lm_tests.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index da3dfab..eb1fc22 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -518,8 +518,8 @@ def test_operator_cache(setup_models, model): second_response = df.sem_map(user_instruction) assert lm.stats.total_usage.operator_cache_hits == 1 - result_1 = first_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True) - result_2 = second_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True) + result_1 = first_response[first_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True)] + result_2 = second_response[second_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True)] pd.testing.assert_frame_equal(result_1, result_2) pd.testing.assert_frame_equal(result_1, expected_response) From 8a5dbf8d6558c7e2da2a75db127d44e4e92c9de6 Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Thu, 26 Dec 2024 21:09:34 -0800 Subject: [PATCH 21/21] update --- .github/tests/lm_tests.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index eb1fc22..3a18bf8 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -518,12 +518,13 @@ def test_operator_cache(setup_models, model): second_response = df.sem_map(user_instruction) assert lm.stats.total_usage.operator_cache_hits == 1 - result_1 = first_response[first_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True)] - result_2 = second_response[second_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True)] + first_response["_map"] = first_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower() + second_response["_map"] = second_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower() + expected_response["_map"] = expected_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower() - pd.testing.assert_frame_equal(result_1, result_2) - pd.testing.assert_frame_equal(result_1, expected_response) - pd.testing.assert_frame_equal(result_2, expected_response) + pd.testing.assert_frame_equal(first_response, second_response) + pd.testing.assert_frame_equal(first_response, expected_response) + pd.testing.assert_frame_equal(second_response, expected_response) lm.reset_cache() lm.reset_stats() @@ -578,12 +579,14 @@ def test_disable_operator_cache(setup_models, model): # Now enable operator cache. lotus.settings.configure(enable_operator_cache=True) first_responses = df.sem_map(user_instruction) - first_responses = first_responses[first_responses["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True)] + first_responses["_map"] = first_responses["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower() assert lm.stats.total_usage.operator_cache_hits == 0 second_responses = df.sem_map(user_instruction) - second_responses = second_responses[second_responses["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True)] + second_responses["_map"] = second_responses["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower() assert lm.stats.total_usage.operator_cache_hits == 1 + expected_response["_map"] = expected_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower() + pd.testing.assert_frame_equal(first_responses, second_responses) pd.testing.assert_frame_equal(first_responses, expected_response) pd.testing.assert_frame_equal(second_responses, expected_response)