Skip to content

Commit

Permalink
Add flexibility to Store trait + better pgvector filtering (#273)
Browse files Browse the repository at this point in the history
* Fix method name typo

* Abstract over Vector Store filters, Add more robust pgvector filtering
  • Loading branch information
dredozubov authored Jan 23, 2025
1 parent a1e29f6 commit 60c7125
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 68 deletions.
6 changes: 4 additions & 2 deletions src/vectorstore/opensearch/opensearch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,12 @@ impl Store {

#[async_trait]
impl VectorStore for Store {
type Options = VecStoreOptions<Value>;

async fn add_documents(
&self,
docs: &[Document],
opt: &VecStoreOptions,
opt: &Self::Options,
) -> Result<Vec<String>, Box<dyn Error>> {
let texts: Vec<String> = docs.iter().map(|d| d.page_content.clone()).collect();
let embedder = opt.embedder.as_ref().unwrap_or(&self.embedder);
Expand Down Expand Up @@ -154,7 +156,7 @@ impl VectorStore for Store {
&self,
query: &str,
limit: usize,
opt: &VecStoreOptions,
opt: &Self::Options,
) -> Result<Vec<Document>, Box<dyn Error>> {
let query_vector = self.embedder.embed_query(query).await?;
let query = build_similarity_search_query(
Expand Down
10 changes: 5 additions & 5 deletions src/vectorstore/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,20 @@ use crate::embedding::embedder_trait::Embedder;
/// .with_filters(json!({"genre": "Sci-Fi"}))
/// .with_embedder(my_embedder);
/// ```
pub struct VecStoreOptions {
pub struct VecStoreOptions<F> {
pub name_space: Option<String>,
pub score_threshold: Option<f32>,
pub filters: Option<Value>,
pub filters: Option<F>,
pub embedder: Option<Arc<dyn Embedder>>,
}

impl Default for VecStoreOptions {
impl Default for VecStoreOptions<Value> {
fn default() -> Self {
Self::new()
}
}

impl VecStoreOptions {
impl<F> VecStoreOptions<F> {
pub fn new() -> Self {
VecStoreOptions {
name_space: None,
Expand All @@ -49,7 +49,7 @@ impl VecStoreOptions {
self
}

pub fn with_filters(mut self, filters: Value) -> Self {
pub fn with_filters(mut self, filters: F) -> Self {
self.filters = Some(filters);
self
}
Expand Down
15 changes: 8 additions & 7 deletions src/vectorstore/pgvector/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@ use sqlx::{postgres::PgPoolOptions, Pool, Postgres, Row, Transaction};
use crate::{embedding::embedder_trait::Embedder, vectorstore::VecStoreOptions};

use super::{
HNSWIndex, Store, PG_LOCKID_EXTENSION, PG_LOCK_ID_COLLECTION_TABLE, PG_LOCK_ID_EMBEDDING_TABLE,
HNSWIndex, PgFilter, PgOptions, Store, PG_LOCKID_EXTENSION, PG_LOCK_ID_COLLECTION_TABLE,
PG_LOCK_ID_EMBEDDING_TABLE,
};

const DEFAULT_COLLECTION_NAME: &str = "langchain";
const DEFAULT_PRE_DELETE_COLLECTION: bool = false;
const DEFAULT_EMBEDDING_STORE_TABLE_NAME: &str = "langchain_pg_embedding";
const DEFAULT_COLLECTION_STORE_TABLE_NAME: &str = "langchain_pg_collection";

pub struct StoreBuilder {
pub struct StoreBuilder<F> {
pool: Option<Pool<Postgres>>,
embedder: Option<Arc<dyn Embedder>>,
connection_url: Option<String>,
Expand All @@ -25,11 +26,11 @@ pub struct StoreBuilder {
collection_uuid: String,
collection_table_name: String,
collection_metadata: HashMap<String, Value>,
vstore_options: VecStoreOptions,
vstore_options: VecStoreOptions<F>,
hns_index: Option<HNSWIndex>,
}

impl StoreBuilder {
impl StoreBuilder<PgFilter> {
// Returns a new StoreBuilder instance with default values for each option
pub fn new() -> Self {
StoreBuilder {
Expand All @@ -43,7 +44,7 @@ impl StoreBuilder {
collection_name: DEFAULT_COLLECTION_NAME.into(),
collection_table_name: DEFAULT_COLLECTION_STORE_TABLE_NAME.into(),
collection_metadata: HashMap::new(),
vstore_options: VecStoreOptions::default(),
vstore_options: VecStoreOptions::new(),
hns_index: None,
}
}
Expand Down Expand Up @@ -88,12 +89,12 @@ impl StoreBuilder {
self
}

pub fn vstore_options(mut self, vstore_options: VecStoreOptions) -> Self {
pub fn vstore_options(mut self, vstore_options: PgOptions) -> Self {
self.vstore_options = vstore_options;
self
}

fn collecion_metadata(mut self, collecion_metadata: HashMap<String, Value>) -> Self {
fn collection_metadata(mut self, collecion_metadata: HashMap<String, Value>) -> Self {
self.collection_metadata = collecion_metadata;
self
}
Expand Down
125 changes: 93 additions & 32 deletions src/vectorstore/pgvector/pgvector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,74 @@ pub struct Store {
pub(crate) pre_delete_collection: bool,
pub(crate) vector_dimensions: i32,
pub(crate) hns_index: Option<HNSWIndex>,
pub(crate) vstore_options: VecStoreOptions,
pub(crate) vstore_options: PgOptions,
}

#[derive(Debug, Clone, PartialEq)]
pub enum PgFilter {
Eq(PgLit, PgLit),
Cmp(std::cmp::Ordering, PgLit, PgLit),
In(PgLit, Vec<String>),
And(Vec<PgFilter>),
Or(Vec<PgFilter>),
}

pub type Column = String;

pub type Path = Vec<String>;

#[derive(Debug, Clone, PartialEq)]
pub enum PgLit {
JsonField(Path),
LitStr(String),
RawJson(Value),
}

impl ToString for PgLit {
fn to_string(&self) -> String {
match self {
PgLit::LitStr(str) => format!("'{}'", str.clone()),
PgLit::JsonField(path) => format!("cmetadata#>>'{{{}}}'", path.join(",")),
PgLit::RawJson(value) => serde_json::to_string(value).unwrap_or("null".to_string()),
}
}
}

impl ToString for PgFilter {
fn to_string(&self) -> String {
match self {
PgFilter::Eq(a, b) => format!("{} = {}", a.to_string(), b.to_string()),
PgFilter::Cmp(ordering, a, b) => {
let op = match ordering {
std::cmp::Ordering::Less => "<",
std::cmp::Ordering::Greater => ">",
std::cmp::Ordering::Equal => "=",
};
format!("{} {} {}", a.to_string(), op, b.to_string())
}
PgFilter::In(a, values) => {
format!(
"{} = ANY(ARRAY[{}])",
a.to_string(),
values
.iter()
.map(|s| format!("'{}'", s))
.collect::<Vec<String>>()
.join(",")
)
}
PgFilter::And(pgfilters) => pgfilters
.iter()
.map(|pgf| pgf.to_string())
.collect::<Vec<String>>()
.join(" AND "),
PgFilter::Or(pgfilters) => pgfilters
.iter()
.map(|pgf| pgf.to_string())
.collect::<Vec<String>>()
.join(" OR "),
}
}
}

pub struct HNSWIndex {
Expand All @@ -43,28 +110,21 @@ impl HNSWIndex {
}

impl Store {
// getFilters return metadata filters, now only support map[key]value pattern
// TODO: should support more types like {"key1": {"key2":"values2"}} or {"key": ["value1", "values2"]}.
fn get_filters(&self, opt: &VecStoreOptions) -> Result<HashMap<String, Value>, Box<dyn Error>> {
fn get_filters(&self, opt: &PgOptions) -> Result<String, Box<dyn Error>> {
match &opt.filters {
Some(Value::Object(map)) => {
// Convert serde_json Map to HashMap<String, Value>
let filters = map.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
Ok(filters)
}
None => Ok(HashMap::new()), // No filters provided
_ => Err("Invalid filters format".into()), // Filters provided but not in the expected format
Some(pgfilter) => Ok(pgfilter.to_string()),
None => Ok("TRUE".to_string()), // No filters provided
}
}

fn get_name_space(&self, opt: &VecStoreOptions) -> String {
fn get_name_space(&self, opt: &PgOptions) -> String {
match &opt.name_space {
Some(name_space) => name_space.clone(),
None => self.collection_name.clone(),
}
}

fn get_score_threshold(&self, opt: &VecStoreOptions) -> Result<f32, Box<dyn Error>> {
fn get_score_threshold(&self, opt: &PgOptions) -> Result<f32, Box<dyn Error>> {
match &opt.score_threshold {
Some(score_threshold) => {
if *score_threshold < 0.0 || *score_threshold > 1.0 {
Expand Down Expand Up @@ -102,12 +162,28 @@ impl Store {
Ok(())
}
}

pub type PgOptions = VecStoreOptions<PgFilter>;

impl Default for PgOptions {
fn default() -> Self {
PgOptions {
filters: None,
score_threshold: None,
name_space: None,
embedder: None,
}
}
}

#[async_trait]
impl VectorStore for Store {
type Options = PgOptions;

async fn add_documents(
&self,
docs: &[Document],
opt: &VecStoreOptions,
opt: &PgOptions,
) -> Result<Vec<String>, Box<dyn Error>> {
if opt.score_threshold.is_some() || opt.filters.is_some() || opt.name_space.is_some() {
return Err(Box::new(std::io::Error::new(
Expand Down Expand Up @@ -162,25 +238,10 @@ impl VectorStore for Store {
&self,
query: &str,
limit: usize,
opt: &VecStoreOptions,
opt: &PgOptions,
) -> Result<Vec<Document>, Box<dyn Error>> {
let collection_name = self.get_name_space(opt);
let filter = self.get_filters(opt)?;
let mut where_querys = filter
.iter()
.map(|(k, v)| {
format!(
"(data.cmetadata ->> '{}') = '{}'",
k,
v.to_string().trim_matches('"')
)
})
.collect::<Vec<String>>()
.join(" AND ");

if where_querys.is_empty() {
where_querys = "TRUE".to_string();
}
let where_filter = self.get_filters(opt)?;

let sql = format!(
r#"WITH filtered_embedding_dims AS MATERIALIZED (
Expand Down Expand Up @@ -213,7 +274,7 @@ impl VectorStore for Store {
self.collection_table_name,
self.collection_table_name,
collection_name,
where_querys,
where_filter,
);

let query_vector = self.embedder.embed_query(query).await?;
Expand Down
10 changes: 7 additions & 3 deletions src/vectorstore/qdrant/qdrant.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use async_trait::async_trait;
use qdrant_client::client::Payload;
use qdrant_client::qdrant::{Filter, PointStruct, SearchPointsBuilder, UpsertPointsBuilder};
use serde_json::json;
use serde_json::{json, Value};
use std::error::Error;
use std::sync::Arc;

Expand All @@ -27,14 +27,18 @@ pub struct Store {
pub search_filter: Option<Filter>,
}

type QdrantOptions = VecStoreOptions<Value>;

#[async_trait]
impl VectorStore for Store {
type Options = QdrantOptions;

/// Add documents to the store.
/// Returns a list of document IDs added to the Qdrant collection.
async fn add_documents(
&self,
docs: &[Document],
opt: &VecStoreOptions,
opt: &QdrantOptions,
) -> Result<Vec<String>, Box<dyn Error>> {
let embedder = opt.embedder.as_ref().unwrap_or(&self.embedder);
let texts: Vec<String> = docs.iter().map(|d| d.page_content.clone()).collect();
Expand Down Expand Up @@ -69,7 +73,7 @@ impl VectorStore for Store {
&self,
query: &str,
limit: usize,
opt: &VecStoreOptions,
opt: &QdrantOptions,
) -> Result<Vec<Document>, Box<dyn Error>> {
if opt.name_space.is_some() {
return Err("Qdrant doesn't support namespaces".into());
Expand Down
10 changes: 7 additions & 3 deletions src/vectorstore/sqlite_vec/sqlite_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ pub struct Store {
pub(crate) embedder: Arc<dyn Embedder>,
}

pub type SqliteOptions = VecStoreOptions<Value>;

impl Store {
pub async fn initialize(&self) -> Result<(), Box<dyn Error>> {
self.create_table_if_not_exists().await?;
Expand Down Expand Up @@ -70,7 +72,7 @@ impl Store {
Ok(())
}

fn get_filters(&self, opt: &VecStoreOptions) -> Result<HashMap<String, Value>, Box<dyn Error>> {
fn get_filters(&self, opt: &SqliteOptions) -> Result<HashMap<String, Value>, Box<dyn Error>> {
match &opt.filters {
Some(Value::Object(map)) => {
// Convert serde_json Map to HashMap<String, Value>
Expand All @@ -85,10 +87,12 @@ impl Store {

#[async_trait]
impl VectorStore for Store {
type Options = SqliteOptions;

async fn add_documents(
&self,
docs: &[Document],
opt: &VecStoreOptions,
opt: &Self::Options,
) -> Result<Vec<String>, Box<dyn Error>> {
let texts: Vec<String> = docs.iter().map(|d| d.page_content.clone()).collect();

Expand Down Expand Up @@ -136,7 +140,7 @@ impl VectorStore for Store {
&self,
query: &str,
limit: usize,
opt: &VecStoreOptions,
opt: &Self::Options,
) -> Result<Vec<Document>, Box<dyn Error>> {
let table = &self.table;

Expand Down
Loading

0 comments on commit 60c7125

Please sign in to comment.