Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LM tests to CI #12

Merged
merged 7 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import pandas as pd
import pytest

import lotus
from lotus.models import OpenAIModel

# Set logger level to DEBUG
lotus.logger.setLevel("DEBUG")


@pytest.fixture
def setup_models():
# Setup GPT models
gpt_4o_mini = OpenAIModel(model="gpt-4o-mini")
gpt_4o = OpenAIModel(model="gpt-4o")
return gpt_4o_mini, gpt_4o


def test_filter_operation(setup_models):
gpt_4o_mini, _ = setup_models
lotus.settings.configure(lm=gpt_4o_mini)

# Test filter operation on an easy dataframe
data = {"Text": ["I am really exicted to go to class today!", "I am very sad"]}
df = pd.DataFrame(data)
user_instruction = "{Text} is a positive sentiment"
filtered_df = df.sem_filter(user_instruction)

expected_df = pd.DataFrame({"Text": ["I am really exicted to go to class today!"]})
assert filtered_df.equals(expected_df)


def test_filter_cascade(setup_models):
gpt_4o_mini, gpt_4o = setup_models
lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini)

data = {"Text": ["I am really exicted to go to class today!", "I am very sad"]}
df = pd.DataFrame(data)
user_instruction = "{Text} is a positive sentiment"

# All filters resolved by the helper model
filtered_df, stats = df.sem_filter(user_instruction, cascade_threshold=0, return_stats=True)
assert stats["filters_resolved_by_large_model"] == 0, stats
assert stats["filters_resolved_by_helper_model"] == 2, stats
expected_df = pd.DataFrame({"Text": ["I am really exicted to go to class today!"]})
assert filtered_df.equals(expected_df)

# All filters resolved by the large model
filtered_df, stats = df.sem_filter(user_instruction, cascade_threshold=1.01, return_stats=True)
assert stats["filters_resolved_by_large_model"] == 2, stats
assert stats["filters_resolved_by_helper_model"] == 0, stats
assert filtered_df.equals(expected_df)


def test_top_k(setup_models):
gpt_4o_mini, _ = setup_models
lotus.settings.configure(lm=gpt_4o_mini)

data = {
"Text": [
"Lionel Messi is a good soccer player",
"Michael Jordan is a good basketball player",
"Steph Curry is a good basketball player",
"Tom Brady is a good football player",
]
}
df = pd.DataFrame(data)
user_instruction = "Which {Text} is most related to basketball?"
sorted_df = df.sem_topk(user_instruction, K=2)

top_2_expected = set(["Michael Jordan is a good basketball player", "Steph Curry is a good basketball player"])
top_2_actual = set(sorted_df["Text"].values)
assert top_2_expected == top_2_actual


def test_join(setup_models):
gpt_4o_mini, _ = setup_models
lotus.settings.configure(lm=gpt_4o_mini)

data1 = {"School": ["UC Berkeley", "Stanford"]}
data2 = {"School Type": ["Public School", "Private School"]}

df1 = pd.DataFrame(data1)
df2 = pd.DataFrame(data2)
join_instruction = "{School} is a {School Type}"
joined_df = df1.sem_join(df2, join_instruction)
joined_pairs = set(zip(joined_df["School"], joined_df["School Type"]))
expected_pairs = set([("UC Berkeley", "Public School"), ("Stanford", "Private School")])
assert joined_pairs == expected_pairs


def test_map_fewshot(setup_models):
gpt_4o_mini, _ = setup_models
lotus.settings.configure(lm=gpt_4o_mini)

data = {"School": ["UC Berkeley", "Carnegie Mellon"]}
df = pd.DataFrame(data)
examples = {"School": ["Stanford", "MIT"], "Answer": ["CA", "MA"]}
examples_df = pd.DataFrame(examples)
user_instruction = "What state is {School} in? Respond only with the two-letter abbreviation."
df = df.sem_map(user_instruction, examples=examples_df, suffix="State")

pairs = set(zip(df["School"], df["State"]))
expected_pairs = set([("UC Berkeley", "CA"), ("Carnegie Mellon", "PA")])
assert pairs == expected_pairs
31 changes: 0 additions & 31 deletions .github/workflows/ruff.yml

This file was deleted.

60 changes: 60 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
name: Tests and Linting

on:
push:
branches:
- main
pull_request:
branches:
- main

jobs:
ruff_lint:
name: Ruff Lint
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.9'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ruff==0.5.2

- name: Run ruff
run: ruff check .

test:
name: LM Tests
runs-on: ubuntu-latest
timeout-minutes: 5

steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.9'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -e .
pip install pytest

- name: Set OpenAI API Key
run: echo "OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> $GITHUB_ENV

- name: Run Python tests
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: pytest .github/tests/lm_tests.py
18 changes: 9 additions & 9 deletions lotus/templates/task_instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def filter_formatter_cot(
cot_reasoning: List[str],
) -> List[str]:
sys_instruction = (
"The user will povide a claim and some relevant context.\n"
"The user will provide a claim and some relevant context.\n"
"Your job is to determine whether the claim is true for the given context.\n"
'First give your reasoning. Then you MUST end your output with "Answer: True or False"'
)
Expand Down Expand Up @@ -45,7 +45,7 @@ def filter_formatter_zs_cot(
user_instruction: str,
) -> List[str]:
sys_instruction = (
"The user will povide a claim and some relevant context.\n"
"The user will provide a claim and some relevant context.\n"
"Your job is to determine whether the claim is true for the given context.\n"
'First give your reasoning. Then you MUST end your output with "Answer: True or False"'
)
Expand All @@ -71,7 +71,7 @@ def filter_formatter(
return filter_formatter_zs_cot(df_text, user_instruction)

sys_instruction = (
"The user will povide a claim and some relevant context.\n"
"The user will provide a claim and some relevant context.\n"
"Your job is to determine whether the claim is true for the given context.\n"
'You must answer with a single word, "True" or "False".'
)
Expand Down Expand Up @@ -103,7 +103,7 @@ def map_formatter_cot(
cot_reasoning: List[str],
) -> List[str]:
sys_instruction = (
"The user will povide an instruction and some relevant context.\n"
"The user will provide an instruction and some relevant context.\n"
"Your job is to answer the user's instruction given the context."
"You must give your reasoning and then your final answer"
)
Expand All @@ -119,7 +119,7 @@ def map_formatter_cot(
[
{
"role": "user",
"content": f"Context:\n{ex_df_txt}\n\Instruction: {user_instruction}",
"content": f"Context:\n{ex_df_txt}\nInstruction: {user_instruction}",
},
{
"role": "assistant",
Expand All @@ -142,7 +142,7 @@ def map_formatter_zs_cot(
user_instruction: str,
) -> List[str]:
sys_instruction = (
"The user will povide an instruction and some relevant context.\n"
"The user will provide an instruction and some relevant context.\n"
"Your job is to answer the user's instruction given the context."
'First give your reasoning. Then you MUST end your output with "Answer: your answer"'
)
Expand All @@ -153,7 +153,7 @@ def map_formatter_zs_cot(
messages.append(
{
"role": "user",
"content": f"Context:\n{df_text}\n\Instruction: {user_instruction}",
"content": f"Context:\n{df_text}\nInstruction: {user_instruction}",
}
)
return messages
Expand All @@ -173,7 +173,7 @@ def map_formatter(
return map_formatter_zs_cot(df_text, user_instruction)

sys_instruction = (
"The user will povide an instruction and some relevant context.\n"
"The user will provide an instruction and some relevant context.\n"
"Your job is to answer the user's instruction given the context."
)
messages = [
Expand Down Expand Up @@ -203,7 +203,7 @@ def map_formatter(

def extract_formatter(df_text: str, user_instruction: str) -> List[str]:
sys_instruction = (
"The user will povide an instruction and some relevant context.\n"
"The user will provide an instruction and some relevant context.\n"
"Your job is to extract the information requested in the instruction.\n"
"Write the response in JSONL format in a single line with the following fields:\n"
"""{"answer": "your answer", "quotes": "quote from context supporting your answer"}"""
Expand Down
Loading