Skip to content

Commit

Permalink
Add LM tests to CI (#12)
Browse files Browse the repository at this point in the history
Adds basic LM tests to the CI/CD pipeline. Currently tests a variety of
operators
1. Filter
2. Filter + Cascade
3. Top-K
4. Join
5. Map + Few-Shot

Overall the OpenAI cost of running these tests should be < $0.01
  • Loading branch information
sidjha1 authored Oct 1, 2024
1 parent 2140cbc commit cb35763
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 40 deletions.
105 changes: 105 additions & 0 deletions .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import pandas as pd
import pytest

import lotus
from lotus.models import OpenAIModel

# Set logger level to DEBUG
lotus.logger.setLevel("DEBUG")


@pytest.fixture
def setup_models():
# Setup GPT models
gpt_4o_mini = OpenAIModel(model="gpt-4o-mini")
gpt_4o = OpenAIModel(model="gpt-4o")
return gpt_4o_mini, gpt_4o


def test_filter_operation(setup_models):
gpt_4o_mini, _ = setup_models
lotus.settings.configure(lm=gpt_4o_mini)

# Test filter operation on an easy dataframe
data = {"Text": ["I am really exicted to go to class today!", "I am very sad"]}
df = pd.DataFrame(data)
user_instruction = "{Text} is a positive sentiment"
filtered_df = df.sem_filter(user_instruction)

expected_df = pd.DataFrame({"Text": ["I am really exicted to go to class today!"]})
assert filtered_df.equals(expected_df)


def test_filter_cascade(setup_models):
gpt_4o_mini, gpt_4o = setup_models
lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini)

data = {"Text": ["I am really exicted to go to class today!", "I am very sad"]}
df = pd.DataFrame(data)
user_instruction = "{Text} is a positive sentiment"

# All filters resolved by the helper model
filtered_df, stats = df.sem_filter(user_instruction, cascade_threshold=0, return_stats=True)
assert stats["filters_resolved_by_large_model"] == 0, stats
assert stats["filters_resolved_by_helper_model"] == 2, stats
expected_df = pd.DataFrame({"Text": ["I am really exicted to go to class today!"]})
assert filtered_df.equals(expected_df)

# All filters resolved by the large model
filtered_df, stats = df.sem_filter(user_instruction, cascade_threshold=1.01, return_stats=True)
assert stats["filters_resolved_by_large_model"] == 2, stats
assert stats["filters_resolved_by_helper_model"] == 0, stats
assert filtered_df.equals(expected_df)


def test_top_k(setup_models):
gpt_4o_mini, _ = setup_models
lotus.settings.configure(lm=gpt_4o_mini)

data = {
"Text": [
"Lionel Messi is a good soccer player",
"Michael Jordan is a good basketball player",
"Steph Curry is a good basketball player",
"Tom Brady is a good football player",
]
}
df = pd.DataFrame(data)
user_instruction = "Which {Text} is most related to basketball?"
sorted_df = df.sem_topk(user_instruction, K=2)

top_2_expected = set(["Michael Jordan is a good basketball player", "Steph Curry is a good basketball player"])
top_2_actual = set(sorted_df["Text"].values)
assert top_2_expected == top_2_actual


def test_join(setup_models):
gpt_4o_mini, _ = setup_models
lotus.settings.configure(lm=gpt_4o_mini)

data1 = {"School": ["UC Berkeley", "Stanford"]}
data2 = {"School Type": ["Public School", "Private School"]}

df1 = pd.DataFrame(data1)
df2 = pd.DataFrame(data2)
join_instruction = "{School} is a {School Type}"
joined_df = df1.sem_join(df2, join_instruction)
joined_pairs = set(zip(joined_df["School"], joined_df["School Type"]))
expected_pairs = set([("UC Berkeley", "Public School"), ("Stanford", "Private School")])
assert joined_pairs == expected_pairs


def test_map_fewshot(setup_models):
gpt_4o_mini, _ = setup_models
lotus.settings.configure(lm=gpt_4o_mini)

data = {"School": ["UC Berkeley", "Carnegie Mellon"]}
df = pd.DataFrame(data)
examples = {"School": ["Stanford", "MIT"], "Answer": ["CA", "MA"]}
examples_df = pd.DataFrame(examples)
user_instruction = "What state is {School} in? Respond only with the two-letter abbreviation."
df = df.sem_map(user_instruction, examples=examples_df, suffix="State")

pairs = set(zip(df["School"], df["State"]))
expected_pairs = set([("UC Berkeley", "CA"), ("Carnegie Mellon", "PA")])
assert pairs == expected_pairs
31 changes: 0 additions & 31 deletions .github/workflows/ruff.yml

This file was deleted.

60 changes: 60 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
name: Tests and Linting

on:
push:
branches:
- main
pull_request:
branches:
- main

jobs:
ruff_lint:
name: Ruff Lint
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.9'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ruff==0.5.2
- name: Run ruff
run: ruff check .

test:
name: LM Tests
runs-on: ubuntu-latest
timeout-minutes: 5

steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.9'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -e .
pip install pytest
- name: Set OpenAI API Key
run: echo "OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> $GITHUB_ENV

- name: Run Python tests
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: pytest .github/tests/lm_tests.py
18 changes: 9 additions & 9 deletions lotus/templates/task_instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def filter_formatter_cot(
cot_reasoning: List[str],
) -> List[str]:
sys_instruction = (
"The user will povide a claim and some relevant context.\n"
"The user will provide a claim and some relevant context.\n"
"Your job is to determine whether the claim is true for the given context.\n"
'First give your reasoning. Then you MUST end your output with "Answer: True or False"'
)
Expand Down Expand Up @@ -45,7 +45,7 @@ def filter_formatter_zs_cot(
user_instruction: str,
) -> List[str]:
sys_instruction = (
"The user will povide a claim and some relevant context.\n"
"The user will provide a claim and some relevant context.\n"
"Your job is to determine whether the claim is true for the given context.\n"
'First give your reasoning. Then you MUST end your output with "Answer: True or False"'
)
Expand All @@ -71,7 +71,7 @@ def filter_formatter(
return filter_formatter_zs_cot(df_text, user_instruction)

sys_instruction = (
"The user will povide a claim and some relevant context.\n"
"The user will provide a claim and some relevant context.\n"
"Your job is to determine whether the claim is true for the given context.\n"
'You must answer with a single word, "True" or "False".'
)
Expand Down Expand Up @@ -103,7 +103,7 @@ def map_formatter_cot(
cot_reasoning: List[str],
) -> List[str]:
sys_instruction = (
"The user will povide an instruction and some relevant context.\n"
"The user will provide an instruction and some relevant context.\n"
"Your job is to answer the user's instruction given the context."
"You must give your reasoning and then your final answer"
)
Expand All @@ -119,7 +119,7 @@ def map_formatter_cot(
[
{
"role": "user",
"content": f"Context:\n{ex_df_txt}\n\Instruction: {user_instruction}",
"content": f"Context:\n{ex_df_txt}\nInstruction: {user_instruction}",
},
{
"role": "assistant",
Expand All @@ -142,7 +142,7 @@ def map_formatter_zs_cot(
user_instruction: str,
) -> List[str]:
sys_instruction = (
"The user will povide an instruction and some relevant context.\n"
"The user will provide an instruction and some relevant context.\n"
"Your job is to answer the user's instruction given the context."
'First give your reasoning. Then you MUST end your output with "Answer: your answer"'
)
Expand All @@ -153,7 +153,7 @@ def map_formatter_zs_cot(
messages.append(
{
"role": "user",
"content": f"Context:\n{df_text}\n\Instruction: {user_instruction}",
"content": f"Context:\n{df_text}\nInstruction: {user_instruction}",
}
)
return messages
Expand All @@ -173,7 +173,7 @@ def map_formatter(
return map_formatter_zs_cot(df_text, user_instruction)

sys_instruction = (
"The user will povide an instruction and some relevant context.\n"
"The user will provide an instruction and some relevant context.\n"
"Your job is to answer the user's instruction given the context."
)
messages = [
Expand Down Expand Up @@ -203,7 +203,7 @@ def map_formatter(

def extract_formatter(df_text: str, user_instruction: str) -> List[str]:
sys_instruction = (
"The user will povide an instruction and some relevant context.\n"
"The user will provide an instruction and some relevant context.\n"
"Your job is to extract the information requested in the instruction.\n"
"Write the response in JSONL format in a single line with the following fields:\n"
"""{"answer": "your answer", "quotes": "quote from context supporting your answer"}"""
Expand Down

0 comments on commit cb35763

Please sign in to comment.