Skip to content

Commit

Permalink
Add unit test for vec_query
Browse files Browse the repository at this point in the history
  • Loading branch information
StdioA committed Aug 26, 2024
1 parent 236937c commit 46f76f4
Show file tree
Hide file tree
Showing 11 changed files with 109 additions and 58 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install "pip<24.1"
pip install -r requirements-full.txt
pip install -r requirements/full.txt
pip install "sqlite-vec==0.1.1"
pip install pytest pytest-cov coverage codecov
- name: Test with pytest
Expand Down
2 changes: 1 addition & 1 deletion bean_utils/bean.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from beancount.core.data import Open, Transaction
from beancount.core.number import MISSING
from typing import List
from bean_utils.txs_query import query_txs
from bean_utils.vec_query import query_txs
from bean_utils.rag import complete_rag
import conf

Expand Down
9 changes: 4 additions & 5 deletions bean_utils/bean_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from conf import _load_config_from_dict
from conf.config_data import Config
from beancount.parser import parser
from bean_utils import bean, txs_query
from bean_utils import bean, vec_query


today = str(datetime.now().astimezone().date())
Expand Down Expand Up @@ -159,7 +159,6 @@ def test_generate_trx_with_vector_db(mock_config, monkeypatch):
"candidates": 3,
"output_amount": 2,
}))
# monkeypatch.setattr(txs_query, "embedding", mock_embedding)
def _mock_embedding_post(*args, json={}, **kwargs):
result, tokens = mock_embedding(json["input"])
return MockResponse({
Expand All @@ -169,7 +168,7 @@ def _mock_embedding_post(*args, json={}, **kwargs):
monkeypatch.setattr(requests, "post", _mock_embedding_post)

manager = bean.BeanManager(mock_config.beancount.filename)
txs_query.build_tx_db(manager.entries)
vec_query.build_tx_db(manager.entries)
trx = manager.generate_trx('10.00 "Kin Soy", "Eating"')
# The match effect is not garanteed in this test due to incorrect embedding implementation
assert len(trx) == 2
Expand Down Expand Up @@ -202,13 +201,13 @@ def test_generate_trx_with_rag(mock_config, monkeypatch):
monkeypatch.setattr(mock_config, "rag", Config.from_dict({
"enable": True,
}))
monkeypatch.setattr(txs_query, "embedding", mock_embedding)
monkeypatch.setattr(vec_query, "embedding", mock_embedding)
monkeypatch.setattr(requests, "post", mock_post({"message": {"content": exp_trx}}))

# Test RAG fallback
manager = bean.BeanManager(mock_config.beancount.filename)
trx = manager.generate_trx('10.00 "Kin Soy", "Eating"')
txs_query.build_tx_db(manager.entries)
vec_query.build_tx_db(manager.entries)
# The match effect is not garanteed in this test due to incorrect embedding implementation
assert len(trx) == 1
assert_txs_equal(trx[0], exp_trx)
Expand Down
4 changes: 2 additions & 2 deletions bean_utils/rag.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import requests
from vec_db import query_by_embedding
from bean_utils import txs_query
from bean_utils import vec_query
import conf


Expand Down Expand Up @@ -38,7 +38,7 @@ def complete_rag(args, date, accounts):
candidates = conf.config.embedding.candidates or 3
rag_config = conf.config.rag

match = query_by_embedding(txs_query.embedding([stripped_input])[0][0]["embedding"], stripped_input, candidates)
match = query_by_embedding(vec_query.embedding([stripped_input])[0][0]["embedding"], stripped_input, candidates)
reference_records = "\n------\n".join([x["content"] for x in match])
prompt = _PROMPT_TEMPLATE.format(date=date, reference_records=reference_records, accounts=accounts)
payload = {
Expand Down
58 changes: 25 additions & 33 deletions bean_utils/txs_query.py → bean_utils/vec_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def convert_account(account):
def escape_quotes(s):
if not s:
return s
return s.replace('"', '\\"').replace("'", "\\'")
return s.replace('"', '\\"')


def convert_to_natural_language(transaction) -> str:
Expand All @@ -57,45 +57,37 @@ def convert_to_natural_language(transaction) -> str:
return sentence


content_cache = {}


def read_lines(fname, start, end):
global content_cache
if fname not in content_cache:
with open(fname) as f:
lines = f.readlines()
content_cache[fname] = lines
return content_cache[fname][start-1:end]


def build_tx_db(transactions):
_content_cache = {}
def _read_lines(fname, start, end):
if fname not in _content_cache:
with open(fname) as f:
_content_cache[fname] = f.readlines()
return _content_cache[fname][start-1:end]

unique_txs = {}
amount = conf.config.embedding.transaction_amount
# Build latest transactions
for entry in reversed(transactions):
if not isinstance(entry, Transaction):
continue
try:
sentence = convert_to_natural_language(entry)
if sentence is None:
continue
if sentence in unique_txs:
unique_txs[sentence]["occurance"] += 1
continue
fname = entry.meta['filename']
start_lineno = entry.meta['lineno']
end_lineno = max(p.meta['lineno'] for p in entry.postings)
unique_txs[sentence] = {
"sentence": sentence,
"hash": hash_entry(entry),
"occurance": 1,
"content": "".join(read_lines(fname, start_lineno, end_lineno)),
}
if len(unique_txs) >= amount:
break
except Exception:
raise
sentence = convert_to_natural_language(entry)
if sentence is None:
continue
if sentence in unique_txs:
unique_txs[sentence]["occurance"] += 1
continue
fname = entry.meta['filename']
start_lineno = entry.meta['lineno']
end_lineno = max(p.meta['lineno'] for p in entry.postings)
unique_txs[sentence] = {
"sentence": sentence,
"hash": hash_entry(entry),
"occurance": 1,
"content": "".join(_read_lines(fname, start_lineno, end_lineno)),
}
if len(unique_txs) >= amount:
break
# Build embedding by group
total_usage = 0
unique_txs_list = list(unique_txs.values())
Expand Down
50 changes: 50 additions & 0 deletions bean_utils/vec_query_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest
from bean_utils import vec_query
from conf import _load_config_from_dict
from conf.config_data import Config
from beancount.parser import parser
from bean_utils.bean import parse_args


@pytest.fixture
def mock_config():
conf_data = {
"beancount": {
"account_distinguation_range": [1, 2],
},
}
config = _load_config_from_dict(conf_data)
return config


@pytest.mark.parametrize(
"account, arg, exp",
[
("Assets:BoFA:Checking", 1, "BoFA"),
("Assets:BoFA:Checking", [1, 1], "BoFA"),
("Assets:BoFA:Checking", [1, 2], "BoFA:Checking"),
("Assets:BoFA:Checking", [1, 5], "BoFA:Checking"),
],
)
def test_convert_account(account, arg, exp, mock_config, monkeypatch):
monkeypatch.setattr(vec_query.conf.config, "beancount", Config.from_dict({
"account_distinguation_range": arg,
}))
assert vec_query.convert_account(account) == exp


def test_convert_to_natual_language(monkeypatch):
trx_str = """
2022-01-01 * "Discount 'abc'" "Discount"
Assets:US:BofA:Checking 4264.93 USD
Equity:Opening-Balances -4264.93 USD
"""
monkeypatch.setattr(vec_query.conf.config, "beancount", Config.from_dict({
"account_distinguation_range": [1, 2],
}))
trx, _, _ = parser.parse_string(trx_str)

result = vec_query.convert_to_natural_language(trx[0])
assert result == '"Discount \'abc\'" "Discount" US:BofA Opening-Balances'
args = parse_args(result)
assert args == ["Discount 'abc'", "Discount", "US:BofA", "Opening-Balances"]
4 changes: 2 additions & 2 deletions bots/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import List, Union, Any
from beancount.core.inventory import Inventory
import requests
from bean_utils import txs_query
from bean_utils import vec_query
from bean_utils.bean import bean_manager, NoTransactionError
import conf

Expand All @@ -31,7 +31,7 @@ def build_db() -> BaseMessage:
if not conf.config.embedding.get("enable", True):
return BaseMessage(content=_("Embedding is not enabled."))
entries = bean_manager.entries
tokens = txs_query.build_tx_db(entries)
tokens = vec_query.build_tx_db(entries)
return BaseMessage(content=_("Token usage: {tokens}").format(tokens=tokens))


Expand Down
4 changes: 2 additions & 2 deletions bots/controller_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from datetime import datetime, date
from bean_utils import txs_query
from bean_utils import vec_query
from conf import _load_config_from_dict
from bean_utils.bean import init_bean_manager
from bots import controller
Expand Down Expand Up @@ -130,7 +130,7 @@ def test_build_db(monkeypatch, mock_env):
"candidates": 3,
"output_amount": 2,
}))
monkeypatch.setattr(txs_query, "embedding", mock_embedding)
monkeypatch.setattr(vec_query, "embedding", mock_embedding)
response = controller.build_db()
assert isinstance(response, controller.BaseMessage)
assert response.content == f"Token usage: {mock_env.embedding.transaction_amount}"
26 changes: 15 additions & 11 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import argparse
import conf
import logging
from bean_utils.bean import init_bean_manager
from bean_utils import bean


if __name__ == "__main__":
def main():
# Init logging
logging.basicConfig(level=logging.INFO)
# logging.getLogger().addHandler(logging.StreamHandler())
Expand All @@ -19,17 +19,21 @@
mattermost_parser.add_argument('-c', nargs="?", type=str, default="config.yaml", help="config file path")

args = parser.parse_args()
if args.command is not None:
conf.load_config(args.c)
# Init i18n
conf.init_locale()
init_bean_manager()
if args.command is None:
parser.print_help()
return

conf.load_config(args.c)
# Init i18n
conf.init_locale()
bean.init_bean_manager()

if args.command == "telegram":
from bots.telegram_bot import run_bot
run_bot()
elif args.command == "mattermost":
from bots.mmbot import run_bot
run_bot()
else:
parser.print_help()
run_bot()


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion vec_db/json_vec_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def query_by_embedding(embedding, sentence, candidate_amount):
try:
with open(_get_db_name()) as f:
transactions = json.load(f)
except FileExistsError:
except FileNotFoundError:
logging.warning("JSON vector database is not built")
return None
embed_query = np.array(embedding)
Expand Down
6 changes: 6 additions & 0 deletions vec_db/json_vec_db_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ def test_json_db(mock_config):
"embedding": easy_embedding("another-3"),
},
]
# Query when DB not exists
candidates = json_vec_db.query_by_embedding(
easy_embedding("content-1"), "sentence-1", 2,
)
assert candidates is None
# Build DB
json_vec_db.build_db(txs)
db_path = json_vec_db._get_db_name()
assert db_path.exists()
Expand Down

0 comments on commit 46f76f4

Please sign in to comment.