Skip to content

Commit

Permalink
sem_topk + sem_filter safe_mode
Browse files Browse the repository at this point in the history
  • Loading branch information
StanChan03 committed Dec 4, 2024
1 parent 16dbaa4 commit c693adf
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 13 deletions.
15 changes: 15 additions & 0 deletions lotus/sem_ops/sem_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import lotus
from lotus.templates import task_instructions
from lotus.types import LMOutput, LogprobsForFilterCascade, SemanticFilterOutput
from lotus.utils import show_safe_mode

from .cascade_utils import calibrate_llm_logprobs, importance_sampling, learn_cascade_thresholds
from .postprocessors import filter_postprocess
Expand All @@ -22,6 +23,7 @@ def sem_filter(
cot_reasoning: list[str] | None = None,
strategy: str | None = None,
logprobs: bool = False,
safe_mode: bool = False,
) -> SemanticFilterOutput:
"""
Filters a list of documents based on a given user instruction using a language model.
Expand All @@ -47,6 +49,12 @@ def sem_filter(
lotus.logger.debug(f"input to model: {prompt}")
inputs.append(prompt)
kwargs: dict[str, Any] = {"logprobs": logprobs}

if safe_mode:
estimated_total_calls = len(docs)
estimated_total_cost = sum(model.count_tokens(input) for input in inputs)
show_safe_mode(estimated_total_cost, estimated_total_calls)

lm_output: LMOutput = model(inputs, **kwargs)

postprocess_output = filter_postprocess(
Expand All @@ -56,6 +64,8 @@ def sem_filter(
lotus.logger.debug(f"raw_outputs: {postprocess_output.raw_outputs}")
lotus.logger.debug(f"explanations: {postprocess_output.explanations}")

model.print_total_usage()

return SemanticFilterOutput(**postprocess_output.model_dump(), logprobs=lm_output.logprobs if logprobs else None)


Expand Down Expand Up @@ -88,6 +98,7 @@ def learn_filter_cascade_thresholds(
examples_answers=examples_answers,
cot_reasoning=cot_reasoning,
strategy=strategy,
safe_mode=False,
).outputs

best_combination, _ = learn_cascade_thresholds(
Expand Down Expand Up @@ -137,6 +148,7 @@ def __call__(
precision_target: float | None = None,
failure_probability: float | None = None,
return_stats: bool = False,
safe_mode: bool = False,
) -> pd.DataFrame | tuple[pd.DataFrame, dict[str, Any]]:
"""
Applies semantic filter over a dataframe.
Expand Down Expand Up @@ -221,6 +233,7 @@ def __call__(
cot_reasoning=helper_cot_reasoning,
logprobs=True,
strategy=helper_strategy,
safe_mode=safe_mode,
)
helper_outputs, helper_logprobs = helper_output.outputs, helper_output.logprobs
formatted_helper_logprobs: LogprobsForFilterCascade = (
Expand Down Expand Up @@ -302,6 +315,7 @@ def __call__(
examples_answers=examples_answers,
cot_reasoning=cot_reasoning,
strategy=strategy,
safe_mode=safe_mode,
)

for idx, large_idx in enumerate(low_conf_idxs):
Expand All @@ -322,6 +336,7 @@ def __call__(
examples_answers=examples_answers,
cot_reasoning=cot_reasoning,
strategy=strategy,
safe_mode=safe_mode,
)
outputs = output.outputs
raw_outputs = output.raw_outputs
Expand Down
32 changes: 19 additions & 13 deletions lotus/sem_ops/sem_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,15 @@ def llm_quicksort(
stats = {}
stats["total_tokens"] = 0
stats["total_llm_calls"] = 0
sample_prompt = get_match_prompt_binary(docs[0], docs[1], user_instruction, strategy=strategy)
estimated_quickselect_calls = 2 * K
estimated_quicksort_calls = 2 * len(docs) * np.log(len(docs))
estimated_total_calls = estimated_quickselect_calls + estimated_quicksort_calls
estimated_total_tokens = lotus.settings.lm.count_tokens(sample_prompt) * estimated_total_calls
if safe_mode:
print("Quicksort:")
show_safe_mode(estimated_total_tokens, estimated_total_calls)
print("\n")

if cascade_threshold is not None:
stats["total_small_tokens"] = 0
Expand Down Expand Up @@ -214,16 +223,10 @@ def partition(indexes: list[int], low: int, high: int, K: int) -> int:
indexes[pivot_index], indexes[high] = indexes[high], indexes[pivot_index]

pairs = [(docs[indexes[j]], pivot) for j in range(low, high)]
if safe_mode:
estimated_LM_calls = len(pairs)
if cascade_threshold is None:
comparisons, tokens = compare_batch_binary(pairs, user_instruction, strategy=strategy)
stats["total_tokens"] += tokens
stats["total_llm_calls"] += len(pairs)
if safe_mode:
estimated_costs = tokens
print("Quicksort:")
show_safe_mode(estimated_costs, estimated_LM_calls)
else:
comparisons, small_tokens, large_tokens, num_large_calls = compare_batch_binary_cascade(
pairs,
Expand Down Expand Up @@ -301,22 +304,25 @@ def llm_heapsort(
Returns:
SemanticTopKOutput: The indexes of the top k documents and stats.
"""
sample_prompt = get_match_prompt_binary(docs[0], docs[1], user_instruction, strategy=strategy)
estimated_heap_construction_calls = len(docs) * np.log(len(docs))
estimated_top_k_extraction_calls = K * np.log(len(docs))
estimated_total_calls = estimated_heap_construction_calls + estimated_top_k_extraction_calls
estimated_total_cost = lotus.settings.lm.count_tokens(sample_prompt) * estimated_total_calls
if safe_mode:
print("Heap Sort:")
show_safe_mode(estimated_total_cost, estimated_total_calls)
print("\n")

HeapDoc.num_calls = 0
HeapDoc.total_tokens = 0
HeapDoc.strategy = strategy
N = len(docs)
heap = [HeapDoc(docs[idx], user_instruction, idx) for idx in range(N)]
estimated_LM_calls = len(heap)

heap = heapq.nsmallest(K, heap)
indexes = [heapq.heappop(heap).idx for _ in range(len(heap))]

estimated_cost = HeapDoc.total_tokens
if safe_mode:
print("Heap Sort:")
show_safe_mode(estimated_cost, estimated_LM_calls)
print("\n")

stats = {"total_tokens": HeapDoc.total_tokens, "total_llm_calls": HeapDoc.num_calls}
return SemanticTopKOutput(indexes=indexes, stats=stats)

Expand Down

0 comments on commit c693adf

Please sign in to comment.