Skip to content

Commit

Permalink
Small updates
Browse files Browse the repository at this point in the history
  • Loading branch information
sidjha1 committed Dec 10, 2024
1 parent c363089 commit 7c24b9e
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 27 deletions.
Binary file removed examples/op_examples/Skill:right_index/index
Binary file not shown.
Binary file removed examples/op_examples/Skill:right_index/vecs
Binary file not shown.
2 changes: 1 addition & 1 deletion lotus/models/cross_encoder_reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ def __init__(
self.model = CrossEncoder(model, device=device) # type: ignore # CrossEncoder has wrong type stubs

def __call__(self, query: str, docs: list[str], K: int) -> RerankerOutput:
results = self.model.rank(query, docs, top_k=K, batch_size=self.max_batch_size)
results = self.model.rank(query, docs, top_k=K, batch_size=self.max_batch_size, show_progress_bar=False)
indices = [int(result["corpus_id"]) for result in results]
return RerankerOutput(indices=indices)
8 changes: 4 additions & 4 deletions lotus/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
def __call__(
self,
messages: list[list[dict[str, str]]],
show_pbar: bool = True,
show_progress_bar: bool = True,
**kwargs: dict[str, Any],
) -> LMOutput:
all_kwargs = {**self.kwargs, **kwargs}
Expand All @@ -62,7 +62,7 @@ def __call__(
self.stats.total_usage.cache_hits += len(messages) - len(uncached_data)

# Process uncached messages in batches
uncached_responses = self._process_uncached_messages(uncached_data, all_kwargs, show_pbar)
uncached_responses = self._process_uncached_messages(uncached_data, all_kwargs, show_progress_bar)

# Add new responses to cache
for resp, (_, hash) in zip(uncached_responses, uncached_data):
Expand All @@ -77,15 +77,15 @@ def __call__(

return LMOutput(outputs=outputs, logprobs=logprobs)

def _process_uncached_messages(self, uncached_data, all_kwargs, show_pbar):
def _process_uncached_messages(self, uncached_data, all_kwargs, show_progress_bar):
"""Processes uncached messages in batches and returns responses."""
uncached_responses = []
total_calls = len(uncached_data)

pbar = tqdm(
total=total_calls,
desc="Processing uncached messages",
disable=not show_pbar,
disable=not show_progress_bar,
bar_format="{l_bar}{bar} {n}/{total} LM calls [{elapsed}<{remaining}, {rate_fmt}{postfix}]",
)
for i in range(0, total_calls, self.max_batch_size):
Expand Down
2 changes: 1 addition & 1 deletion lotus/models/sentence_transformers_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _embed(self, docs: pd.Series | list) -> NDArray[np.float64]:
batch = docs[i : i + self.max_batch_size]
_batch = convert_to_base_data(batch)
torch_embeddings = self.transformer.encode(
_batch, convert_to_tensor=True, normalize_embeddings=self.normalize_embeddings
_batch, convert_to_tensor=True, normalize_embeddings=self.normalize_embeddings, show_progress_bar=False
)
assert isinstance(torch_embeddings, torch.Tensor)
cpu_embeddings = torch_embeddings.cpu().numpy()
Expand Down
8 changes: 4 additions & 4 deletions lotus/sem_ops/sem_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def sem_filter(
strategy: str | None = None,
logprobs: bool = False,
safe_mode: bool = False,
show_pbar: bool = True,
show_progress_bar: bool = True,
) -> SemanticFilterOutput:
"""
Filters a list of documents based on a given user instruction using a language model.
Expand Down Expand Up @@ -56,7 +56,7 @@ def sem_filter(
estimated_total_cost = sum(model.count_tokens(input) for input in inputs)
show_safe_mode(estimated_total_cost, estimated_total_calls)

lm_output: LMOutput = model(inputs, show_pbar=show_pbar, **kwargs)
lm_output: LMOutput = model(inputs, show_progress_bar=show_progress_bar, **kwargs)

postprocess_output = filter_postprocess(
lm_output.outputs, default=default, cot_reasoning=strategy in ["cot", "zs-cot"]
Expand Down Expand Up @@ -236,7 +236,7 @@ def __call__(
logprobs=True,
strategy=helper_strategy,
safe_mode=safe_mode,
show_pbar=True,
show_progress_bar=True,
)
helper_outputs, helper_logprobs = helper_output.outputs, helper_output.logprobs
formatted_helper_logprobs: LogprobsForFilterCascade = (
Expand Down Expand Up @@ -340,7 +340,7 @@ def __call__(
cot_reasoning=cot_reasoning,
strategy=strategy,
safe_mode=safe_mode,
show_pbar=True,
show_progress_bar=True,
)
outputs = output.outputs
raw_outputs = output.raw_outputs
Expand Down
15 changes: 7 additions & 8 deletions lotus/sem_ops/sem_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def sem_join(
default: bool = True,
strategy: str | None = None,
safe_mode: bool = False,
show_pbar: bool = True,
show_progress_bar: bool = True,
) -> SemanticJoinOutput:
"""
Joins two series using a model.
Expand Down Expand Up @@ -69,7 +69,7 @@ def sem_join(
estimated_total_cost = estimated_tokens_per_call * estimated_total_calls
print("Sem_Join:")
show_safe_mode(estimated_total_cost, estimated_total_calls)
if show_pbar:
if show_progress_bar:
pbar = tqdm(
total=len(l1) * len(l2),
desc="Processing uncached messages",
Expand All @@ -88,7 +88,7 @@ def sem_join(
cot_reasoning=cot_reasoning,
default=default,
strategy=strategy,
show_pbar=False,
show_progress_bar=False,
)
outputs = output.outputs
raw_outputs = output.raw_outputs
Expand All @@ -105,7 +105,7 @@ def sem_join(
if output
]
)
if show_pbar:
if show_progress_bar:
pbar.update(len(l1) * len(l2))
pbar.close()

Expand Down Expand Up @@ -217,7 +217,7 @@ def sem_join_cascade(
join_results = [(row["_left_id"], row["_right_id"], None) for _, row in helper_high_conf.iterrows()]

pbar = tqdm(
total=len(helper_low_conf) + len(helper_high_conf),
total=num_large,
desc="Processing uncached messages",
bar_format="{l_bar}{bar} {n}/{total} LM calls [{elapsed}<{remaining}, {rate_fmt}{postfix}]",
)
Expand All @@ -240,9 +240,9 @@ def sem_join_cascade(
cot_reasoning=cot_reasoning,
default=default,
strategy=strategy,
show_pbar=False,
show_progress_bar=False,
)
pbar.update(len(helper_low_conf) + len(helper_high_conf))
pbar.update(num_large)
pbar.close()
join_results.extend(large_join_output.join_results)

Expand Down Expand Up @@ -530,7 +530,6 @@ def learn_join_cascade_threshold(
examples_answers=examples_answers,
cot_reasoning=cot_reasoning,
strategy=strategy,
show_pbar=False,
)

(pos_threshold, neg_threshold), _ = learn_cascade_thresholds(
Expand Down
25 changes: 16 additions & 9 deletions lotus/sem_ops/sem_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def compare_batch_binary(
for doc1, doc2 in pairs:
match_prompts.append(get_match_prompt_binary(doc1, doc2, user_instruction, strategy=strategy))
tokens += lotus.settings.lm.count_tokens(match_prompts[-1])
lm_results: LMOutput = lotus.settings.lm(match_prompts, show_pbar=False)
lm_results: LMOutput = lotus.settings.lm(match_prompts, show_progress_bar=False)
results: list[bool] = list(map(parse_ans_binary, lm_results.outputs))
return results, tokens

Expand Down Expand Up @@ -257,14 +257,21 @@ def quicksort_recursive(indexes: list[int], low: int, high: int, K: int) -> None
if high <= low:
return

if low < high:
pi = partition(indexes, low, high, K)
left_size = pi - low
if left_size + 1 >= K:
quicksort_recursive(indexes, low, pi - 1, K)
else:
quicksort_recursive(indexes, low, pi - 1, left_size)
quicksort_recursive(indexes, pi + 1, high, K - left_size - 1)
num_comparisons = high - low
pbar = tqdm(
total=num_comparisons,
desc="Processing uncached messages",
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} LM calls [{elapsed}<{remaining}]",
)
pi = partition(indexes, low, high, K)
pbar.update(num_comparisons)
pbar.close()
left_size = pi - low
if left_size + 1 >= K:
quicksort_recursive(indexes, low, pi - 1, K)
else:
quicksort_recursive(indexes, low, pi - 1, left_size)
quicksort_recursive(indexes, pi + 1, high, K - left_size - 1)

indexes = list(range(len(docs)))
quicksort_recursive(indexes, 0, len(indexes) - 1, K)
Expand Down

0 comments on commit 7c24b9e

Please sign in to comment.