diff --git a/README.md b/README.md index da23fff..7c9df95 100644 --- a/README.md +++ b/README.md @@ -57,13 +57,19 @@ At high traffic levels, ingestion and query performance may be limited in a sing to larger aggregate numbers by using multiple processes or even distributing across multiple VMs. In this case, the result metrics will need to be aggregated to get the total QPS and throughput. +This can be supported by the --concurrency (-c) parameter which will spawn separate child processes: + i.e. Run 4 query benchmarks in separate processes with 4 tables each with table name prefix 'my-prefix' (bash): -```bash -./multi.sh 4 my-prefix "uv run bench.py -t 4 -q 10000 --no-ingest --no-index" -``` + +`uv run bench.py -t 4 -q 10000 -c 4 -p my-prefix --no-ingest --no-index` + This technique can also be used for high-throughput ingestion to multiple tables in parallel using a table prefix per process: i.e. Ingest to 50 tables in parallel across 10 processes with prefix 'high-throughput' (bash): -```bash -./multi.sh 10 high-throughput "uv run bench.py -t 5" -``` + +`uv run bench.py -c 10 -t 5 -p high-throughput` + + +### Reporting + +# todo diff --git a/bench.py b/bench.py index b7d4fc5..9b897cf 100755 --- a/bench.py +++ b/bench.py @@ -2,7 +2,10 @@ import concurrent import os import time +import traceback from concurrent.futures import wait +from multiprocessing import Pool +from pprint import pprint from typing import Iterable from lancedb.remote.errors import LanceDBClientError @@ -13,44 +16,71 @@ import pyarrow as pa from datasets import load_dataset, DownloadConfig -from src.cloud.benchmark.util import print_percentiles, await_indices +from cloud.benchmark.report import ( + set_result, + save_results, + get_report_dir, + generate_report, + aggregate_results, +) +from src.cloud.benchmark.util import print_percentiles, await_indices, get_percentiles def run_benchmark( - dataset: str, - num_tables: int, - batch_size: int, - num_queries: int, - ingest: bool, - index: bool, - prefix: str, - reset: bool + dataset: str, + num_tables: int, + batch_size: int, + num_queries: int, + ingest: bool, + index: bool, + prefix: str, + reset: bool, + test_run_id: str, ): - db = lancedb.connect( - uri=os.environ["LANCEDB_DB_URI"], - api_key=os.environ["LANCEDB_API_KEY"], - host_override=os.getenv("LANCEDB_HOST_OVERRIDE"), - region=os.getenv("LANCEDB_REGION", "us-east-1"), - ) + try: + print(f"starting test run {test_run_id} with prefix {prefix}") + db = lancedb.connect( + uri=os.environ["LANCEDB_DB_URI"], + api_key=os.environ["LANCEDB_API_KEY"], + host_override=os.getenv("LANCEDB_HOST_OVERRIDE"), + region=os.getenv("LANCEDB_REGION", "us-east-1"), + ) + results = { + "params": { + "test_run_id": test_run_id, + "dataset": dataset, + "batch_size": batch_size, + "num_queries": num_queries, + "prefix": prefix, + }, + "tables": {}, + } + + if reset: + _drop_tables(db, num_tables, prefix) + + if ingest: + tables = list(_create_tables(db, num_tables, prefix)) + _ingest(tables, dataset, batch_size, results) + else: + tables = list(_open_tables(db, num_tables, prefix)) - if reset: - _drop_tables(db, num_tables, prefix) + if index: + _create_indices(tables, results) - if ingest: - tables = list(_create_tables(db, num_tables, prefix)) - _ingest(tables, dataset, batch_size) - else: - tables = list(_open_tables(db, num_tables, prefix)) + _query_tables(tables, num_queries, results) - if index: - _create_indices(tables) + # pprint(results) + save_results(test_run_id, prefix, results) - _query_tables(tables, num_queries) - print("benchmark complete") + except Exception as e: + print(f"benchmark failed with error: {e}") + print(traceback.format_exc()) + raise def _create_tables( - db: lancedb.LanceDBConnection, num_tables: int, prefix: str + db: lancedb.LanceDBConnection, num_tables: int, prefix: str ) -> Iterable[RemoteTable]: schema = pa.schema( [ @@ -77,7 +107,7 @@ def _create_tables( def _open_tables( - db: lancedb.LanceDBConnection, num_tables: int, prefix: str + db: lancedb.LanceDBConnection, num_tables: int, prefix: str ) -> Iterable[RemoteTable]: for i in range(num_tables): table_name = f"{prefix}-{i}" @@ -91,24 +121,32 @@ def _drop_tables(db, num_tables, prefix): db.drop_table(t.name) -def _ingest(tables: list[RemoteTable], dataset: str, batch_size: int): +def _ingest(tables: list[RemoteTable], dataset: str, batch_size: int, results: dict): start = time.time() with concurrent.futures.ThreadPoolExecutor(max_workers=len(tables)) as executor: futures = [] for table in tables: - futures.append(executor.submit(_ingest_table, dataset, table, batch_size)) - results = [future.result() for future in futures] + futures.append( + executor.submit(_ingest_table, dataset, table, batch_size, results) + ) + r = [future.result() for future in futures] total_s = time.time() - start - total_rows = sum(results) + total_rows = sum(r) + rows_s = total_rows / total_s print( - f"ingested {total_rows} rows in {len(tables)} tables in {total_s:.1f}s. average: {total_rows / total_s:.1f}rows/s" + f"ingested {total_rows} rows in {len(tables)} tables in {total_s:.1f}s. average: {rows_s :.1f}rows/s" ) + set_result(results, "ingest_time_all_tables", total_s) + set_result(results, table.name, "ingest_total_rows_all_tables", total_rows) + set_result(results, table.name, "ingest_avg_rows_s_all_tables", rows_s) -def _ingest_table(dataset: str, table: RemoteTable, batch_size: int) -> int: - # todo: support batch size > 1000 + +def _ingest_table( + dataset: str, table: RemoteTable, batch_size: int, results: dict +) -> int: add_times = [] begin = time.time() total_rows = 0 @@ -124,10 +162,18 @@ def _ingest_table(dataset: str, table: RemoteTable, batch_size: int) -> int: ) total_s = int((time.time() - begin)) + rows_s = total_rows / total_s print( - f"{table.name}: ingested {total_rows} rows in {total_s}s. average: {total_rows / total_s:.1f}rows/s" + f"{table.name}: ingested {total_rows} rows in {total_s}s. average: {rows_s :.1f}rows/s" ) - print_percentiles(add_times) + add_percentiles = get_percentiles(add_times) + print_percentiles(add_percentiles) + + set_result(results, table.name, "ingest_time_s", total_s) + set_result(results, table.name, "rows_ingested", total_rows) + set_result(results, table.name, "ingest_avg_rows_s", rows_s) + set_result(results, table.name, "ingest_percentiles_ms", add_percentiles) + return total_rows @@ -144,20 +190,22 @@ def _split_record_batch(record_batch, batch_size): yield record_batch.slice(i, min(batch_size, num_rows - i)) -def _query_tables(tables: list[RemoteTable], num_queries: int): +def _query_tables(tables: list[RemoteTable], num_queries: int, results: dict): num_tables = len(tables) with concurrent.futures.ThreadPoolExecutor(max_workers=num_tables) as executor: futures = [] for table in tables: - futures.append(executor.submit(_query_table, table, num_queries)) - results = [future.result() for future in futures] + futures.append(executor.submit(_query_table, table, num_queries, results)) + r = [future.result() for future in futures] total_queries = num_queries * num_tables - total_qps = sum(results) + total_qps = sum(r) print( f"completed {total_queries} queries on {num_tables} tables. average: {total_qps:.1f}QPS" ) + set_result(results, None, "total_queries", total_queries) + set_result(results, None, "total_qps", total_qps) def _await_index(table: RemoteTable, index_type: str, start_time): @@ -167,7 +215,7 @@ def _await_index(table: RemoteTable, index_type: str, start_time): ) -def _create_indices(tables: list[RemoteTable]): +def _create_indices(tables: list[RemoteTable], results): # create the indices - these will be created async table_indices = {} for t in tables: @@ -231,7 +279,7 @@ def _convert_dataset(schema, dataset: str, batch_size: int) -> Iterable[pa.Recor yield b -def _query_table(table: RemoteTable, num_queries: int, warmup_queries=100): +def _query_table(table: RemoteTable, num_queries: int, results: dict, warmup_queries=1): # log a warning if data is not fully indexed total_rows = table.count_rows() for idx in table.list_indices()["indexes"]: @@ -258,13 +306,22 @@ def _query_table(table: RemoteTable, num_queries: int, warmup_queries=100): total_s = int(time.time() - begin) qps = num_queries / total_s print(f"{table.name}: query count: {num_queries} average: {qps :.1f}QPS") - print_percentiles(diffs) + percentiles = get_percentiles(diffs) + print_percentiles(percentiles) + + set_result(results, table.name, "query_total_time_s", total_s) + set_result(results, table.name, "num_queries", num_queries) + set_result(results, table.name, "query_avg_qps", qps) + set_result(results, table.name, "query_percentiles_ms", percentiles) + return qps def _query(table: RemoteTable, nprobes=1): try: - table.search(np.random.standard_normal(1536)).metric("cosine").nprobes(nprobes).select(["openai", "title"]).to_list() + table.search(np.random.standard_normal(1536)).metric("cosine").nprobes( + nprobes + ).select(["openai", "title"]).to_list() except Exception as e: print(f"{table.name}: error during query: {e}") @@ -320,6 +377,13 @@ def main(): default="ldb-cloud-benchmarks", help="table name prefix", ) + parser.add_argument( + "-c", + "--concurrency", + type=int, + default=1, + help="number of parallel processes to launch", + ) parser.add_argument( "-r", "--reset", @@ -328,18 +392,46 @@ def main(): action=argparse.BooleanOptionalAction, help="drop tables before ingesting", ) + parser.add_argument( + "-i", + "--id", + type=str, + default=None, + help="drop tables before ingesting", + ) args = parser.parse_args() print(args) - run_benchmark( - args.dataset, - args.tables, - args.batch, - args.queries, - args.ingest, - args.index, - args.prefix, - args.reset - ) + + test_run_id = args.id + if not test_run_id: + test_run_id = str(int(time.time())) + + # launch child processes based on configured concurrency + # each child process will operate on unique tables based on table name prefix + with Pool(processes=args.concurrency) as pool: + p_args = [ + ( + args.dataset, + args.tables, + args.batch, + args.queries, + args.ingest, + args.index, + f"{args.prefix}-{i+1}" if args.concurrency > 1 else args.prefix, + args.reset, + test_run_id, + ) + for i in range(args.concurrency) + ] + pool.starmap(run_benchmark, p_args) + + # aggregate results from concurrent benchmarks + aggregate_results(test_run_id) + + # generate the report + report_file = generate_report(get_report_dir(test_run_id)) + print(f"finished test run {test_run_id}. ") + print(f"report saved to {report_file}") if __name__ == "__main__": diff --git a/multi.sh b/multi.sh deleted file mode 100755 index a73a9bb..0000000 --- a/multi.sh +++ /dev/null @@ -1,9 +0,0 @@ -n=$1 -prefix=$2 -cmd=$3 -rm -f out.txt -for i in $(seq 1 $n); do - $cmd -p "$prefix-$i" & >> out.txt -done - -tail -f out.txt \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index b64d7ec..99897b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,8 @@ dependencies = [ "lancedb", "tqdm", "pyarrow", - "backoff" + "backoff", + "jinja2" ] [tool.uv] diff --git a/src/cloud/benchmark/report.py b/src/cloud/benchmark/report.py new file mode 100644 index 0000000..c5c8262 --- /dev/null +++ b/src/cloud/benchmark/report.py @@ -0,0 +1,99 @@ +import json +import os +from pathlib import Path +from typing import Optional + +from jinja2 import Template + + +def generate_report(test_run_id: str) -> str: + report_dir = get_report_dir(test_run_id) + + # copy template to report dir + current_dir = os.path.dirname(os.path.abspath(__file__)) + template = os.path.join(current_dir, "report/index.html") + + new_path = report_dir / "index.html" + + # load aggregated data from report dir + with open(report_dir / "aggregated.json", "r") as f: + aggregated = json.load(f) + + with open(template, "r+") as t: + template = Template(t.read()) + html = template.render(aggregated) + with open(new_path, "w+") as o: + o.write(html) + + return str(new_path) + + +def set_result(results: dict, table_name: Optional[str], key: str, val): + if not table_name: + results[key] = val + return + if table_name not in results: + results["tables"][table_name] = {} + results["tables"][table_name][key] = val + + +def save_results(test_run_id: str, prefix: str, results: dict): + result_dir = get_report_dir(test_run_id) + Path(result_dir).mkdir(parents=True, exist_ok=True) + result_file = result_dir / f"{prefix}.json" + with open(result_file, "w") as json_file: + json.dump(results, json_file, indent=4) + print(f"results saved to {result_file}") + + +def get_report_dir( + test_run_id: str, base_path: str = "/tmp/lancedb-cloud-benchmarks/results" +) -> Path: + result_dir = Path(base_path) / test_run_id + return result_dir + + +def aggregate_results(test_run_id: str, out_file_name: str = "aggregated.json") -> str: + """Aggregate the results from all processes and store it in the output file""" + report_dir = get_report_dir(test_run_id) + out_file = os.path.join(report_dir, out_file_name) + merged = { + "processes": {}, + } + + total_qps = 0 + total_queries = 0 + total_rows = 0 + total_rows_s = 0 + params = None + for f in os.listdir(report_dir): + if f.endswith(".json"): + file_path = os.path.join(report_dir, f) + with open(file_path, "r") as f: + try: + data = json.load(f) + merged["processes"][data["params"]["prefix"]] = data + params = data["params"] + + # aggregate the results for all processes and tables + total_qps += data.get("total_qps", 0) + total_queries += data.get("total_queries", 0) + total_rows += data.get("ingest_total_rows_all_tables", 0) + total_rows_s += data.get("ingest_avg_rows_s_all_tables", 0) + + except json.JSONDecodeError as e: + print(f"error reading {f}: {e}") + + merged["params"] = params + merged["aggregated"] = { + "qps": total_qps, + "queries": total_queries, + "ingestion_rows_s": total_rows_s, + "ingested_rows": total_rows, + } + + with open(out_file, "w") as outfile: + json.dump(merged, outfile, indent=4) + + print(f"aggregated results saved to {out_file}") + return str(out_file) diff --git a/src/cloud/benchmark/report/index.html b/src/cloud/benchmark/report/index.html new file mode 100644 index 0000000..124b69e --- /dev/null +++ b/src/cloud/benchmark/report/index.html @@ -0,0 +1,160 @@ + + + +
+ + + + + + + + + + + + + + + + + + + +
+ Rows{{aggregated.ingested_rows}}
+ Ingestion rows/s{{aggregated.ingestion_rows_s}}
+
+ Queries{{aggregated.queries}}
+ Queries Per Second{{aggregated.qps}}
+
Table | +Ingestion Percentiles | +Query Percentiles | +
---|---|---|
{{name}} | +{{table.ingest_percentiles_ms}} | +{{table.query_percentiles_ms}} | +