diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index bad0d5e3..4d91912b 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -7,6 +7,7 @@ import lotus from lotus.templates import task_instructions from lotus.types import LMOutput, LogprobsForFilterCascade, SemanticFilterOutput +from lotus.utils import show_safe_mode from .cascade_utils import calibrate_llm_logprobs, importance_sampling, learn_cascade_thresholds from .postprocessors import filter_postprocess @@ -22,6 +23,7 @@ def sem_filter( cot_reasoning: list[str] | None = None, strategy: str | None = None, logprobs: bool = False, + safe_mode: bool = False, ) -> SemanticFilterOutput: """ Filters a list of documents based on a given user instruction using a language model. @@ -47,6 +49,12 @@ def sem_filter( lotus.logger.debug(f"input to model: {prompt}") inputs.append(prompt) kwargs: dict[str, Any] = {"logprobs": logprobs} + + if safe_mode: + estimated_total_calls = len(docs) + estimated_total_cost = sum(model.count_tokens(input) for input in inputs) + show_safe_mode(estimated_total_cost, estimated_total_calls) + lm_output: LMOutput = model(inputs, **kwargs) postprocess_output = filter_postprocess( @@ -56,6 +64,8 @@ def sem_filter( lotus.logger.debug(f"raw_outputs: {postprocess_output.raw_outputs}") lotus.logger.debug(f"explanations: {postprocess_output.explanations}") + model.print_total_usage() + return SemanticFilterOutput(**postprocess_output.model_dump(), logprobs=lm_output.logprobs if logprobs else None) @@ -88,6 +98,7 @@ def learn_filter_cascade_thresholds( examples_answers=examples_answers, cot_reasoning=cot_reasoning, strategy=strategy, + safe_mode=False, ).outputs best_combination, _ = learn_cascade_thresholds( @@ -137,6 +148,7 @@ def __call__( precision_target: float | None = None, failure_probability: float | None = None, return_stats: bool = False, + safe_mode: bool = False, ) -> pd.DataFrame | tuple[pd.DataFrame, dict[str, Any]]: """ Applies semantic filter over a dataframe. @@ -221,6 +233,7 @@ def __call__( cot_reasoning=helper_cot_reasoning, logprobs=True, strategy=helper_strategy, + safe_mode=safe_mode, ) helper_outputs, helper_logprobs = helper_output.outputs, helper_output.logprobs formatted_helper_logprobs: LogprobsForFilterCascade = ( @@ -302,6 +315,7 @@ def __call__( examples_answers=examples_answers, cot_reasoning=cot_reasoning, strategy=strategy, + safe_mode=safe_mode, ) for idx, large_idx in enumerate(low_conf_idxs): @@ -322,6 +336,7 @@ def __call__( examples_answers=examples_answers, cot_reasoning=cot_reasoning, strategy=strategy, + safe_mode=safe_mode, ) outputs = output.outputs raw_outputs = output.raw_outputs diff --git a/lotus/sem_ops/sem_topk.py b/lotus/sem_ops/sem_topk.py index 62fcbf43..1d788fef 100644 --- a/lotus/sem_ops/sem_topk.py +++ b/lotus/sem_ops/sem_topk.py @@ -187,6 +187,15 @@ def llm_quicksort( stats = {} stats["total_tokens"] = 0 stats["total_llm_calls"] = 0 + sample_prompt = get_match_prompt_binary(docs[0], docs[1], user_instruction, strategy=strategy) + estimated_quickselect_calls = 2 * K + estimated_quicksort_calls = 2 * len(docs) * np.log(len(docs)) + estimated_total_calls = estimated_quickselect_calls + estimated_quicksort_calls + estimated_total_tokens = lotus.settings.lm.count_tokens(sample_prompt) * estimated_total_calls + if safe_mode: + print("Quicksort:") + show_safe_mode(estimated_total_tokens, estimated_total_calls) + print("\n") if cascade_threshold is not None: stats["total_small_tokens"] = 0 @@ -214,16 +223,10 @@ def partition(indexes: list[int], low: int, high: int, K: int) -> int: indexes[pivot_index], indexes[high] = indexes[high], indexes[pivot_index] pairs = [(docs[indexes[j]], pivot) for j in range(low, high)] - if safe_mode: - estimated_LM_calls = len(pairs) if cascade_threshold is None: comparisons, tokens = compare_batch_binary(pairs, user_instruction, strategy=strategy) stats["total_tokens"] += tokens stats["total_llm_calls"] += len(pairs) - if safe_mode: - estimated_costs = tokens - print("Quicksort:") - show_safe_mode(estimated_costs, estimated_LM_calls) else: comparisons, small_tokens, large_tokens, num_large_calls = compare_batch_binary_cascade( pairs, @@ -301,22 +304,25 @@ def llm_heapsort( Returns: SemanticTopKOutput: The indexes of the top k documents and stats. """ + sample_prompt = get_match_prompt_binary(docs[0], docs[1], user_instruction, strategy=strategy) + estimated_heap_construction_calls = len(docs) * np.log(len(docs)) + estimated_top_k_extraction_calls = K * np.log(len(docs)) + estimated_total_calls = estimated_heap_construction_calls + estimated_top_k_extraction_calls + estimated_total_cost = lotus.settings.lm.count_tokens(sample_prompt) * estimated_total_calls + if safe_mode: + print("Heap Sort:") + show_safe_mode(estimated_total_cost, estimated_total_calls) + print("\n") + HeapDoc.num_calls = 0 HeapDoc.total_tokens = 0 HeapDoc.strategy = strategy N = len(docs) heap = [HeapDoc(docs[idx], user_instruction, idx) for idx in range(N)] - estimated_LM_calls = len(heap) heap = heapq.nsmallest(K, heap) indexes = [heapq.heappop(heap).idx for _ in range(len(heap))] - estimated_cost = HeapDoc.total_tokens - if safe_mode: - print("Heap Sort:") - show_safe_mode(estimated_cost, estimated_LM_calls) - print("\n") - stats = {"total_tokens": HeapDoc.total_tokens, "total_llm_calls": HeapDoc.num_calls} return SemanticTopKOutput(indexes=indexes, stats=stats)