Skip to content

Commit

Permalink
feat: row_count and unique columns (#14)
Browse files Browse the repository at this point in the history
Signed-off-by: cutecutecat <[email protected]>
  • Loading branch information
cutecutecat authored Oct 8, 2024
1 parent 5e175a9 commit b5c22d8
Show file tree
Hide file tree
Showing 8 changed files with 238 additions and 25 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -328,13 +328,13 @@ from pgvecto_rs.sdk import PGVectoRs, Record
# Create a client
client = PGVectoRs(
db_url="postgresql+psycopg://postgres:mysecretpassword@localhost:5432/postgres",
table_name="example",
collection_name="example",
dimension=3,
)

try:
# Add some records
client.add_records(
client.insert(
[
Record.from_text("hello 1", [1, 2, 3]),
Record.from_text("hello 2", [1, 2, 4]),
Expand Down
3 changes: 3 additions & 0 deletions examples/sdk_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def embed(text: str):
client.insert(records1)
client.insert(records2)

# Count rows
client.row_count(estimate=True)

# Query (With a filter from the filters module)
print("#################### First Query ####################")
for record, dis in client.search(
Expand Down
7 changes: 7 additions & 0 deletions src/pgvecto_rs/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,10 @@ def __init__(self, vtype: str) -> None:
class TextParseError(PGVectoRsError):
def __init__(self, payload: str, dtype: type) -> None:
super().__init__(f"failed to parse text of '{payload}' as a {dtype}")


class CountRowsEstimateCondError(PGVectoRsError):
def __init__(self) -> None:
super().__init__(
"cannot use estimate=True and a condition for row count requests"
)
99 changes: 81 additions & 18 deletions src/pgvecto_rs/sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,55 +2,95 @@
from uuid import UUID

from numpy import ndarray
from sqlalchemy import ColumnElement, Float, create_engine, delete, insert, select, text
from sqlalchemy import (
BIGINT,
Column,
ColumnElement,
Float,
create_engine,
delete,
func,
insert,
select,
text,
)
from sqlalchemy.dialects import postgresql
from sqlalchemy.dialects.postgresql.pg_catalog import pg_class
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm.session import Session
from sqlalchemy.types import String

from pgvecto_rs.errors import CountRowsEstimateCondError
from pgvecto_rs.sdk.filters import Filter
from pgvecto_rs.sdk.record import Record, RecordORM, RecordORMType
from pgvecto_rs.sdk.record import Record, RecordORM, RecordORMType, Unique
from pgvecto_rs.sqlalchemy import VECTOR


def table_factory(collection_name, dimension, table_args, base_class=RecordORM):
def __init__(self, **kwargs): # noqa: N807
base_class.__init__(self, **kwargs)

newclass = type(
collection_name,
(base_class,),
{
"__init__": __init__,
"__tablename__": f"collection_{collection_name}",
"__table_args__": table_args,
"id": mapped_column(
postgresql.UUID(as_uuid=True),
primary_key=True,
),
"text": mapped_column(String),
"meta": mapped_column(postgresql.JSONB),
"embedding": mapped_column(VECTOR(dimension)),
},
)
return newclass


class PGVectoRs:
_engine: Engine
_table: Type[RecordORM]
dimension: int

def __init__(
self, db_url: str, collection_name: str, dimension: int, recreate: bool = False
def __init__( # noqa: PLR0913
self,
db_url: str,
collection_name: str,
dimension: int,
recreate: bool = False,
constraints: Union[List[Unique], None] = None,
) -> None:
"""Connect to an existing table or create a new empty one.
If the `recreate=True`, the table will be dropped if it exists.
Args:
----
db_url (str): url to the database.
table_name (str): name of the table.
collection_name (str): name of the collection. A prefix `collection_` is added to actual table name.
dimension (int): dimension of the embeddings.
recreate (bool): drop the table if it exists. Defaults to False.
constraints (List[Unique]): add constraints to columns, e.g. UNIQUE constraint
"""

class _Table(RecordORM):
__tablename__ = f"collection_{collection_name}"
__table_args__ = {"extend_existing": True} # noqa: RUF012
id: Mapped[UUID] = mapped_column(
postgresql.UUID(as_uuid=True),
primary_key=True,
if constraints is None or len(constraints) == 0:
table_args = {"extend_existing": True}
else:
table_args = (
*[col.make() for col in constraints],
{"extend_existing": True},
)
text: Mapped[str] = mapped_column(String)
meta: Mapped[dict] = mapped_column(postgresql.JSONB)
embedding: Mapped[ndarray] = mapped_column(VECTOR(dimension))

self._engine = create_engine(db_url)
self._table = table_factory(collection_name, dimension, table_args)
with Session(self._engine) as session:
session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
if recreate:
session.execute(text(f"DROP TABLE IF EXISTS {_Table.__tablename__}"))
session.execute(
text(f"DROP TABLE IF EXISTS {self._table.__tablename__}")
)
session.commit()
self._table = _Table
self._table.__table__.create(self._engine, checkfirst=True)
self.dimension = dimension

Expand Down Expand Up @@ -105,6 +145,29 @@ def search(
res = session.execute(stmt)
return [(Record.from_orm(row[0]), row[1]) for row in res]

# ================ Stat ==================
def row_count(self, estimate: bool = True, filter: Optional[Filter] = None) -> int:
if estimate and filter is not None:
raise CountRowsEstimateCondError()
if estimate:
stmt = (
select(func.cast(Column("reltuples", Float), BIGINT).label("rows"))
.select_from(pg_class)
.where(
Column("oid", Float)
== func.cast(self._table.__tablename__, postgresql.REGCLASS)
)
)
with Session(self._engine) as session:
result = session.execute(stmt).fetchone()
else:
stmt = select(func.count("*").label("rows")).select_from(self._table)
if filter is not None:
stmt = stmt.where(filter(self._table))
with Session(self._engine) as session:
result = session.execute(stmt).fetchone()
return result[0]

# ================ Delete ================
def delete(self, filter: Filter) -> None:
with Session(self._engine) as session:
Expand Down
24 changes: 24 additions & 0 deletions src/pgvecto_rs/sdk/record.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,34 @@
from enum import IntEnum
from functools import reduce
from typing import List, Optional, Type, Union
from uuid import UUID, uuid4

from numpy import array, float32, ndarray
from sqlalchemy import UniqueConstraint
from sqlalchemy.orm import DeclarativeBase, Mapped


class Column(IntEnum):
TEXT = 1
META = 2
EMBEDDING = 4


class Unique:
def __init__(self, columns: List[Column]):
self.value = reduce(lambda x, y: x | y, columns)

def make(self) -> UniqueConstraint:
ans: List[UniqueConstraint] = []
if self.value & Column.TEXT:
ans.append("text")
if self.value & Column.META:
ans.append("meta")
if self.value & Column.EMBEDDING:
ans.append("embedding")
return UniqueConstraint(*ans)


class RecordORM(DeclarativeBase):
__tablename__: str
id: Mapped[UUID]
Expand Down
6 changes: 3 additions & 3 deletions src/pgvecto_rs/types/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ def __init__(
self.ratio = ratio

def dump(self) -> dict:
if self.type == "trivial":
return {"quantization": {"trivial": {}}}
if self.type == "product":
return {"quantization": {"product": {"ratio": self.ratio}}}
elif self.type == "scalar":
return {"quantization": {"scalar": {}}}
else:
return {"quantization": {"product": {"ratio": self.ratio}}}
return {}


class Flat:
Expand Down
8 changes: 6 additions & 2 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
DB_NAME = os.getenv("DB_NAME", "postgres")

# Run tests with shell:
# DB_HOST=localhost DB_USER=postgres DB_PASS=password DB_NAME=postgres python3 -m pytest bindings/python/tests/
# DB_HOST=localhost DB_USER=postgres DB_PASS=password DB_NAME=postgres python3 -m pytest tests/
URL = f"postgresql://{USER}:{PASS}@{HOST}:{PORT}/{DB_NAME}"
DATABASES = {
"default": {
Expand Down Expand Up @@ -106,7 +106,11 @@
),
(
IndexOption(index=Ivf(quantization=Quantization(typ="trivial"))),
"[indexing.ivf.quantization.trivial]\n",
"[indexing.ivf]\n",
),
(
IndexOption(index=Ivf()),
"[indexing.ivf]\n",
),
(
IndexOption(
Expand Down
112 changes: 112 additions & 0 deletions tests/test_sdk.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import time
from typing import Callable, List

import numpy as np
import pytest
from sqlalchemy.exc import IntegrityError

from pgvecto_rs.sdk import Filter, PGVectoRs, Record, filters
from pgvecto_rs.sdk.record import Column, Unique
from tests import (
COSINE_DIS_OP,
L2_DIS_OP,
Expand Down Expand Up @@ -82,3 +85,112 @@ def test_search_order_and_limit(
for rec, dis in client.search(dis_oprand, dis_op, top_k=4):
expect = assert_func(dis_oprand, rec.embedding.to_numpy())
assert np.allclose(expect, dis, atol=1e-10)


def test_unique_text_table(
client: PGVectoRs,
):
unique_client = PGVectoRs(
db_url=URL,
collection_name="unique_text",
dimension=3,
recreate=True,
constraints=[Unique(columns=[Column.TEXT])],
)
it = iter(MockTexts.items())
text1, vector1 = next(it)
_, vector2 = next(it)
records_ok = [Record.from_text(t, v, {"src": "src1"}) for t, v in MockTexts.items()]
records_fail = [
Record.from_text(text1, vector1, {"src": "src1"}),
Record.from_text(text1, vector2, {"src": "src2"}),
]
unique_client.insert(records_ok)
unique_client.delete_all()
with pytest.raises(IntegrityError):
unique_client.insert(records_fail)


def test_unique_meta_table(
client: PGVectoRs,
):
unique_client = PGVectoRs(
db_url=URL,
collection_name="unique_meta",
dimension=3,
recreate=True,
constraints=[Unique(columns=[Column.META])],
)
it = iter(MockTexts.items())
text1, vector1 = next(it)
text2, vector2 = next(it)
records_ok = [
Record.from_text(text1, vector1, {"src": "src1"}),
Record.from_text(text2, vector2, {"src": "src2"}),
]
records_fail = [
Record.from_text(text1, vector1, {"src": "src1"}),
Record.from_text(text2, vector2, {"src": "src1"}),
]
unique_client.insert(records_ok)
unique_client.delete_all()
with pytest.raises(IntegrityError):
unique_client.insert(records_fail)


def test_unique_text_meta_table(
client: PGVectoRs,
):
unique_client = PGVectoRs(
db_url=URL,
collection_name="unique_both",
dimension=3,
recreate=True,
constraints=[Unique(columns=[Column.TEXT, Column.META])],
)
it = iter(MockTexts.items())
text1, vector1 = next(it)
text2, vector2 = next(it)
records_ok = [
Record.from_text(text1, vector1, {"src": "src1"}),
Record.from_text(text2, vector2, {"src": "src1"}),
]
records_fail = [
Record.from_text(text1, vector1, {"src": "src1"}),
Record.from_text(text1, vector2, {"src": "src1"}),
]
unique_client.insert(records_ok)
unique_client.delete_all()
with pytest.raises(IntegrityError):
unique_client.insert(records_fail)


COUNT = 1000


def test_count_table(
client: PGVectoRs,
):
count_client = PGVectoRs(
db_url=URL,
collection_name="count",
dimension=3,
recreate=True,
)
it = iter(MockTexts.items())
text1, vector1 = next(it)
records = [Record.from_text(text1, vector1, {"src": "src1"}) for _ in range(COUNT)]
count_client.insert(records)

rows = count_client.row_count(estimate=False)
assert rows == COUNT

rows = count_client.row_count(estimate=False, filter=filter_src2)
assert rows == 0

for _ in range(90):
estimate_rows = count_client.row_count(estimate=True)
if estimate_rows == COUNT:
return
time.sleep(1)
raise AssertionError

0 comments on commit b5c22d8

Please sign in to comment.