forked from timescale/pgai
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore: add vectorizer load test scripts
- Loading branch information
Showing
5 changed files
with
225 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Vectorizer Load Test | ||
|
||
This directory contains scripts to help with load testing the vectorizer. The | ||
scripts create a table named `wiki` with approximately 1.5M rows to be | ||
vectorized. | ||
|
||
1. Add a `.env` file and put a `DB_URL` in it. The value should be a Postgres DB connection URL. It can be a local DB or a remote DB. | ||
2. Run `./load.sh`. This script will | ||
1. Download a dataset from HuggingFace | ||
2. Load it into a working table named `wiki_orig` | ||
3. Process the data into the `wiki` table. The original data is already chunked. We have to dechunk it. | ||
4. [optionally] drop the working tables | ||
5. [optionally] dump the `wiki` table to `wiki.dump` | ||
|
||
If you already have a `wiki.dump` file, you can use `./restore.sh` to recreate | ||
the `wiki` table without having to go through the process above. This is much | ||
faster. | ||
|
||
Once you have created the `wiki` table, you are ready to create one or more | ||
vectorizers on the table. Happy testing! |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
import os | ||
import shutil | ||
import subprocess | ||
from multiprocessing import Process | ||
from pathlib import Path | ||
|
||
import psycopg | ||
from datasets import load_dataset | ||
from dotenv import load_dotenv | ||
|
||
load_dotenv() | ||
DB_URL = os.environ["DB_URL"] | ||
|
||
|
||
def load(): | ||
print("fetching dataset...") | ||
data = load_dataset(f"Cohere/wikipedia-22-12", 'en', split='train', streaming=True, trust_remote_code=True) | ||
|
||
to_load = -1 | ||
ans = 'q' | ||
while ans.lower() not in {'y', 'n'}: | ||
ans = input("do you want to load the entire dataset (8.59M rows)? (y/n) ") | ||
if ans.lower() == 'n': | ||
to_load = input("how many rows do you want to load? ") | ||
try: | ||
to_load = int(to_load) | ||
except ValueError: | ||
ans = None | ||
print("invalid input") | ||
continue | ||
elif ans.lower() == 'y': | ||
to_load = -1 | ||
|
||
def batches(): | ||
batch = [] | ||
for i, row in enumerate(data): | ||
batch.append((i, row)) | ||
if len(batch) == 1000 or i == to_load: | ||
yield batch | ||
batch = [] | ||
if i == to_load: | ||
break | ||
if len(batch) > 0: | ||
yield batch | ||
|
||
print("connecting to the database...") | ||
with psycopg.connect(DB_URL) as con: | ||
with con.cursor() as cur: | ||
print("creating wiki_orig table...") | ||
cur.execute(""" | ||
drop table if exists wiki_orig; | ||
create table wiki_orig | ||
( title text | ||
, "text" text | ||
, url text | ||
, wiki_id int | ||
, paragraph_id int | ||
) | ||
""") | ||
con.commit() | ||
|
||
print("loading data...") | ||
for batch in batches(): | ||
with con.cursor(binary=True) as cur: | ||
with cur.copy(""" | ||
copy wiki_orig (title, "text", url, wiki_id, paragraph_id) | ||
from stdin (format binary) | ||
""") as cpy: | ||
cpy.set_types(['text', 'text', 'text', 'integer', 'integer']) | ||
for i, row in batch: | ||
cpy.write_row((row["title"], row["text"], row["url"], row["wiki_id"], row["paragraph_id"])) | ||
if i != 0 and (i % 1000 == 0 or i == to_load): | ||
print(f"{i}") | ||
con.commit() | ||
|
||
with con.cursor() as cur: | ||
print("creating index on wiki_orig...") | ||
cur.execute("create index on wiki_orig (wiki_id, paragraph_id)") | ||
print("creating wiki table...") | ||
cur.execute(""" | ||
drop table if exists wiki; | ||
create table wiki | ||
( id bigint not null primary key generated by default as identity | ||
, title text not null | ||
, body text not null | ||
, wiki_id int not null | ||
, url text not null | ||
) | ||
""") | ||
print("creating queue table...") | ||
cur.execute("drop table if exists queue") | ||
cur.execute("create table queue (wiki_id int not null primary key)") | ||
cur.execute("insert into queue (wiki_id) select distinct wiki_id from wiki_orig") | ||
|
||
|
||
def dechunk(): | ||
load_dotenv() | ||
with psycopg.connect(os.environ["DB_URL"], autocommit=True) as con: | ||
with con.cursor() as cur: | ||
cur.execute(""" | ||
do language plpgsql $$ | ||
declare | ||
_wiki_id int; | ||
begin | ||
loop | ||
select wiki_id into _wiki_id | ||
from queue | ||
for update skip locked | ||
limit 1 | ||
; | ||
exit when not found; | ||
insert into wiki (title, body, wiki_id, url) | ||
select | ||
title | ||
, string_agg("text", E'\n\n' order by paragraph_id) | ||
, wiki_id | ||
, url | ||
from wiki_orig | ||
where wiki_id = _wiki_id | ||
group by wiki_id, title, url | ||
; | ||
delete from queue | ||
where wiki_id = _wiki_id | ||
; | ||
commit; | ||
end loop; | ||
end | ||
$$; | ||
""") | ||
|
||
|
||
if __name__ == '__main__': | ||
load() | ||
|
||
concurrency = 0 | ||
while concurrency < 1: | ||
concurrency = input("how many processes do you want to use to dechunk? ") | ||
try: | ||
concurrency = int(concurrency) | ||
except ValueError: | ||
concurrency = 0 | ||
print("invalid input") | ||
continue | ||
|
||
print("dechunking...") | ||
procs = [] | ||
for _ in range(concurrency): | ||
proc = Process(target=dechunk) | ||
procs.append(proc) | ||
proc.start() | ||
for proc in procs: | ||
proc.join() | ||
|
||
if input("do you want to drop the intermediate tables? (y/n) ").lower() == 'y': | ||
with psycopg.connect(DB_URL) as con: | ||
with con.cursor() as cur: | ||
print("dropping wiki_orig...") | ||
cur.execute("drop table if exists wiki_orig") | ||
print("dropping queue...") | ||
cur.execute("drop table if exists queue") | ||
|
||
if shutil.which("pg_dump") is not None: | ||
if input("do you want to dump the dataset? (y/n) ").lower() == "y": | ||
p = Path.cwd().joinpath("wiki.dump") | ||
if p.is_file(): | ||
p.unlink(missing_ok=True) | ||
print("dumping dataset to wiki.dump...") | ||
subprocess.run(f"""pg_dump -d "{DB_URL}" -Fc -v -f wiki.dump --no-owner --no-privileges --table=public.wiki""", | ||
check=True, | ||
shell=True, | ||
env=os.environ, | ||
cwd=str(Path.cwd()), | ||
) | ||
|
||
print("done") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
#!/usr/bin/env bash | ||
|
||
if [ -d '.venv' ]; then | ||
source .venv/bin/activate | ||
else | ||
python3 -m venv .venv | ||
source .venv/bin/activate | ||
pip install --upgrade pip | ||
pip install python-dotenv datasets "psycopg[binary]" | ||
fi | ||
|
||
python3 load.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#!/usr/bin/env bash | ||
|
||
# if you have already generated a wiki.dump file using load.sh | ||
# you can skip load.sh in the future and run this script to restore from the dump | ||
|
||
if [ ! -f "wiki.dump" ]; then | ||
echo "wiki.dump does not exist" | ||
exit 1 | ||
fi | ||
|
||
if [ -f '.env' ]; then | ||
set -a && . ".env" && set +a # this loads the env vars from the .env file | ||
fi | ||
pg_restore -d "$DB_URL" -v -Fc --exit-on-error --no-owner --no-privileges wiki.dump |