-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
operator level cache #65
Changes from 8 commits
9bd4b09
2057704
03947bd
24f21b3
7ecd48d
813d83c
74cec20
335df10
ebced3c
1542ed2
a5761e3
dcfd00c
3955735
4386b7c
c966d65
20d7dea
2ad4ce2
e846172
0de3d6e
e403921
53ec040
8a5dbf8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
import pandas as pd | ||
|
||
import lotus.models | ||
from lotus.cache import operator_cache | ||
from lotus.templates import task_instructions | ||
from lotus.types import LMOutput, SemanticAggOutput | ||
|
||
|
@@ -14,6 +15,7 @@ def sem_agg( | |
partition_ids: list[int], | ||
safe_mode: bool = False, | ||
progress_bar_desc: str = "Aggregating", | ||
use_operator_cache: bool = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. High level question - should There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. putting it in settings is better I'd say. I can put it in settings |
||
) -> SemanticAggOutput: | ||
""" | ||
Aggregates multiple documents into a single answer using a model. | ||
|
@@ -148,6 +150,7 @@ def process_group(args): | |
group, user_instruction, all_cols, suffix, progress_bar_desc = args | ||
return group.sem_agg(user_instruction, all_cols, suffix, None, progress_bar_desc=progress_bar_desc) | ||
|
||
@operator_cache | ||
def __call__( | ||
self, | ||
user_instruction: str, | ||
|
@@ -156,6 +159,7 @@ def __call__( | |
group_by: list[str] | None = None, | ||
safe_mode: bool = False, | ||
progress_bar_desc: str = "Aggregating", | ||
use_operator_cache: bool = False, | ||
) -> pd.DataFrame: | ||
""" | ||
Applies semantic aggregation over a dataframe. | ||
|
@@ -213,6 +217,7 @@ def __call__( | |
partition_ids, | ||
safe_mode=safe_mode, | ||
progress_bar_desc=progress_bar_desc, | ||
use_operator_cache=use_operator_cache, | ||
) | ||
|
||
# package answer in a dataframe | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use
lotus.logger
rather than prints.