Skip to content

Commit

Permalink
sem_join + sem_filter pbars (#46)
Browse files Browse the repository at this point in the history
Added flag to sem_filter for sem_join to create its own combined pbar

---------

Co-authored-by: Sid Jha <[email protected]>
Co-authored-by: liana313 <[email protected]>
  • Loading branch information
3 people authored Dec 10, 2024
1 parent 99f96b2 commit 733f253
Show file tree
Hide file tree
Showing 17 changed files with 183 additions and 91 deletions.
48 changes: 28 additions & 20 deletions .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import lotus
from lotus.models import LM, SentenceTransformersRM
from lotus.types import SemJoinCascadeArgs
from lotus.types import CascadeArgs

################################################################################
# Setup
Expand Down Expand Up @@ -270,10 +270,12 @@ def test_filter_cascade(setup_models):
# All filters resolved by the helper model
filtered_df, stats = df.sem_filter(
user_instruction=user_instruction,
learn_cascade_threshold_sample_percentage=0.5,
recall_target=0.9,
precision_target=0.9,
failure_probability=0.2,
cascade_args=CascadeArgs(
learn_cascade_threshold_sample_percentage=0.5,
recall_target=0.9,
precision_target=0.9,
failure_probability=0.2,
),
return_stats=True,
)

Expand All @@ -286,10 +288,12 @@ def test_filter_cascade(setup_models):
def test_join_cascade(setup_models):
models = setup_models
rm = SentenceTransformersRM(model="intfloat/e5-base-v2")
lotus.settings.configure(lm=models["gpt-4o-mini"],
rm=rm,
min_join_cascade_size=10, # for smaller testings
cascade_IS_random_seed=42)
lotus.settings.configure(
lm=models["gpt-4o-mini"],
rm=rm,
min_join_cascade_size=10, # for smaller testings
cascade_IS_random_seed=42,
)

data1 = {
"School": [
Expand All @@ -308,37 +312,41 @@ def test_join_cascade(setup_models):
"Yale University",
"Cornell University",
"University of Pennsylvania",
]}
]
}
data2 = {"School Type": ["Public School", "Private School"]}

df1 = pd.DataFrame(data1)
df2 = pd.DataFrame(data2)
join_instruction = "{School} is a {School Type}"
expected_pairs = [("University of California, Berkeley", "Public School"), ("Stanford University", "Private School")]
expected_pairs = [
("University of California, Berkeley", "Public School"),
("Stanford University", "Private School"),
]

# Cascade join
joined_df, stats = df1.sem_join(
df2, join_instruction,
cascade_args=SemJoinCascadeArgs(recall_target=0.7, precision_target=0.7),
return_stats=True)
df2, join_instruction, cascade_args=CascadeArgs(recall_target=0.7, precision_target=0.7), return_stats=True
)

for pair in expected_pairs:
school, school_type = pair
exists = ((joined_df['School'] == school) & (joined_df['School Type'] == school_type)).any()
exists = ((joined_df["School"] == school) & (joined_df["School Type"] == school_type)).any()
assert exists, f"Expected pair {pair} does not exist in the dataframe!"
assert stats["join_resolved_by_helper_model"] > 0, stats

# All joins resolved by the large model
joined_df, stats = df1.sem_join(
df2, join_instruction,
cascade_args=SemJoinCascadeArgs(recall_target=1.0, precision_target=1.0),
return_stats=True)
df2, join_instruction, cascade_args=CascadeArgs(recall_target=1.0, precision_target=1.0), return_stats=True
)

for pair in expected_pairs:
school, school_type = pair
exists = ((joined_df['School'] == school) & (joined_df['School Type'] == school_type)).any()
exists = ((joined_df["School"] == school) & (joined_df["School Type"] == school_type)).any()
assert exists, f"Expected pair {pair} does not exist in the dataframe!"
assert stats["join_resolved_by_large_model"] > stats["join_resolved_by_helper_model"], stats # helper negative still can still meet the precision target
assert (
stats["join_resolved_by_large_model"] > stats["join_resolved_by_helper_model"]
), stats # helper negative still can still meet the precision target
assert stats["join_helper_positive"] == 0, stats


Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
pip install ruff==0.7.2
- name: Run ruff
run: ruff check .
run: ruff check lotus/

mypy:
name: Type Check
Expand Down
5 changes: 2 additions & 3 deletions examples/op_examples/agg.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import pandas as pd

import lotus
from lotus.models import LM, SentenceTransformersRM
from lotus.models import LM

lm = LM(model="gpt-4o-mini")
rm = SentenceTransformersRM(model="intfloat/e5-base-v2")

lotus.settings.configure(lm=lm, rm=rm)
lotus.settings.configure(lm=lm)
data = {
"Course Name": [
"Probability and Random Processes",
Expand Down
14 changes: 6 additions & 8 deletions examples/op_examples/filter_cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import lotus
from lotus.models import LM
from lotus.types import CascadeArgs


gpt_35_turbo = LM("gpt-3.5-turbo")
gpt_4o = LM("gpt-4o")
Expand Down Expand Up @@ -116,13 +118,9 @@
}
df = pd.DataFrame(data)
user_instruction = "{Course Name} requires a lot of math"
df, stats = df.sem_filter(
user_instruction=user_instruction,
learn_cascade_threshold_sample_percentage=0.5,
recall_target=0.9,
precision_target=0.9,
failure_probability=0.2,
return_stats=True,
)

cascade_args = CascadeArgs(recall_target=0.9, precision_target=0.9, sampling_percentage=0.5, failure_probability=0.2)

df, stats = df.sem_filter(user_instruction=user_instruction, cascade_args=cascade_args, return_stats=True)
print(df)
print(stats)
4 changes: 2 additions & 2 deletions examples/op_examples/join_cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import lotus
from lotus.models import LM, SentenceTransformersRM
from lotus.types import SemJoinCascadeArgs
from lotus.types import CascadeArgs

lm = LM(model="gpt-4o-mini")
rm = SentenceTransformersRM(model="intfloat/e5-base-v2")
Expand Down Expand Up @@ -124,7 +124,7 @@
df2 = pd.DataFrame(data2)
join_instruction = "By taking {Course Name:left} I will learn {Skill:right}"

cascade_args = SemJoinCascadeArgs(recall_target=0.7, precision_target=0.7)
cascade_args = CascadeArgs(recall_target=0.7, precision_target=0.7)
res, stats = df1.sem_join(df2, join_instruction, cascade_args=cascade_args, return_stats=True)


Expand Down
2 changes: 1 addition & 1 deletion lotus/models/cross_encoder_reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ def __init__(
self.model = CrossEncoder(model, device=device) # type: ignore # CrossEncoder has wrong type stubs

def __call__(self, query: str, docs: list[str], K: int) -> RerankerOutput:
results = self.model.rank(query, docs, top_k=K, batch_size=self.max_batch_size)
results = self.model.rank(query, docs, top_k=K, batch_size=self.max_batch_size, show_progress_bar=False)
indices = [int(result["corpus_id"]) for result in results]
return RerankerOutput(indices=indices)
26 changes: 22 additions & 4 deletions lotus/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ def __init__(
self.cache = Cache(max_cache_size)

def __call__(
self, messages: list[list[dict[str, str]]], safe_mode: bool = False, **kwargs: dict[str, Any]
self,
messages: list[list[dict[str, str]]],
show_progress_bar: bool = True,
progress_bar_desc: str = "Processing uncached messages",
**kwargs: dict[str, Any],
) -> LMOutput:
all_kwargs = {**self.kwargs, **kwargs}

Expand All @@ -59,7 +63,9 @@ def __call__(
self.stats.total_usage.cache_hits += len(messages) - len(uncached_data)

# Process uncached messages in batches
uncached_responses = self._process_uncached_messages(uncached_data, all_kwargs)
uncached_responses = self._process_uncached_messages(
uncached_data, all_kwargs, show_progress_bar, progress_bar_desc
)

# Add new responses to cache
for resp, (_, hash) in zip(uncached_responses, uncached_data):
Expand All @@ -74,12 +80,24 @@ def __call__(

return LMOutput(outputs=outputs, logprobs=logprobs)

def _process_uncached_messages(self, uncached_data, all_kwargs):
def _process_uncached_messages(self, uncached_data, all_kwargs, show_progress_bar, progress_bar_desc):
"""Processes uncached messages in batches and returns responses."""
uncached_responses = []
for i in tqdm(range(0, len(uncached_data), self.max_batch_size), desc="Processing uncached messages"):
total_calls = len(uncached_data)

pbar = tqdm(
total=total_calls,
desc=progress_bar_desc,
disable=not show_progress_bar,
bar_format="{l_bar}{bar} {n}/{total} LM calls [{elapsed}<{remaining}, {rate_fmt}{postfix}]",
)
for i in range(0, total_calls, self.max_batch_size):
batch = [msg for msg, _ in uncached_data[i : i + self.max_batch_size]]
uncached_responses.extend(batch_completion(self.model, batch, drop_params=True, **all_kwargs))

pbar.update(len(batch))
pbar.close()

return uncached_responses

def _cache_response(self, response, hash):
Expand Down
2 changes: 1 addition & 1 deletion lotus/models/sentence_transformers_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _embed(self, docs: pd.Series | list) -> NDArray[np.float64]:
batch = docs[i : i + self.max_batch_size]
_batch = convert_to_base_data(batch)
torch_embeddings = self.transformer.encode(
_batch, convert_to_tensor=True, normalize_embeddings=self.normalize_embeddings
_batch, convert_to_tensor=True, normalize_embeddings=self.normalize_embeddings, show_progress_bar=False
)
assert isinstance(torch_embeddings, torch.Tensor)
cpu_embeddings = torch_embeddings.cpu().numpy()
Expand Down
18 changes: 12 additions & 6 deletions lotus/sem_ops/cascade_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def importance_sampling(
sample_size = int(sample_percentage * len(proxy_scores))
sample_indices = np.random.choice(indices, sample_size, p=sample_w)

correction_factors = (1/len(proxy_scores)) / w
correction_factors = (1 / len(proxy_scores)) / w

return sample_indices, correction_factors

Expand Down Expand Up @@ -65,8 +65,14 @@ def recall(pos_threshold: float, neg_threshold: float, sorted_pairs: list[tuple[
sent_to_oracle = [x for x in sorted_pairs if x[0] < pos_threshold and x[0] > neg_threshold]
total_correct = sum(pair[1] * pair[2] for pair in sorted_pairs)
recall = (
sum(1 for x in helper_accepted if x[0] >= pos_threshold and x[1]) + sum(x[1] * x[2] for x in sent_to_oracle)
) / total_correct if total_correct > 0 else 0.0
(
sum(1 for x in helper_accepted if x[0] >= pos_threshold and x[1])
+ sum(x[1] * x[2] for x in sent_to_oracle)
)
/ total_correct
if total_correct > 0
else 0.0
)
return recall

def precision(pos_threshold: float, neg_threshold: float, sorted_pairs: list[tuple[float, bool, float]]) -> float:
Expand All @@ -80,8 +86,7 @@ def precision(pos_threshold: float, neg_threshold: float, sorted_pairs: list[tup

def calculate_tau_neg(sorted_pairs: list[tuple[float, bool, float]], tau_pos: float, recall_target: float) -> float:
return max(
(x[0] for x in sorted_pairs[::-1] if recall(tau_pos, x[0], sorted_pairs) >= recall_target),
default=0
(x[0] for x in sorted_pairs[::-1] if recall(tau_pos, x[0], sorted_pairs) >= recall_target), default=0
)

# Pair helper model probabilities with helper correctness and oracle answer
Expand Down Expand Up @@ -135,6 +140,7 @@ def calculate_tau_neg(sorted_pairs: list[tuple[float, bool, float]], tau_pos: fl

return best_combination, oracle_calls


def calibrate_sem_sim_join(true_score: list[float]) -> list[float]:
true_score = list(np.clip(true_score, 0, 1))
return true_score
return true_score
7 changes: 5 additions & 2 deletions lotus/sem_ops/sem_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def sem_agg(
user_instruction: str,
partition_ids: list[int],
safe_mode: bool = False,
progress_bar_desc: str = "Aggregating",
) -> SemanticAggOutput:
"""
Aggregates multiple documents into a single answer using a model.
Expand Down Expand Up @@ -115,7 +116,7 @@ def doc_formatter(tree_level: int, doc: str, ctr: int) -> str:
batch.append([{"role": "user", "content": prompt}])
new_partition_ids.append(cur_partition_id)

lm_output: LMOutput = model(batch)
lm_output: LMOutput = model(batch, progress_bar_desc=progress_bar_desc)

summaries = lm_output.outputs
partition_ids = new_partition_ids
Expand Down Expand Up @@ -149,6 +150,7 @@ def __call__(
suffix: str = "_output",
group_by: list[str] | None = None,
safe_mode: bool = False,
progress_bar_desc: str = "Aggregating",
) -> pd.DataFrame:
"""
Applies semantic aggregation over a dataframe.
Expand Down Expand Up @@ -178,7 +180,7 @@ def __call__(
grouped = self._obj.groupby(group_by)
new_df = pd.DataFrame()
for name, group in grouped:
res = group.sem_agg(user_instruction, all_cols, suffix, None)
res = group.sem_agg(user_instruction, all_cols, suffix, None, progress_bar_desc=progress_bar_desc)
new_df = pd.concat([new_df, res])
return new_df

Expand All @@ -200,6 +202,7 @@ def __call__(
formatted_usr_instr,
partition_ids,
safe_mode=safe_mode,
progress_bar_desc=progress_bar_desc,
)

# package answer in a dataframe
Expand Down
5 changes: 4 additions & 1 deletion lotus/sem_ops/sem_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def sem_extract(
extract_quotes: bool = False,
postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess,
safe_mode: bool = False,
progress_bar_desc: str = "Extracting",
) -> SemanticExtractOutput:
"""
Extracts attributes and values from a list of documents using a model.
Expand Down Expand Up @@ -48,7 +49,7 @@ def sem_extract(
show_safe_mode(estimated_cost, estimated_LM_calls)

# call model
lm_output: LMOutput = model(inputs, response_format={"type": "json_object"})
lm_output: LMOutput = model(inputs, response_format={"type": "json_object"}, progress_bar_desc=progress_bar_desc)

# post process results
postprocess_output = postprocessor(lm_output.outputs)
Expand Down Expand Up @@ -79,6 +80,7 @@ def __call__(
postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess,
return_raw_outputs: bool = False,
safe_mode: bool = False,
progress_bar_desc: str = "Extracting",
) -> pd.DataFrame:
"""
Extracts the attributes and values of a dataframe.
Expand Down Expand Up @@ -108,6 +110,7 @@ def __call__(
extract_quotes=extract_quotes,
postprocessor=postprocessor,
safe_mode=safe_mode,
progress_bar_desc=progress_bar_desc,
)

new_df = self._obj.copy()
Expand Down
Loading

0 comments on commit 733f253

Please sign in to comment.