Skip to content

Commit

Permalink
sem_join + sem_filter pbars
Browse files Browse the repository at this point in the history
  • Loading branch information
StanChan03 committed Dec 8, 2024
1 parent 99f96b2 commit 255b39e
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 5 deletions.
23 changes: 19 additions & 4 deletions lotus/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ def __init__(
self.cache = Cache(max_cache_size)

def __call__(
self, messages: list[list[dict[str, str]]], safe_mode: bool = False, **kwargs: dict[str, Any]
self,
messages: list[list[dict[str, str]]],
show_pbar: bool = True,
**kwargs: dict[str, Any],
) -> LMOutput:
all_kwargs = {**self.kwargs, **kwargs}

Expand All @@ -59,7 +62,7 @@ def __call__(
self.stats.total_usage.cache_hits += len(messages) - len(uncached_data)

# Process uncached messages in batches
uncached_responses = self._process_uncached_messages(uncached_data, all_kwargs)
uncached_responses = self._process_uncached_messages(uncached_data, all_kwargs, show_pbar)

# Add new responses to cache
for resp, (_, hash) in zip(uncached_responses, uncached_data):
Expand All @@ -74,12 +77,24 @@ def __call__(

return LMOutput(outputs=outputs, logprobs=logprobs)

def _process_uncached_messages(self, uncached_data, all_kwargs):
def _process_uncached_messages(self, uncached_data, all_kwargs, show_pbar):
"""Processes uncached messages in batches and returns responses."""
uncached_responses = []
for i in tqdm(range(0, len(uncached_data), self.max_batch_size), desc="Processing uncached messages"):
total_calls = len(uncached_data)

pbar = tqdm(
total=total_calls,
desc="Processing uncached messages",
disable=not show_pbar,
bar_format="{l_bar}{bar} {n}/{total} LM calls [{elapsed}<{remaining}, {rate_fmt}{postfix}]",
)
for i in range(0, total_calls, self.max_batch_size):
batch = [msg for msg, _ in uncached_data[i : i + self.max_batch_size]]
uncached_responses.extend(batch_completion(self.model, batch, drop_params=True, **all_kwargs))

pbar.update(len(batch))
pbar.close()

return uncached_responses

def _cache_response(self, response, hash):
Expand Down
5 changes: 4 additions & 1 deletion lotus/sem_ops/sem_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def sem_filter(
strategy: str | None = None,
logprobs: bool = False,
safe_mode: bool = False,
show_pbar: bool = True,
) -> SemanticFilterOutput:
"""
Filters a list of documents based on a given user instruction using a language model.
Expand Down Expand Up @@ -55,7 +56,7 @@ def sem_filter(
estimated_total_cost = sum(model.count_tokens(input) for input in inputs)
show_safe_mode(estimated_total_cost, estimated_total_calls)

lm_output: LMOutput = model(inputs, **kwargs)
lm_output: LMOutput = model(inputs, show_pbar=show_pbar, **kwargs)

postprocess_output = filter_postprocess(
lm_output.outputs, default=default, cot_reasoning=strategy in ["cot", "zs-cot"]
Expand Down Expand Up @@ -235,6 +236,7 @@ def __call__(
logprobs=True,
strategy=helper_strategy,
safe_mode=safe_mode,
show_pbar=True,
)
helper_outputs, helper_logprobs = helper_output.outputs, helper_output.logprobs
formatted_helper_logprobs: LogprobsForFilterCascade = (
Expand Down Expand Up @@ -338,6 +340,7 @@ def __call__(
cot_reasoning=cot_reasoning,
strategy=strategy,
safe_mode=safe_mode,
show_pbar=True,
)
outputs = output.outputs
raw_outputs = output.raw_outputs
Expand Down
10 changes: 10 additions & 0 deletions lotus/sem_ops/sem_join.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any

import pandas as pd
from tqdm import tqdm

import lotus
from lotus.templates import task_instructions
Expand Down Expand Up @@ -68,6 +69,11 @@ def sem_join(
print("Sem_Join:")
show_safe_mode(estimated_total_cost, estimated_total_calls)

pbar = tqdm(
total=len(l1),
desc="Processing uncached messages",
bar_format="{l_bar}{bar} {n}/{total} LM Calls [{elapsed}<{remaining}, {rate_fmt}{postfix}]",
)
# for i1 in enumerate(l1):
for id1, i1 in zip(ids1, left_multimodal_data):
# perform llm filter
Expand All @@ -81,6 +87,7 @@ def sem_join(
cot_reasoning=cot_reasoning,
default=default,
strategy=strategy,
show_pbar=False,
)
outputs = output.outputs
raw_outputs = output.raw_outputs
Expand All @@ -98,6 +105,9 @@ def sem_join(
]
)

pbar.update(1)
pbar.close()

lotus.logger.debug(f"outputs: {filter_outputs}")
lotus.logger.debug(f"explanations: {all_explanations}")

Expand Down

0 comments on commit 255b39e

Please sign in to comment.