diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index 35a0a5c9..2d661902 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -3,21 +3,22 @@ import lotus from lotus.models import LM +from tokenizers import Tokenizer # Set logger level to DEBUG lotus.logger.setLevel("DEBUG") @pytest.fixture -def setup_models(): +def setup_gpt_models(): # Setup GPT models gpt_4o_mini = LM(model="gpt-4o-mini") gpt_4o = LM(model="gpt-4o") return gpt_4o_mini, gpt_4o -def test_filter_operation(setup_models): - gpt_4o_mini, _ = setup_models +def test_filter_operation(setup_gpt_models): + gpt_4o_mini, _ = setup_gpt_models lotus.settings.configure(lm=gpt_4o_mini) # Test filter operation on an easy dataframe @@ -30,8 +31,8 @@ def test_filter_operation(setup_models): assert filtered_df.equals(expected_df) -def test_filter_cascade(setup_models): - gpt_4o_mini, gpt_4o = setup_models +def test_filter_cascade(setup_gpt_models): + gpt_4o_mini, gpt_4o = setup_gpt_models lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini) data = { @@ -99,8 +100,8 @@ def test_filter_cascade(setup_models): assert stats["filters_resolved_by_helper_model"] > 0, stats -def test_top_k(setup_models): - gpt_4o_mini, _ = setup_models +def test_top_k(setup_gpt_models): + gpt_4o_mini, _ = setup_gpt_models lotus.settings.configure(lm=gpt_4o_mini) data = { @@ -120,8 +121,8 @@ def test_top_k(setup_models): assert top_2_expected == top_2_actual -def test_join(setup_models): - gpt_4o_mini, _ = setup_models +def test_join(setup_gpt_models): + gpt_4o_mini, _ = setup_gpt_models lotus.settings.configure(lm=gpt_4o_mini) data1 = {"School": ["UC Berkeley", "Stanford"]} @@ -136,8 +137,8 @@ def test_join(setup_models): assert joined_pairs == expected_pairs -def test_join_cascade(setup_models): - gpt_4o_mini, gpt_4o = setup_models +def test_join_cascade(setup_gpt_models): + gpt_4o_mini, gpt_4o = setup_gpt_models lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini) data1 = {"School": ["UC Berkeley", "Stanford"]} @@ -163,8 +164,8 @@ def test_join_cascade(setup_models): assert stats["filters_resolved_by_helper_model"] == 0, stats -def test_map_fewshot(setup_models): - gpt_4o_mini, _ = setup_models +def test_map_fewshot(setup_gpt_models): + gpt_4o_mini, _ = setup_gpt_models lotus.settings.configure(lm=gpt_4o_mini) data = {"School": ["UC Berkeley", "Carnegie Mellon"]} @@ -177,3 +178,29 @@ def test_map_fewshot(setup_models): pairs = set(zip(df["School"], df["State"])) expected_pairs = set([("UC Berkeley", "CA"), ("Carnegie Mellon", "PA")]) assert pairs == expected_pairs + +def test_agg_then_map(setup_gpt_models): + _, gpt_4o = setup_gpt_models + lotus.settings.configure(lm=gpt_4o) + + data = {"Text": ["My name is John", "My name is Jane", "My name is John"]} + df = pd.DataFrame(data) + agg_instruction = "What is the most common name in {Text}?" + agg_df = df.sem_agg(agg_instruction, suffix="draft_output") + map_instruction = f"{{draft_output}} is a draft answer to the question 'What is the most common name?'. Clean up the draft answer so that there is just a single name. Your answer MUST be on word" + cleaned_df = agg_df.sem_map(map_instruction, suffix="final_output") + assert cleaned_df["final_output"].values[0] == "John" + +def test_count_tokens(setup_gpt_models): + gpt_4o_mini, _ = setup_gpt_models + lotus.settings.configure(lm=gpt_4o_mini) + + tokens = gpt_4o_mini.count_tokens("Hello, world!") + assert gpt_4o_mini.count_tokens([{"role": "user", "content": "Hello, world!"}]) == tokens + assert tokens < 100 + + custom_tokenizer = Tokenizer.from_pretrained("gpt2") + custom_lm = LM(model="doesn't matter", tokenizer=custom_tokenizer) + tokens = custom_lm.count_tokens("Hello, world!") + assert custom_lm.count_tokens([{"role": "user", "content": "Hello, world!"}]) == tokens + assert tokens < 100 diff --git a/lotus/models/lm.py b/lotus/models/lm.py index e878bb4f..795ebcf6 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -1,15 +1,17 @@ import numpy as np from litellm import batch_completion, token_counter from litellm.types.utils import ChatCompletionTokenLogprob, ModelResponse +from tokenizers import Tokenizer from lotus.types import LMOutput, LogprobsForCascade, LogprobsForFilterCascade class LM: - def __init__(self, model="gpt-4o-mini", temperature=0.0, max_ctx_len=128000, max_tokens=512, **kwargs): + def __init__(self, model: str = "gpt-4o-mini", temperature: float = 0.0, max_ctx_len: int = 128000, max_tokens: int = 512, tokenizer: Tokenizer = None, **kwargs): self.model = model self.max_ctx_len = max_ctx_len self.max_tokens = max_tokens + self.tokenizer = tokenizer self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) self.history = [] @@ -34,9 +36,9 @@ def _get_top_choice_logprobs(self, response: ModelResponse) -> list[ChatCompleti def format_logprobs_for_cascade(self, logprobs: list[list[ChatCompletionTokenLogprob]]) -> LogprobsForCascade: all_tokens = [] all_confidences = [] - for resp in range(len(logprobs)): - tokens = [logprob.token for logprob in logprobs[resp]] - confidences = [np.exp(logprob.logprob) for logprob in logprobs[resp]] + for resp_logprobs in logprobs: + tokens = [logprob.token for logprob in resp_logprobs] + confidences = [np.exp(logprob.logprob) for logprob in resp_logprobs] all_tokens.append(tokens) all_confidences.append(confidences) return LogprobsForCascade(tokens=all_tokens, confidences=all_confidences) @@ -44,24 +46,48 @@ def format_logprobs_for_cascade(self, logprobs: list[list[ChatCompletionTokenLog def format_logprobs_for_filter_cascade( self, logprobs: list[list[ChatCompletionTokenLogprob]] ) -> LogprobsForFilterCascade: - all_tokens = [] - all_confidences = [] + # Get base cascade format first + base_cascade = self.format_logprobs_for_cascade(logprobs) all_true_probs = [] - for resp in range(len(logprobs)): - all_tokens.append([logprob.token for logprob in logprobs[resp]]) - all_confidences.append([np.exp(logprob.logprob) for logprob in logprobs[resp]]) - top_logprobs = {x.token: np.exp(x.logprob) for x in logprobs[resp]} - true_prob, false_prob = 0, 0 - if top_logprobs and "True" in top_logprobs and "False" in top_logprobs: - true_prob = np.exp(top_logprobs["True"]) - false_prob = np.exp(top_logprobs["False"]) - all_true_probs.append(true_prob / (true_prob + false_prob)) - else: - all_true_probs.append(1 if "True" in top_logprobs else 0) - return LogprobsForFilterCascade(tokens=all_tokens, confidences=all_confidences, true_probs=all_true_probs) + def get_normalized_true_prob(token_probs: dict[str, float]) -> float | None: + if "True" in token_probs and "False" in token_probs: + true_prob = token_probs["True"] + false_prob = token_probs["False"] + return true_prob / (true_prob + false_prob) + return None + + # Get true probabilities for filter cascade + for resp_idx, response_logprobs in enumerate(logprobs): + true_prob = None + for logprob in response_logprobs: + token_probs = {top.token: np.exp(top.logprob) for top in logprob.top_logprobs} + true_prob = get_normalized_true_prob(token_probs) + if true_prob is not None: + break + + # Default to 1 if "True" in tokens, 0 if not + if true_prob is None: + true_prob = 1 if "True" in base_cascade.tokens[resp_idx] else 0 + + all_true_probs.append(true_prob) + + return LogprobsForFilterCascade( + tokens=base_cascade.tokens, + confidences=base_cascade.confidences, + true_probs=all_true_probs + ) def count_tokens(self, messages: list[dict[str, str]] | str) -> int: + """Count tokens in messages using either custom tokenizer or model's default tokenizer""" if isinstance(messages, str): messages = [{"role": "user", "content": messages}] - return token_counter(model=self.model, messages=messages) + + kwargs = {"model": self.model, "messages": messages} + if self.tokenizer: + kwargs["custom_tokenizer"] = { + "type": "huggingface_tokenizer", + "tokenizer": self.tokenizer + } + + return token_counter(**kwargs) diff --git a/lotus/types.py b/lotus/types.py index 7852754f..6e11f93a 100644 --- a/lotus/types.py +++ b/lotus/types.py @@ -8,7 +8,6 @@ class StatsMixin(BaseModel): stats: dict[str, Any] | None = None -# TODO: Figure out better logprobs type class LogprobsMixin(BaseModel): # for each response, we have a list of tokens, and for each token, we have a ChatCompletionTokenLogprob logprobs: list[list[ChatCompletionTokenLogprob]] | None = None