Skip to content

Commit

Permalink
safe_mode for sem_agg+sem_topk
Browse files Browse the repository at this point in the history
  • Loading branch information
StanChan03 committed Dec 2, 2024
1 parent 786cfdc commit 16dbaa4
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 7 deletions.
3 changes: 0 additions & 3 deletions lotus/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,6 @@ def __call__(
logprobs = (
[self._get_top_choice_logprobs(resp) for resp in all_responses] if all_kwargs.get("logprobs") else None
)
if self.safe_mode:
print("\n")
self.print_total_usage()

return LMOutput(outputs=outputs, logprobs=logprobs)

Expand Down
22 changes: 22 additions & 0 deletions lotus/sem_ops/sem_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import lotus.models
from lotus.templates import task_instructions
from lotus.types import LMOutput, SemanticAggOutput
from lotus.utils import show_safe_mode


def sem_agg(
docs: list[str],
model: lotus.models.LM,
user_instruction: str,
partition_ids: list[int],
safe_mode: bool = False,
) -> SemanticAggOutput:
"""
Aggregates multiple documents into a single answer using a model.
Expand Down Expand Up @@ -76,6 +78,12 @@ def doc_formatter(tree_level: int, doc: str, ctr: int) -> str:
template_tokens = model.count_tokens(template)
context_tokens = 0
doc_ctr = 1 # num docs in current prompt

if safe_mode:
print(f"Starting tree level {tree_level} aggregation with {len(docs)} docs")
estimated_LM_calls = 0
estimated_costs = 0

for idx in range(len(docs)):
partition_id = partition_ids[idx]
formatted_doc = doc_formatter(tree_level, docs[idx], doc_ctr)
Expand All @@ -98,6 +106,9 @@ def doc_formatter(tree_level: int, doc: str, ctr: int) -> str:
context_str = formatted_doc
context_tokens = new_tokens
doc_ctr += 1
if safe_mode:
estimated_LM_calls += 1
estimated_costs += model.count_tokens(prompt)
else:
context_str = context_str + formatted_doc
context_tokens += new_tokens
Expand All @@ -108,6 +119,13 @@ def doc_formatter(tree_level: int, doc: str, ctr: int) -> str:
lotus.logger.debug(f"Prompt added to batch: {prompt}")
batch.append([{"role": "user", "content": prompt}])
new_partition_ids.append(cur_partition_id)
if safe_mode:
estimated_LM_calls += 1
estimated_costs += model.count_tokens(prompt)

if safe_mode:
show_safe_mode(estimated_costs, estimated_LM_calls)

lm_output: LMOutput = model(batch)

summaries = lm_output.outputs
Expand All @@ -118,6 +136,8 @@ def doc_formatter(tree_level: int, doc: str, ctr: int) -> str:
lotus.logger.debug(f"Model outputs from tree level {tree_level}: {summaries}")
tree_level += 1

model.print_total_usage()

return SemanticAggOutput(outputs=summaries)


Expand All @@ -139,6 +159,7 @@ def __call__(
all_cols: bool = False,
suffix: str = "_output",
group_by: list[str] | None = None,
safe_mode: bool = False,
) -> pd.DataFrame:
"""
Applies semantic aggregation over a dataframe.
Expand Down Expand Up @@ -189,6 +210,7 @@ def __call__(
lotus.settings.lm,
formatted_usr_instr,
partition_ids,
safe_mode=safe_mode,
)

# package answer in a dataframe
Expand Down
7 changes: 6 additions & 1 deletion lotus/sem_ops/sem_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def sem_extract(
output_cols: dict[str, str | None],
extract_quotes: bool = False,
postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess,
safe_mode: bool = False,
) -> SemanticExtractOutput:
"""
Extracts attributes and values from a list of documents using a model.
Expand All @@ -41,7 +42,7 @@ def sem_extract(
inputs.append(prompt)

# check if safe_mode is enabled
if model.safe_mode:
if safe_mode:
estimated_cost = sum(model.count_tokens(input) for input in inputs)
estimated_LM_calls = len(docs)
show_safe_mode(estimated_cost, estimated_LM_calls)
Expand All @@ -54,6 +55,8 @@ def sem_extract(
lotus.logger.debug(f"raw_outputs: {lm_output.outputs}")
lotus.logger.debug(f"outputs: {postprocess_output.outputs}")

model.print_total_usage()

return SemanticExtractOutput(**postprocess_output.model_dump())


Expand All @@ -75,6 +78,7 @@ def __call__(
extract_quotes: bool = False,
postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess,
return_raw_outputs: bool = False,
safe_mode: bool = False,
) -> pd.DataFrame:
"""
Extracts the attributes and values of a dataframe.
Expand Down Expand Up @@ -103,6 +107,7 @@ def __call__(
output_cols=output_cols,
extract_quotes=extract_quotes,
postprocessor=postprocessor,
safe_mode=safe_mode,
)

new_df = self._obj.copy()
Expand Down
7 changes: 6 additions & 1 deletion lotus/sem_ops/sem_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def sem_map(
examples_answers: list[str] | None = None,
cot_reasoning: list[str] | None = None,
strategy: str | None = None,
safe_mode: bool = False,
) -> SemanticMapOutput:
"""
Maps a list of documents to a list of outputs using a model.
Expand Down Expand Up @@ -46,7 +47,7 @@ def sem_map(
inputs.append(prompt)

# check if safe_mode is enabled
if model.safe_mode:
if safe_mode:
estimated_cost = sum(model.count_tokens(input) for input in inputs)
estimated_LM_calls = len(docs)
show_safe_mode(estimated_cost, estimated_LM_calls)
Expand All @@ -60,6 +61,8 @@ def sem_map(
lotus.logger.debug(f"outputs: {postprocess_output.outputs}")
lotus.logger.debug(f"explanations: {postprocess_output.explanations}")

model.print_total_usage()

return SemanticMapOutput(**postprocess_output.model_dump())


Expand All @@ -85,6 +88,7 @@ def __call__(
suffix: str = "_map",
examples: pd.DataFrame | None = None,
strategy: str | None = None,
safe_mode: bool = False,
) -> pd.DataFrame:
"""
Applies semantic map over a dataframe.
Expand Down Expand Up @@ -132,6 +136,7 @@ def __call__(
examples_answers=examples_answers,
cot_reasoning=cot_reasoning,
strategy=strategy,
safe_mode=safe_mode,
)

new_df = self._obj.copy()
Expand Down
33 changes: 32 additions & 1 deletion lotus/sem_ops/sem_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import lotus
from lotus.templates import task_instructions
from lotus.types import LMOutput, SemanticTopKOutput
from lotus.utils import show_safe_mode


def get_match_prompt_binary(
Expand Down Expand Up @@ -121,6 +122,7 @@ def llm_naive_sort(
docs: list[dict[str, Any]],
user_instruction: str,
strategy: str | None = None,
safe_mode: bool = False,
) -> SemanticTopKOutput:
"""
Sorts the documents using a naive quadratic method.
Expand All @@ -140,6 +142,10 @@ def llm_naive_sort(

llm_calls = len(pairs)
comparisons, tokens = compare_batch_binary(pairs, user_instruction, strategy=strategy)
if safe_mode:
print("Naive Sort:")
show_safe_mode(tokens, llm_calls)
print("\n")
votes = [0] * N
idx = 0
for i in range(N):
Expand All @@ -163,6 +169,7 @@ def llm_quicksort(
embedding: bool = False,
strategy: str | None = None,
cascade_threshold: float | None = None,
safe_mode: bool = False,
) -> SemanticTopKOutput:
"""
Sorts the documents using quicksort.
Expand Down Expand Up @@ -207,10 +214,16 @@ def partition(indexes: list[int], low: int, high: int, K: int) -> int:
indexes[pivot_index], indexes[high] = indexes[high], indexes[pivot_index]

pairs = [(docs[indexes[j]], pivot) for j in range(low, high)]
if safe_mode:
estimated_LM_calls = len(pairs)
if cascade_threshold is None:
comparisons, tokens = compare_batch_binary(pairs, user_instruction, strategy=strategy)
stats["total_tokens"] += tokens
stats["total_llm_calls"] += len(pairs)
if safe_mode:
estimated_costs = tokens
print("Quicksort:")
show_safe_mode(estimated_costs, estimated_LM_calls)
else:
comparisons, small_tokens, large_tokens, num_large_calls = compare_batch_binary_cascade(
pairs,
Expand Down Expand Up @@ -275,6 +288,7 @@ def llm_heapsort(
user_instruction: str,
K: int,
strategy: str | None = None,
safe_mode: bool = False,
) -> SemanticTopKOutput:
"""
Sorts the documents using a heap.
Expand All @@ -292,9 +306,17 @@ def llm_heapsort(
HeapDoc.strategy = strategy
N = len(docs)
heap = [HeapDoc(docs[idx], user_instruction, idx) for idx in range(N)]
estimated_LM_calls = len(heap)

heap = heapq.nsmallest(K, heap)
indexes = [heapq.heappop(heap).idx for _ in range(len(heap))]

estimated_cost = HeapDoc.total_tokens
if safe_mode:
print("Heap Sort:")
show_safe_mode(estimated_cost, estimated_LM_calls)
print("\n")

stats = {"total_tokens": HeapDoc.total_tokens, "total_llm_calls": HeapDoc.num_calls}
return SemanticTopKOutput(indexes=indexes, stats=stats)

Expand All @@ -320,6 +342,7 @@ def __call__(
group_by: list[str] | None = None,
cascade_threshold: float | None = None,
return_stats: bool = False,
safe_mode: bool = False,
) -> pd.DataFrame | tuple[pd.DataFrame, dict[str, Any]]:
"""
Sorts the DataFrame based on the user instruction and returns the top K rows.
Expand Down Expand Up @@ -392,14 +415,22 @@ def __call__(
embedding=method == "quick-sem",
strategy=strategy,
cascade_threshold=cascade_threshold,
safe_mode=safe_mode,
)
elif method == "heap":
output = llm_heapsort(multimodal_data, formatted_usr_instr, K, strategy=strategy)
output = llm_heapsort(
multimodal_data,
formatted_usr_instr,
K,
strategy=strategy,
safe_mode=safe_mode,
)
elif method == "naive":
output = llm_naive_sort(
multimodal_data,
formatted_usr_instr,
strategy=strategy,
safe_mode=safe_mode,
)
else:
raise ValueError(f"Method {method} not recognized")
Expand Down
2 changes: 1 addition & 1 deletion lotus/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def show_safe_mode(estimated_cost, estimated_LM_calls):
for i in range(5, 0, -1):
print(f"Proceeding execution in {i} seconds... Press CTRL+C to cancel", end="\r")
time.sleep(1)
print(" " * 30, end="\r")
print(" " * 60, end="\r")
except KeyboardInterrupt:
print("\nExecution cancelled by user")
exit(0)

0 comments on commit 16dbaa4

Please sign in to comment.