From 9689736b54187b531f1ff01c8ea878e39326ba80 Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Wed, 27 Nov 2024 17:32:37 -0500 Subject: [PATCH] safe_mode flag and removed reset_stats --- lotus/models/lm.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/lotus/models/lm.py b/lotus/models/lm.py index 230507a8..cd07a6e8 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -25,6 +25,7 @@ def __init__( max_batch_size: int = 64, tokenizer: Tokenizer | None = None, max_cache_size: int = 1024, + safe_mode: bool = False, **kwargs: dict[str, Any], ): self.model = model @@ -32,12 +33,15 @@ def __init__( self.max_tokens = max_tokens self.max_batch_size = max_batch_size self.tokenizer = tokenizer + self.safe_mode = safe_mode self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) self.stats: LMStats = LMStats() self.cache = Cache(max_cache_size) - def __call__(self, messages: list[list[dict[str, str]]], **kwargs: dict[str, Any]) -> LMOutput: + def __call__( + self, messages: list[list[dict[str, str]]], safe_mode: bool = False, **kwargs: dict[str, Any] + ) -> LMOutput: all_kwargs = {**self.kwargs, **kwargs} # Set top_logprobs if logprobs requested @@ -65,20 +69,17 @@ def __call__(self, messages: list[list[dict[str, str]]], **kwargs: dict[str, Any logprobs = ( [self._get_top_choice_logprobs(resp) for resp in all_responses] if all_kwargs.get("logprobs") else None ) - - self.print_total_usage() - self.reset_stats() + if self.safe_mode: + self.print_total_usage() return LMOutput(outputs=outputs, logprobs=logprobs) def _process_uncached_messages(self, uncached_data, all_kwargs): """Processes uncached messages in batches and returns responses.""" uncached_responses = [] - with tqdm(total=len(uncached_data), desc="Processing uncached messages") as pbar: - for i in range(0, len(uncached_data), self.max_batch_size): - batch = [msg for msg, _ in uncached_data[i : i + self.max_batch_size]] - uncached_responses.extend(batch_completion(self.model, batch, drop_params=True, **all_kwargs)) - pbar.update(len(batch)) + for i in tqdm(range(0, len(uncached_data), self.max_batch_size), desc="Processing uncached messages"): + batch = [msg for msg, _ in uncached_data[i : i + self.max_batch_size]] + uncached_responses.extend(batch_completion(self.model, batch, drop_params=True, **all_kwargs)) return uncached_responses def _cache_response(self, response, hash):