diff --git a/.github/tests/rm_tests.py b/.github/tests/rm_tests.py index d82cf611..3940944a 100644 --- a/.github/tests/rm_tests.py +++ b/.github/tests/rm_tests.py @@ -79,6 +79,24 @@ def test_search_reranker_only(setup_models): assert df["Course Name"].tolist() == ["Optimization Methods in Engineering", "Probability and Random Processes"] +def test_search(setup_models): + rm, reranker = setup_models + lotus.settings.configure(rm=rm, reranker=reranker) + + data = { + "Course Name": [ + "Probability and Random Processes", + "Cooking", + "Food Sciences", + "Optimization Methods in Engineering", + ] + } + df = pd.DataFrame(data) + df = df.sem_index("Course Name", "index_dir") + df = df.sem_search("Course Name", "Optimization", K=2, n_rerank=1) + assert df["Course Name"].tolist() == ["Optimization Methods in Engineering"] + + def test_dedup(setup_models): rm, _ = setup_models lotus.settings.configure(rm=rm) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 31d1372f..e5f7fc46 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -31,7 +31,7 @@ jobs: run: ruff check . test: - name: LM and RM Tests + name: Python Tests runs-on: ubuntu-latest timeout-minutes: 5 @@ -57,6 +57,6 @@ jobs: - name: Run Python tests env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - run: + run: | pytest .github/tests/lm_tests.py pytest .github/tests/rm_tests.py \ No newline at end of file