Skip to content

Commit

Permalink
token counting + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sidjha1 committed Nov 2, 2024
1 parent 761df36 commit bf8c73a
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 33 deletions.
53 changes: 40 additions & 13 deletions .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,22 @@

import lotus
from lotus.models import LM
from tokenizers import Tokenizer

# Set logger level to DEBUG
lotus.logger.setLevel("DEBUG")


@pytest.fixture
def setup_models():
def setup_gpt_models():
# Setup GPT models
gpt_4o_mini = LM(model="gpt-4o-mini")
gpt_4o = LM(model="gpt-4o")
return gpt_4o_mini, gpt_4o


def test_filter_operation(setup_models):
gpt_4o_mini, _ = setup_models
def test_filter_operation(setup_gpt_models):
gpt_4o_mini, _ = setup_gpt_models
lotus.settings.configure(lm=gpt_4o_mini)

# Test filter operation on an easy dataframe
Expand All @@ -30,8 +31,8 @@ def test_filter_operation(setup_models):
assert filtered_df.equals(expected_df)


def test_filter_cascade(setup_models):
gpt_4o_mini, gpt_4o = setup_models
def test_filter_cascade(setup_gpt_models):
gpt_4o_mini, gpt_4o = setup_gpt_models
lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini)

data = {
Expand Down Expand Up @@ -99,8 +100,8 @@ def test_filter_cascade(setup_models):
assert stats["filters_resolved_by_helper_model"] > 0, stats


def test_top_k(setup_models):
gpt_4o_mini, _ = setup_models
def test_top_k(setup_gpt_models):
gpt_4o_mini, _ = setup_gpt_models
lotus.settings.configure(lm=gpt_4o_mini)

data = {
Expand All @@ -120,8 +121,8 @@ def test_top_k(setup_models):
assert top_2_expected == top_2_actual


def test_join(setup_models):
gpt_4o_mini, _ = setup_models
def test_join(setup_gpt_models):
gpt_4o_mini, _ = setup_gpt_models
lotus.settings.configure(lm=gpt_4o_mini)

data1 = {"School": ["UC Berkeley", "Stanford"]}
Expand All @@ -136,8 +137,8 @@ def test_join(setup_models):
assert joined_pairs == expected_pairs


def test_join_cascade(setup_models):
gpt_4o_mini, gpt_4o = setup_models
def test_join_cascade(setup_gpt_models):
gpt_4o_mini, gpt_4o = setup_gpt_models
lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini)

data1 = {"School": ["UC Berkeley", "Stanford"]}
Expand All @@ -163,8 +164,8 @@ def test_join_cascade(setup_models):
assert stats["filters_resolved_by_helper_model"] == 0, stats


def test_map_fewshot(setup_models):
gpt_4o_mini, _ = setup_models
def test_map_fewshot(setup_gpt_models):
gpt_4o_mini, _ = setup_gpt_models
lotus.settings.configure(lm=gpt_4o_mini)

data = {"School": ["UC Berkeley", "Carnegie Mellon"]}
Expand All @@ -177,3 +178,29 @@ def test_map_fewshot(setup_models):
pairs = set(zip(df["School"], df["State"]))
expected_pairs = set([("UC Berkeley", "CA"), ("Carnegie Mellon", "PA")])
assert pairs == expected_pairs

def test_agg_then_map(setup_gpt_models):
_, gpt_4o = setup_gpt_models
lotus.settings.configure(lm=gpt_4o)

data = {"Text": ["My name is John", "My name is Jane", "My name is John"]}
df = pd.DataFrame(data)
agg_instruction = "What is the most common name in {Text}?"
agg_df = df.sem_agg(agg_instruction, suffix="draft_output")
map_instruction = f"{{draft_output}} is a draft answer to the question 'What is the most common name?'. Clean up the draft answer so that there is just a single name. Your answer MUST be on word"
cleaned_df = agg_df.sem_map(map_instruction, suffix="final_output")
assert cleaned_df["final_output"].values[0] == "John"

def test_count_tokens(setup_gpt_models):
gpt_4o_mini, _ = setup_gpt_models
lotus.settings.configure(lm=gpt_4o_mini)

tokens = gpt_4o_mini.count_tokens("Hello, world!")
assert gpt_4o_mini.count_tokens([{"role": "user", "content": "Hello, world!"}]) == tokens
assert tokens < 100

custom_tokenizer = Tokenizer.from_pretrained("gpt2")
custom_lm = LM(model="doesn't matter", tokenizer=custom_tokenizer)
tokens = custom_lm.count_tokens("Hello, world!")
assert custom_lm.count_tokens([{"role": "user", "content": "Hello, world!"}]) == tokens
assert tokens < 100
64 changes: 45 additions & 19 deletions lotus/models/lm.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import numpy as np
from litellm import batch_completion, token_counter
from litellm.types.utils import ChatCompletionTokenLogprob, ModelResponse
from tokenizers import Tokenizer

from lotus.types import LMOutput, LogprobsForCascade, LogprobsForFilterCascade


class LM:
def __init__(self, model="gpt-4o-mini", temperature=0.0, max_ctx_len=128000, max_tokens=512, **kwargs):
def __init__(self, model: str = "gpt-4o-mini", temperature: float = 0.0, max_ctx_len: int = 128000, max_tokens: int = 512, tokenizer: Tokenizer = None, **kwargs):
self.model = model
self.max_ctx_len = max_ctx_len
self.max_tokens = max_tokens
self.tokenizer = tokenizer
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)
self.history = []

Expand All @@ -34,34 +36,58 @@ def _get_top_choice_logprobs(self, response: ModelResponse) -> list[ChatCompleti
def format_logprobs_for_cascade(self, logprobs: list[list[ChatCompletionTokenLogprob]]) -> LogprobsForCascade:
all_tokens = []
all_confidences = []
for resp in range(len(logprobs)):
tokens = [logprob.token for logprob in logprobs[resp]]
confidences = [np.exp(logprob.logprob) for logprob in logprobs[resp]]
for resp_logprobs in logprobs:
tokens = [logprob.token for logprob in resp_logprobs]
confidences = [np.exp(logprob.logprob) for logprob in resp_logprobs]
all_tokens.append(tokens)
all_confidences.append(confidences)
return LogprobsForCascade(tokens=all_tokens, confidences=all_confidences)

def format_logprobs_for_filter_cascade(
self, logprobs: list[list[ChatCompletionTokenLogprob]]
) -> LogprobsForFilterCascade:
all_tokens = []
all_confidences = []
# Get base cascade format first
base_cascade = self.format_logprobs_for_cascade(logprobs)
all_true_probs = []

for resp in range(len(logprobs)):
all_tokens.append([logprob.token for logprob in logprobs[resp]])
all_confidences.append([np.exp(logprob.logprob) for logprob in logprobs[resp]])
top_logprobs = {x.token: np.exp(x.logprob) for x in logprobs[resp]}
true_prob, false_prob = 0, 0
if top_logprobs and "True" in top_logprobs and "False" in top_logprobs:
true_prob = np.exp(top_logprobs["True"])
false_prob = np.exp(top_logprobs["False"])
all_true_probs.append(true_prob / (true_prob + false_prob))
else:
all_true_probs.append(1 if "True" in top_logprobs else 0)
return LogprobsForFilterCascade(tokens=all_tokens, confidences=all_confidences, true_probs=all_true_probs)
def get_normalized_true_prob(token_probs: dict[str, float]) -> float | None:
if "True" in token_probs and "False" in token_probs:
true_prob = token_probs["True"]
false_prob = token_probs["False"]
return true_prob / (true_prob + false_prob)
return None

# Get true probabilities for filter cascade
for resp_idx, response_logprobs in enumerate(logprobs):
true_prob = None
for logprob in response_logprobs:
token_probs = {top.token: np.exp(top.logprob) for top in logprob.top_logprobs}
true_prob = get_normalized_true_prob(token_probs)
if true_prob is not None:
break

# Default to 1 if "True" in tokens, 0 if not
if true_prob is None:
true_prob = 1 if "True" in base_cascade.tokens[resp_idx] else 0

all_true_probs.append(true_prob)

return LogprobsForFilterCascade(
tokens=base_cascade.tokens,
confidences=base_cascade.confidences,
true_probs=all_true_probs
)

def count_tokens(self, messages: list[dict[str, str]] | str) -> int:
"""Count tokens in messages using either custom tokenizer or model's default tokenizer"""
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
return token_counter(model=self.model, messages=messages)

kwargs = {"model": self.model, "messages": messages}
if self.tokenizer:
kwargs["custom_tokenizer"] = {
"type": "huggingface_tokenizer",
"tokenizer": self.tokenizer
}

return token_counter(**kwargs)
1 change: 0 additions & 1 deletion lotus/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ class StatsMixin(BaseModel):
stats: dict[str, Any] | None = None


# TODO: Figure out better logprobs type
class LogprobsMixin(BaseModel):
# for each response, we have a list of tokens, and for each token, we have a ChatCompletionTokenLogprob
logprobs: list[list[ChatCompletionTokenLogprob]] | None = None
Expand Down

0 comments on commit bf8c73a

Please sign in to comment.