Skip to content

Commit

Permalink
remove safe_mode from LM.py + print_total_usage
Browse files Browse the repository at this point in the history
  • Loading branch information
StanChan03 committed Dec 7, 2024
1 parent d301cd8 commit 4d79412
Show file tree
Hide file tree
Showing 6 changed files with 8 additions and 12 deletions.
2 changes: 0 additions & 2 deletions lotus/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,13 @@ def __init__(
max_batch_size: int = 64,
tokenizer: Tokenizer | None = None,
max_cache_size: int = 1024,
safe_mode: bool = False,
**kwargs: dict[str, Any],
):
self.model = model
self.max_ctx_len = max_ctx_len
self.max_tokens = max_tokens
self.max_batch_size = max_batch_size
self.tokenizer = tokenizer
self.safe_mode = safe_mode
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)

self.stats: LMStats = LMStats()
Expand Down
4 changes: 2 additions & 2 deletions lotus/sem_ops/sem_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def doc_formatter(tree_level: int, doc: str, ctr: int) -> str:
docs = summaries
lotus.logger.debug(f"Model outputs from tree level {tree_level}: {summaries}")
tree_level += 1

model.print_total_usage()
if safe_mode:
model.print_total_usage()

return SemanticAggOutput(outputs=summaries)

Expand Down
4 changes: 2 additions & 2 deletions lotus/sem_ops/sem_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def sem_extract(
postprocess_output = postprocessor(lm_output.outputs)
lotus.logger.debug(f"raw_outputs: {lm_output.outputs}")
lotus.logger.debug(f"outputs: {postprocess_output.outputs}")

model.print_total_usage()
if safe_mode:
model.print_total_usage()

return SemanticExtractOutput(**postprocess_output.model_dump())

Expand Down
3 changes: 2 additions & 1 deletion lotus/sem_ops/sem_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def sem_filter(
lotus.logger.debug(f"raw_outputs: {postprocess_output.raw_outputs}")
lotus.logger.debug(f"explanations: {postprocess_output.explanations}")

model.print_total_usage()
if safe_mode:
model.print_total_usage()

return SemanticFilterOutput(**postprocess_output.model_dump(), logprobs=lm_output.logprobs if logprobs else None)

Expand Down
4 changes: 2 additions & 2 deletions lotus/sem_ops/sem_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def sem_map(
lotus.logger.debug(f"raw_outputs: {lm_output.outputs}")
lotus.logger.debug(f"outputs: {postprocess_output.outputs}")
lotus.logger.debug(f"explanations: {postprocess_output.explanations}")

model.print_total_usage()
if safe_mode:
model.print_total_usage()

return SemanticMapOutput(**postprocess_output.model_dump())

Expand Down
3 changes: 0 additions & 3 deletions lotus/sem_ops/sem_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ def llm_naive_sort(
llm_calls = len(pairs)
comparisons, tokens = compare_batch_binary(pairs, user_instruction, strategy=strategy)
if safe_mode:
print("Naive Sort:")
show_safe_mode(tokens, llm_calls)
votes = [0] * N
idx = 0
Expand Down Expand Up @@ -192,7 +191,6 @@ def llm_quicksort(
estimated_quicksort_calls = 2 * len(docs) * np.log(len(docs))
estimated_total_calls = estimated_quickselect_calls + estimated_quicksort_calls
estimated_total_tokens = lotus.settings.lm.count_tokens(sample_prompt) * estimated_total_calls
print("Quicksort:")
show_safe_mode(estimated_total_tokens, estimated_total_calls)

if cascade_threshold is not None:
Expand Down Expand Up @@ -309,7 +307,6 @@ def llm_heapsort(
estimated_top_k_extraction_calls = K * np.log(len(docs))
estimated_total_calls = estimated_heap_construction_calls + estimated_top_k_extraction_calls
estimated_total_cost = lotus.settings.lm.count_tokens(sample_prompt) * estimated_total_calls
print("Heap Sort:")
show_safe_mode(estimated_total_cost, estimated_total_calls)

HeapDoc.num_calls = 0
Expand Down

0 comments on commit 4d79412

Please sign in to comment.