Skip to content

Commit

Permalink
fix to filter and join to be consistent api
Browse files Browse the repository at this point in the history
  • Loading branch information
liana313 committed Dec 10, 2024
1 parent 0f886eb commit 833511f
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 32 deletions.
14 changes: 6 additions & 8 deletions examples/op_examples/filter_cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import lotus
from lotus.models import LM
from lotus.types import CascadeArgs


gpt_35_turbo = LM("gpt-3.5-turbo")
gpt_4o = LM("gpt-4o")
Expand Down Expand Up @@ -116,13 +118,9 @@
}
df = pd.DataFrame(data)
user_instruction = "{Course Name} requires a lot of math"
df, stats = df.sem_filter(
user_instruction=user_instruction,
learn_cascade_threshold_sample_percentage=0.5,
recall_target=0.9,
precision_target=0.9,
failure_probability=0.2,
return_stats=True,
)

cascade_args = CascadeArgs(recall_target=0.9, precision_target=0.9, sampling_percentage=0.5, failure_probability=0.2)

df, stats = df.sem_filter(user_instruction=user_instruction, cascade_args=cascade_args, return_stats=True)
print(df)
print(stats)
4 changes: 2 additions & 2 deletions examples/op_examples/join_cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import lotus
from lotus.models import LM, SentenceTransformersRM
from lotus.types import SemJoinCascadeArgs
from lotus.types import CascadeArgs

lm = LM(model="gpt-4o-mini")
rm = SentenceTransformersRM(model="intfloat/e5-base-v2")
Expand Down Expand Up @@ -124,7 +124,7 @@
df2 = pd.DataFrame(data2)
join_instruction = "By taking {Course Name:left} I will learn {Skill:right}"

cascade_args = SemJoinCascadeArgs(recall_target=0.7, precision_target=0.7)
cascade_args = CascadeArgs(recall_target=0.7, precision_target=0.7)
res, stats = df1.sem_join(df2, join_instruction, cascade_args=cascade_args, return_stats=True)


Expand Down
33 changes: 15 additions & 18 deletions lotus/sem_ops/sem_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import lotus
from lotus.templates import task_instructions
from lotus.types import LMOutput, LogprobsForFilterCascade, SemanticFilterOutput
from lotus.types import LMOutput, LogprobsForFilterCascade, SemanticFilterOutput, CascadeArgs
from lotus.utils import show_safe_mode

from .cascade_utils import calibrate_llm_logprobs, importance_sampling, learn_cascade_thresholds
Expand Down Expand Up @@ -148,11 +148,7 @@ def __call__(
examples: pd.DataFrame | None = None,
helper_examples: pd.DataFrame | None = None,
strategy: str | None = None,
helper_strategy: str | None = None,
learn_cascade_threshold_sample_percentage: int | None = None,
recall_target: float | None = None,
precision_target: float | None = None,
failure_probability: float | None = None,
cascade_args: CascadeArgs | None = None,
return_stats: bool = False,
safe_mode: bool = False,
progress_bar_desc: str = "Filtering",
Expand All @@ -168,11 +164,11 @@ def __call__(
examples (pd.DataFrame | None): The examples dataframe. Defaults to None.
helper_examples (pd.DataFrame | None): The helper examples dataframe. Defaults to None.
strategy (str | None): The reasoning strategy. Defaults to None.
helper_strategy (str | None): The reasoning strategy for helper. Defaults to None.
learn_cascade_threshold_sample_size (Optional[int]): The percentage of samples from which to learn thresholds when cascading.
recall_target (float | None): The specified recall target when cascading.
precision_target (float | None): The specified precision target when cascading.
failure_probability (float | None): The specified failure probability for precision/recall targets when cascading.
cascade_args (CascadeArgs | None): The arguments for join cascade. Defaults to None.
recall_target (float | None): The target recall. Defaults to None.
precision_target (float | None): The target precision when cascading. Defaults to None.
sampling_percentage (float): The percentage of the data to sample when cascading. Defaults to 0.1.
failure_probability (float): The failure probability when cascading. Defaults to 0.2.
return_stats (bool): Whether to return statistics. Defaults to False.
Returns:
Expand All @@ -182,6 +178,7 @@ def __call__(
lotus.logger.debug(user_instruction)
col_li = lotus.nl_expression.parse_cols(user_instruction)
lotus.logger.debug(col_li)
helper_strategy = strategy

# check that column exists
for column in col_li:
Expand All @@ -205,7 +202,7 @@ def __call__(
cot_reasoning = examples["Reasoning"].tolist()

pos_cascade_threshold, neg_cascade_threshold = None, None
if learn_cascade_threshold_sample_percentage is not None:
if cascade_args is not None:
# Get few-shot examples for small LM
helper_examples_multimodal_data = None
helper_examples_answers = None
Expand All @@ -218,12 +215,12 @@ def __call__(
if helper_strategy == "cot":
helper_cot_reasoning = examples["Reasoning"].tolist()

if learn_cascade_threshold_sample_percentage and lotus.settings.helper_lm:
if cascade_args and lotus.settings.helper_lm:
if helper_strategy == "cot":
lotus.logger.error("CoT not supported for helper models in cascades.")
raise Exception

if recall_target is None or precision_target is None or failure_probability is None:
if cascade_args.recall_target is None or cascade_args.precision_target is None or cascade_args.failure_probability is None:
lotus.logger.error(
"Recall target, precision target, and confidence need to be specified for learned thresholds."
)
Expand Down Expand Up @@ -251,7 +248,7 @@ def __call__(
helper_true_probs = calibrate_llm_logprobs(formatted_helper_logprobs.true_probs)

sample_indices, correction_factors = importance_sampling(
helper_true_probs, learn_cascade_threshold_sample_percentage
helper_true_probs, cascade_args.sampling_percentage
)
sample_df = self._obj.loc[sample_indices]
sample_multimodal_data = task_instructions.df2multimodal_info(sample_df, col_li)
Expand All @@ -263,9 +260,9 @@ def __call__(
lm=lotus.settings.lm,
formatted_usr_instr=formatted_usr_instr,
default=default,
recall_target=recall_target,
precision_target=precision_target,
delta=failure_probability / 2,
recall_target=cascade_args.recall_target,
precision_target=cascade_args.precision_target,
delta=cascade_args.failure_probability / 2,
helper_true_probs=sample_helper_true_probs,
sample_correction_factors=sample_correction_factors,
examples_multimodal_data=examples_multimodal_data,
Expand Down
6 changes: 3 additions & 3 deletions lotus/sem_ops/sem_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import lotus
from lotus.templates import task_instructions
from lotus.types import SemanticJoinOutput, SemJoinCascadeArgs
from lotus.types import SemanticJoinOutput, CascadeArgs
from lotus.utils import show_safe_mode

from .cascade_utils import calibrate_sem_sim_join, importance_sampling, learn_cascade_thresholds
Expand Down Expand Up @@ -576,7 +576,7 @@ def __call__(
examples: pd.DataFrame | None = None,
strategy: str | None = None,
default: bool = True,
cascade_args: SemJoinCascadeArgs | None = None,
cascade_args: CascadeArgs | None = None,
return_stats: bool = False,
safe_mode: bool = False,
progress_bar_desc: str = "Join comparisons",
Expand All @@ -593,7 +593,7 @@ def __call__(
examples (pd.DataFrame | None): The examples dataframe. Defaults to None.
strategy (str | None): The reasoning strategy. Defaults to None.
default (bool): The default value for the join in case of parsing errors. Defaults to True.
cascade_args (SemJoinCascadeArgs | None): The arguments for join cascade. Defaults to None.
cascade_args (CascadeArgs | None): The arguments for join cascade. Defaults to None.
recall_target (float | None): The target recall. Defaults to None.
precision_target (float | None): The target precision when cascading. Defaults to None.
sampling_percentage (float): The percentage of the data to sample when cascading. Defaults to 0.1.
Expand Down
2 changes: 1 addition & 1 deletion lotus/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class SemanticJoinOutput(StatsMixin):
all_explanations: list[str | None]


class SemJoinCascadeArgs(BaseModel):
class CascadeArgs(BaseModel):
recall_target: float | None = None
precision_target: float | None = None
sampling_percentage: float = 0.1
Expand Down

0 comments on commit 833511f

Please sign in to comment.