From 8370981118ba5ad032d96e5b5daf9cd44b8aabfc Mon Sep 17 00:00:00 2001 From: David Dai Date: Tue, 27 Aug 2024 17:23:52 +0800 Subject: [PATCH] Add docstrings and simplify the structure of main.py --- Makefile | 2 +- bean_utils/bean.py | 129 +++++++++++++++++++++++++-- bean_utils/vec_query.py | 53 ++++++++++- bots/{mmbot.py => mattermost_bot.py} | 0 main.py | 16 ++-- vec_db/sqlite_vec_db_test.py | 2 +- 6 files changed, 182 insertions(+), 20 deletions(-) rename bots/{mmbot.py => mattermost_bot.py} (100%) diff --git a/Makefile b/Makefile index e725862..5cdceb0 100644 --- a/Makefile +++ b/Makefile @@ -38,6 +38,6 @@ lint: @ruff check test: - coverage run --source=. --omit="**/*_test.py,bots/mmbot.py,bots/telegram_bot.py,main.py,test.py" -m pytest + coverage run --source=. --omit="**/*_test.py,bots/*_bot.py,main.py,test.py" -m pytest coverage report @coverage html diff --git a/bean_utils/bean.py b/bean_utils/bean.py index 3d37495..2a9245d 100644 --- a/bean_utils/bean.py +++ b/bean_utils/bean.py @@ -34,6 +34,22 @@ def __init__(self, fname=None) -> None: self._load() def _load(self): + """ + Load the beancount file and initialize the internal state. + + This function loads the beancount file specified by `self.fname` and + initializes the internal state. The internal state includes the entries + (transactions and metadata), options, accounts, modification times of + included files, and account files. + + This function does not return anything. It updates the following instance + variables: + - `_entries`: a list of parsed entries. + - `_options`: a dictionary of options. + - `_accounts`: a set of accounts. + - `mtimes`: a dictionary mapping filenames to modification times. + - `account_files`: a set of filenames. + """ self._entries, errors, self._options = loader.load_file(self.fname) self._accounts = set() self.mtimes = {} @@ -51,7 +67,13 @@ def _load(self): self.mtimes[f] = Path(f).stat().st_mtime def _auto_reload(self, accounts_only=False): - # Check and reload + """ + Check and reload if any of the files have been modified. + + Args: + accounts_only (bool): If True, only check the files that contains account + transactions. Defaults to False. + """ files_to_check = self.mtimes.keys() if accounts_only: files_to_check = self.account_files @@ -76,13 +98,34 @@ def accounts(self): return self._accounts def find_account(self, account_str): + """ + Find an account that contains the given string. + + Args: + account_str (str): A substring to search for in the account string. + + Returns: + str or None: The account that contains the given substring, or None + if no such account is found. + """ for account in self.accounts: if account_str in account: return account return None def find_account_by_payee(self, payee): - # Find the account with the same payee + """ + Find the account with the same payee. + + Args: + payee (str): The payee to search for. + + Returns: + str or None: The account with the same payee, or None if not found. If the + transaction has multiple postings with missing units, the first one is + returned. If no expense account is found, the first expense account in + the transaction is returned. + """ target = None for trx in reversed(self._entries): if not isinstance(trx, Transaction): @@ -103,9 +146,27 @@ def find_account_by_payee(self, payee): return expense_account def run_query(self, q): + """ + A procedural interface to the `beancount.query` module. + """ return query.run_query(self.entries, self.options, q) - def match_new_args(self, args) -> List[List[str]]: + def modify_args_via_vec(self, args) -> List[List[str]]: + """ + Given a list of arguments, modify the arguments to match the transactions in the vector + database. + + Args: + args (List[str]): The arguments to modify. + + Returns: + List[List[str]]: A list of modified arguments. + + This function queries the vector database to find transactions that match the given + arguments. It then rebuilds the narrations for each matching transaction and returns a + list of modified arguments. If no matching transactions are found, an empty list is + returned. + """ # Query from vector db matched_txs = query_txs(" ".join(args[1:])) candidate_args = [] @@ -120,6 +181,20 @@ def match_new_args(self, args) -> List[List[str]]: return candidate_args def build_trx(self, args): + """ + The core function of transaction generation. + + Args: + args (List[str]): A list of strings representing the transaction arguments. + The format is: {amount} {from_account} {to_account} {payee} {narration} [#{tag1} #{tag2} ...]. + The to_account and narration are optional. + + Returns: + str: The transaction string in the beancount format. + + Raises: + ValueError: If from_account or to_account is not found. + """ amount, from_acc, to_acc, *extra = args amount = Decimal(amount) @@ -127,9 +202,11 @@ def build_trx(self, args): to_account = self.find_account(to_acc) payee = None + # from_account id requied if from_account is None: err_msg = _("Account {acc} not found").format(acc=from_acc) raise ValueError(err_msg) + # Try to find the payee if to_account is not found if to_account is None: payee = to_acc to_account = self.find_account_by_payee(payee) @@ -139,7 +216,7 @@ def build_trx(self, args): if payee is None: payee, *extra = extra - kwargs = { + trx_info = { "date": datetime.now().astimezone().date(), "payee": payee, "from_account": from_account, @@ -154,14 +231,30 @@ def build_trx(self, args): for arg in extra: if arg.startswith(("#", "^")): tags.append(arg) - elif not kwargs["desc"]: - kwargs["desc"] = arg + elif not trx_info["desc"]: + trx_info["desc"] = arg if tags: - kwargs["tags"] = " " + " ".join(tags) + trx_info["tags"] = " " + " ".join(tags) - return transaction_tmpl.format(**kwargs) + return transaction_tmpl.format(**trx_info) def generate_trx(self, line) -> List[str]: + """ + The entry procedure for transaction generation. + + If the line cannot be directly converted into a transaction, + the function will attempt to match it from a vector database + or a RAG model. If all attempts fail, a ValueError will be raised. + + Args: + line (str): The line to generate a transaction from. + + Returns: + List[str]: A list of transactions generated from the line. + + Raises: + ValueError: If all attempts to generate a transaction fail. + """ args = parse_args(line) try: return [self.build_trx(args)] @@ -177,7 +270,7 @@ def generate_trx(self, line) -> List[str]: if vec_enabled: # Query from vector db candidate_txs = [] - for new_args in self.match_new_args(args): + for new_args in self.modify_args_via_vec(args): with contextlib.suppress(ValueError): candidate_txs.append(self.build_trx(new_args)) if candidate_txs: @@ -187,6 +280,15 @@ def generate_trx(self, line) -> List[str]: raise e def clone_trx(self, text) -> str: + """ + Clone a transaction from text. + + Args: + text (str): Text contains one transaction. + + Returns: + str: A transaction with today's date. + """ entries, _, _ = parser.parse_string(text) try: txs = next(e for e in entries if isinstance(e, Transaction)) @@ -202,6 +304,15 @@ def clone_trx(self, text) -> str: return "\n".join(segments) def commit_trx(self, data): + """ + Commit a transaction to beancount file, and format. + + Args: + data (str): The transaction data in beancount format. + + Raises: + SubprocessError: If the bean-format command fails to execute. + """ fname = self.fname with open(fname, 'a') as f: f.write("\n" + data + "\n") diff --git a/bean_utils/vec_query.py b/bean_utils/vec_query.py index ebdc199..cb44010 100644 --- a/bean_utils/vec_query.py +++ b/bean_utils/vec_query.py @@ -26,6 +26,20 @@ def embedding(texts): def convert_account(account): + """ + Convert an account string to a specific segment. + + Args: + account (str): The account string to convert. + + Returns: + str: The converted account string. + + This function takes an account string and converts it to a specific segment + based on the configuration in conf.config.beancount.account_distinguation_range. + If the account string does not contain any segment, the original account string + is returned. + """ dist_range = conf.config.beancount.account_distinguation_range segments = account.split(":") if isinstance(dist_range, int): @@ -44,7 +58,22 @@ def escape_quotes(s): def convert_to_natural_language(transaction) -> str: - # date = transaction.date.strftime('%Y-%m-%d') + """ + Convert a transaction object to a string representation of natural language for input to RAG. + + Args: + transactions (Transation): A Transaction object. + + Returns: + str: The natural language representation of the transaction. + + The format of the representation is: + `"{payee}" "{description}" "{account1} {account2} ..." [{#tag1} {#tag2} ...]`, + where `{payee}` is the payee of the transaction, `{description}` is the narration, + and `{account1} {account2} ...` is a space-separated list of accounts in the transaction. + The accounts are converted to the most distinguable level as specified in the configuration. + If the transaction has tags, they are appended to the end of the sentence. + """ payee = f'"{escape_quotes(transaction.payee)}"' description = f'"{escape_quotes(transaction.narration)}"' accounts = " ".join([convert_account(posting.account) for posting in transaction.postings]) @@ -56,6 +85,18 @@ def convert_to_natural_language(transaction) -> str: def build_tx_db(transactions): + """ + Build a transaction database from the given transactions. This function + consolidates the latest transactions and calculates their embeddings. + The embeddings are stored in a database for future use. + + Args: + transactions (list): A list of Transaction objects representing the + transactions. + + Returns: + int: The total number of tokens used for embedding. + """ _content_cache = {} def _read_lines(fname, start, end): if fname not in _content_cache: @@ -103,6 +144,16 @@ def _read_lines(fname, start, end): def query_txs(query): + """ + Query transactions based on the given query string. + + Args: + query (str): The query string to search for. + + Returns: + list: A list of matched transactions. The length of the list is determined + by the `output_amount` configuration. + """ candidates = conf.config.embedding.candidates or 3 output_amount = conf.config.embedding.output_amount or 1 match = query_by_embedding(embedding([query])[0][0]["embedding"], query, candidates) diff --git a/bots/mmbot.py b/bots/mattermost_bot.py similarity index 100% rename from bots/mmbot.py rename to bots/mattermost_bot.py diff --git a/main.py b/main.py index c2993f8..c666441 100644 --- a/main.py +++ b/main.py @@ -13,27 +13,27 @@ def init_bot(config_path): bean.init_bean_manager() - -def main(): +def parse_args(): parser = argparse.ArgumentParser(prog='beanbot', description='Bot to translate text into beancount transaction') - subparser = parser.add_subparsers(title='sub command', dest='command') + subparser = parser.add_subparsers(title='sub command', required=True, dest='command') telegram_parser = subparser.add_parser("telegram") telegram_parser.add_argument('-c', nargs="?", type=str, default="config.yaml", help="config file path") mattermost_parser = subparser.add_parser("mattermost") mattermost_parser.add_argument('-c', nargs="?", type=str, default="config.yaml", help="config file path") - args = parser.parse_args() - if args.command is None: - parser.print_help() - return + return parser.parse_args() + + +def main(): + args = parse_args() init_bot(args.c) if args.command == "telegram": from bots.telegram_bot import run_bot elif args.command == "mattermost": - from bots.mmbot import run_bot + from bots.mattermost_bot import run_bot run_bot() diff --git a/vec_db/sqlite_vec_db_test.py b/vec_db/sqlite_vec_db_test.py index cfd75c1..c60cdf7 100644 --- a/vec_db/sqlite_vec_db_test.py +++ b/vec_db/sqlite_vec_db_test.py @@ -9,7 +9,7 @@ from vec_db import sqlite_vec_db -def test_sqlite_db(tmp_path, mock_config, monkeypatch): +def test_sqlite_db(tmp_path, mock_config, monkeypatch): monkeypatch.setattr(sqlite_vec_db, "_db", None) # Build DB txs = [