Skip to content

Commit

Permalink
Add unit test for vec_db
Browse files Browse the repository at this point in the history
  • Loading branch information
StdioA committed Aug 25, 2024
1 parent 224ec7b commit af01366
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 5 deletions.
16 changes: 16 additions & 0 deletions .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements-full.txt
pip install "sqlite-vec==0.1.1"
- name: Test with pytest
run: |
pip install pytest pytest-cov coverage
coverage run -m pytest
coverage report
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ __pycache__
tx_db.*
config.yaml
**/*.po~
.coverage
htmlcov
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,8 @@ $(foreach lang,$(LANGUAGES),$(eval $(call po_rule,$(lang))))

lint:
@ruff check

test:
coverage run -m pytest
coverage report
@coverage html --include="**/*.py"
5 changes: 5 additions & 0 deletions conf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@
def load_config(config_path):
global config
config = Config(config_path)


def _load_config_from_dict(config_dict):
global config
config = Config.from_dict(config_dict)
1 change: 1 addition & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ exclude = [
"site-packages",
"venv",
"test.py",
"*_test.py",
]

# Same as Black.
Expand Down
65 changes: 65 additions & 0 deletions vec_db/json_vec_db_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import pytest
from typing import List
from conf import _load_config_from_dict, config
import tempfile
from vec_db import json_vec_db


@pytest.fixture
def mock_config():
conf_data = {
"embedding": {
"enable": True,
"db_store_folder": tempfile.gettempdir(),
}
}
_load_config_from_dict(conf_data)
return config


def easy_embedding(content: str) -> List[float]:
embed = [float(x) for x in content.encode()]
# right pad with zeros
_width = 64
for _ in range(_width - len(embed)):
embed.append(0.0)
return embed


def test_json_db(mock_config):
# Build DB
txs = [
{
"hash": "hash-1",
"occurance": 1,
"sentence": "sentence-1",
"content": "content-1",
"embedding": easy_embedding("content-1"),
},
{
"hash": "hash-2",
"occurance": 1,
"sentence": "sentence-2",
"content": "content-2",
"embedding": easy_embedding("content-2"),
},
{
"hash": "hash-3",
"occurance": 1,
"sentence": "sentence-3",
"content": "another-3",
"embedding": easy_embedding("another-3"),
},
]
json_vec_db.build_db(txs)
db_path = json_vec_db._get_db_name()
assert db_path.exists()
# Query DB
candidates = json_vec_db.query_by_embedding(
easy_embedding("content-1"), "sentence-1", 2,
)
assert len(candidates) == 2
assert candidates[0]["hash"] == "hash-1"
assert candidates[1]["hash"] == "hash-2"
# Cleanup
db_path.unlink()
12 changes: 7 additions & 5 deletions vec_db/sqlite_vec_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def build_db(txs):
if txs:
embedding_dimention = len(txs[0]["embedding"])

# Drop table
db.execute("DROP TABLE vec_items")
db.execute("DROP TABLE transactions")
# Drop table if exists
db.execute("DROP TABLE IF EXISTS vec_items")
db.execute("DROP TABLE IF EXISTS transactions")
db.commit()
db.execute("VACUUM")
db.commit()
Expand All @@ -71,13 +71,15 @@ def query_by_embedding(embedding, sentence, candidate_amount):
db = get_db()

try:
# I don't know why the implementation of `vec_distance_cosine` from sqlite-vec is `1 - cosine`
# Thus I should turn the result ranged from [0, 2] back to [1, -1]
rows = db.execute(
f"""
SELECT
rowid,
vec_distance_cosine(embedding, ?) AS distance
1-vec_distance_cosine(embedding, ?) AS distance
FROM vec_items
ORDER BY distance LIMIT {candidate_amount}
ORDER BY distance DESC LIMIT {candidate_amount}
""",
(serialize_f32(embedding),)).fetchall()
except sqlite3.OperationalError as e:
Expand Down
48 changes: 48 additions & 0 deletions vec_db/sqlite_vec_db_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pytest
from vec_db.json_vec_db_test import easy_embedding, mock_config

try:
import sqlite_vec as _
except ImportError:
pytest.skip("skipping module tests due to sqlite_vec not installed", allow_module_level=True)
else:
from vec_db import sqlite_vec_db


def test_sqlite_db(mock_config):
# Build DB
txs = [
{
"hash": "hash-1",
"occurance": 1,
"sentence": "sentence-1",
"content": "content-1",
"embedding": easy_embedding("content-1"),
},
{
"hash": "hash-2",
"occurance": 1,
"sentence": "sentence-2",
"content": "content-2",
"embedding": easy_embedding("content-2"),
},
{
"hash": "hash-3",
"occurance": 1,
"sentence": "sentence-3",
"content": "another-3",
"embedding": easy_embedding("another-3"),
},
]
sqlite_vec_db.build_db(txs)
db_path = sqlite_vec_db._get_db_name()
assert db_path.exists()
# Query DB
candidates = sqlite_vec_db.query_by_embedding(
easy_embedding("content-1"), "sentence-1", 2,
)
assert len(candidates) == 2
assert candidates[0]["content"] == "content-1"
assert candidates[1]["content"] == "content-2"
# Cleanup
db_path.unlink()

0 comments on commit af01366

Please sign in to comment.