diff --git a/.github/workflows/unit_test.yml b/.github/workflows/unit_test.yml new file mode 100644 index 0000000..92e4b64 --- /dev/null +++ b/.github/workflows/unit_test.yml @@ -0,0 +1,34 @@ +name: Run Unit Test + +on: + push: + branches: + - master + - feature/unit_test + paths-ignore: + - 'README.md' + - 'README_zh.md' + - 'Makefile' + - 'ruff.toml' + workflow_dispatch: + +jobs: + main: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.x' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements-full.txt + pip install "sqlite-vec==0.1.1" + - name: Test with pytest + run: | + pip install pytest pytest-cov coverage + coverage run -m pytest + coverage report diff --git a/.gitignore b/.gitignore index 9095d57..1e4f279 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,5 @@ __pycache__ tx_db.* config.yaml **/*.po~ +.coverage +htmlcov diff --git a/Makefile b/Makefile index 6277e05..b63fa2f 100644 --- a/Makefile +++ b/Makefile @@ -36,3 +36,8 @@ $(foreach lang,$(LANGUAGES),$(eval $(call po_rule,$(lang)))) lint: @ruff check + +test: + coverage run -m pytest + coverage report + @coverage html --include="**/*.py" diff --git a/conf/__init__.py b/conf/__init__.py index e572d0a..fcb592f 100644 --- a/conf/__init__.py +++ b/conf/__init__.py @@ -11,3 +11,8 @@ def load_config(config_path): global config config = Config(config_path) + + +def _load_config_from_dict(config_dict): + global config + config = Config.from_dict(config_dict) diff --git a/ruff.toml b/ruff.toml index f1705c2..5faddbd 100644 --- a/ruff.toml +++ b/ruff.toml @@ -27,6 +27,7 @@ exclude = [ "site-packages", "venv", "test.py", + "*_test.py", ] # Same as Black. diff --git a/vec_db/json_vec_db_test.py b/vec_db/json_vec_db_test.py new file mode 100644 index 0000000..a8abcf4 --- /dev/null +++ b/vec_db/json_vec_db_test.py @@ -0,0 +1,65 @@ +import pytest +from typing import List +from conf import _load_config_from_dict, config +import tempfile +from vec_db import json_vec_db + + +@pytest.fixture +def mock_config(): + conf_data = { + "embedding": { + "enable": True, + "db_store_folder": tempfile.gettempdir(), + } + } + _load_config_from_dict(conf_data) + return config + + +def easy_embedding(content: str) -> List[float]: + embed = [float(x) for x in content.encode()] + # right pad with zeros + _width = 64 + for _ in range(_width - len(embed)): + embed.append(0.0) + return embed + + +def test_json_db(mock_config): + # Build DB + txs = [ + { + "hash": "hash-1", + "occurance": 1, + "sentence": "sentence-1", + "content": "content-1", + "embedding": easy_embedding("content-1"), + }, + { + "hash": "hash-2", + "occurance": 1, + "sentence": "sentence-2", + "content": "content-2", + "embedding": easy_embedding("content-2"), + }, + { + "hash": "hash-3", + "occurance": 1, + "sentence": "sentence-3", + "content": "another-3", + "embedding": easy_embedding("another-3"), + }, + ] + json_vec_db.build_db(txs) + db_path = json_vec_db._get_db_name() + assert db_path.exists() + # Query DB + candidates = json_vec_db.query_by_embedding( + easy_embedding("content-1"), "sentence-1", 2, + ) + assert len(candidates) == 2 + assert candidates[0]["hash"] == "hash-1" + assert candidates[1]["hash"] == "hash-2" + # Cleanup + db_path.unlink() diff --git a/vec_db/sqlite_vec_db.py b/vec_db/sqlite_vec_db.py index 62b1bd7..0bc1fc2 100644 --- a/vec_db/sqlite_vec_db.py +++ b/vec_db/sqlite_vec_db.py @@ -42,9 +42,9 @@ def build_db(txs): if txs: embedding_dimention = len(txs[0]["embedding"]) - # Drop table - db.execute("DROP TABLE vec_items") - db.execute("DROP TABLE transactions") + # Drop table if exists + db.execute("DROP TABLE IF EXISTS vec_items") + db.execute("DROP TABLE IF EXISTS transactions") db.commit() db.execute("VACUUM") db.commit() @@ -71,13 +71,15 @@ def query_by_embedding(embedding, sentence, candidate_amount): db = get_db() try: + # I don't know why the implementation of `vec_distance_cosine` from sqlite-vec is `1 - cosine` + # Thus I should turn the result ranged from [0, 2] back to [1, -1] rows = db.execute( f""" SELECT rowid, - vec_distance_cosine(embedding, ?) AS distance + 1-vec_distance_cosine(embedding, ?) AS distance FROM vec_items - ORDER BY distance LIMIT {candidate_amount} + ORDER BY distance DESC LIMIT {candidate_amount} """, (serialize_f32(embedding),)).fetchall() except sqlite3.OperationalError as e: diff --git a/vec_db/sqlite_vec_db_test.py b/vec_db/sqlite_vec_db_test.py new file mode 100644 index 0000000..4872131 --- /dev/null +++ b/vec_db/sqlite_vec_db_test.py @@ -0,0 +1,48 @@ +import pytest +from vec_db.json_vec_db_test import easy_embedding, mock_config + +try: + import sqlite_vec as _ +except ImportError: + pytest.skip("skipping module tests due to sqlite_vec not installed", allow_module_level=True) +else: + from vec_db import sqlite_vec_db + + +def test_sqlite_db(mock_config): + # Build DB + txs = [ + { + "hash": "hash-1", + "occurance": 1, + "sentence": "sentence-1", + "content": "content-1", + "embedding": easy_embedding("content-1"), + }, + { + "hash": "hash-2", + "occurance": 1, + "sentence": "sentence-2", + "content": "content-2", + "embedding": easy_embedding("content-2"), + }, + { + "hash": "hash-3", + "occurance": 1, + "sentence": "sentence-3", + "content": "another-3", + "embedding": easy_embedding("another-3"), + }, + ] + sqlite_vec_db.build_db(txs) + db_path = sqlite_vec_db._get_db_name() + assert db_path.exists() + # Query DB + candidates = sqlite_vec_db.query_by_embedding( + easy_embedding("content-1"), "sentence-1", 2, + ) + assert len(candidates) == 2 + assert candidates[0]["content"] == "content-1" + assert candidates[1]["content"] == "content-2" + # Cleanup + db_path.unlink()