Skip to content

Commit

Permalink
small refactor changes for better alignment with other vector stores
Browse files Browse the repository at this point in the history
  • Loading branch information
jgbradley1 committed Jan 13, 2025
1 parent e78680b commit b1d6f0c
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 55 deletions.
2 changes: 1 addition & 1 deletion graphrag/vector_stores/azure_ai_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)


class AzureAISearch(BaseVectorStore):
class AzureAISearchVectorStore(BaseVectorStore):
"""Azure AI Search vector storage implementation."""

index_client: SearchIndexClient
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,32 +57,28 @@ def connect(self, **kwargs: Any) -> Any:
self._container_name = collection_name

self.vector_size = kwargs.get("vector_size", DEFAULT_VECTOR_SIZE)
self.create_database()
self.create_container()
self._create_database()
self._create_container()

def create_database(self) -> None:
def _create_database(self) -> None:
"""Create the database if it doesn't exist."""
database_name = self._database_name
self._cosmos_client.create_database_if_not_exists(id=database_name)
self._database_client = self._cosmos_client.get_database_client(database_name)
self._cosmos_client.create_database_if_not_exists(id=self._database_name)
self._database_client = self._cosmos_client.get_database_client(self._database_name)

def delete_database(self) -> None:
def _delete_database(self) -> None:
"""Delete the database if it exists."""
if self.database_exists():
if self._database_exists():
self._cosmos_client.delete_database(self._database_name)

def database_exists(self) -> bool:
def _database_exists(self) -> bool:
"""Check if the database exists."""
database_name = self._database_name
database_names = [
existing_database_names = [
database["id"] for database in self._cosmos_client.list_databases()
]
return database_name in database_names
return self._database_name in existing_database_names

def create_container(self) -> None:
def _create_container(self) -> None:
"""Create the container if it doesn't exist."""
database_client = self._database_client
container_name = self._container_name
partition_key = PartitionKey(path="/id", kind="Hash")

# Define the container vector policy
Expand All @@ -107,39 +103,36 @@ def create_container(self) -> None:
}

# Create the container and container client
database_client.create_container_if_not_exists(
id=container_name,
self._database_client.create_container_if_not_exists(
id=self._container_name,
partition_key=partition_key,
indexing_policy=indexing_policy,
vector_embedding_policy=vector_embedding_policy,
)
self._container_client = database_client.get_container_client(container_name)
self._container_client = self._database_client.get_container_client(self._container_name)

def delete_container(self) -> None:
def _delete_container(self) -> None:
"""Delete the vector store container in the database if it exists."""
database_client = self._database_client
if self.container_exists():
database_client.delete_container(self._container_name)
if self._container_exists():
self._database_client.delete_container(self._container_name)

def container_exists(self) -> bool:
def _container_exists(self) -> bool:
"""Check if the container name exists in the database."""
database_client = self._database_client
container_names = [
container["id"] for container in database_client.list_containers()
existing_container_names = [
container["id"] for container in self._database_client.list_containers()
]
return self._container_name in container_names
return self._container_name in existing_container_names

def load_documents(
self, documents: list[VectorStoreDocument], overwrite: bool = True
) -> None:
"""Load documents into CosmosDB."""
# Create the CosmosDB container, if it doesn't exist
# Create a CosmosDB container on overwrite
if overwrite:
self.delete_container()
self.create_container()
self._delete_container()
self._create_container()

container_client = self._container_client
if container_client is None:
if self._container_client is None:
msg = "Container client is not initialized."
raise ValueError(msg)

Expand All @@ -152,32 +145,19 @@ def load_documents(
"text": doc.text,
"attributes": json.dumps(doc.attributes),
}
container_client.upsert_item(doc_json)

def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
"""Build a query filter to filter documents by a list of ids."""
if include_ids is None or len(include_ids) == 0:
self.query_filter = None
else:
if isinstance(include_ids[0], str):
id_filter = ", ".join([f"'{id}'" for id in include_ids])
else:
id_filter = ", ".join([str(id) for id in include_ids])
self.query_filter = f"SELECT * FROM c WHERE c.id IN ({id_filter})" # noqa: S608
return self.query_filter
self._container_client.upsert_item(doc_json)

def similarity_search_by_vector(
self, query_embedding: list[float], k: int = 10, **kwargs: Any
) -> list[VectorStoreSearchResult]:
"""Perform a vector-based similarity search."""
container_client = self._container_client
if container_client is None:
if self._container_client is None:
msg = "Container client is not initialized."
raise ValueError(msg)

query = f"SELECT TOP {k} c.id, c.text, c.vector, c.attributes, VectorDistance(c.vector, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.vector, @embedding)" # noqa: S608
query_params = [{"name": "@embedding", "value": query_embedding}]
items = container_client.query_items(
items = self._container_client.query_items(
query=query,
parameters=query_params,
enable_cross_partition_query=True,
Expand Down Expand Up @@ -207,14 +187,25 @@ def similarity_search_by_text(
)
return []

def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
"""Build a query filter to filter documents by a list of ids."""
if include_ids is None or len(include_ids) == 0:
self.query_filter = None
else:
if isinstance(include_ids[0], str):
id_filter = ", ".join([f"'{id}'" for id in include_ids])
else:
id_filter = ", ".join([str(id) for id in include_ids])
self.query_filter = f"SELECT * FROM c WHERE c.id IN ({id_filter})" # noqa: S608
return self.query_filter

def search_by_id(self, id: str) -> VectorStoreDocument:
"""Search for a document by id."""
container_client = self._container_client
if container_client is None:
if self._container_client is None:
msg = "Container client is not initialized."
raise ValueError(msg)

item = container_client.read_item(item=id, partition_key=id)
item = self._container_client.read_item(item=id, partition_key=id)
return VectorStoreDocument(
id=item.get("id", ""),
vector=item.get("vector", []),
Expand Down
6 changes: 3 additions & 3 deletions graphrag/vector_stores/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from enum import Enum
from typing import ClassVar

from graphrag.vector_stores.azure_ai_search import AzureAISearch
from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
from graphrag.vector_stores.base import BaseVectorStore
from graphrag.vector_stores.cosmosdb_vector_store import CosmosDBVectoreStore
from graphrag.vector_stores.cosmosdb import CosmosDBVectoreStore
from graphrag.vector_stores.lancedb import LanceDBVectorStore


Expand Down Expand Up @@ -42,7 +42,7 @@ def create_vector_store(
case VectorStoreType.LanceDB:
return LanceDBVectorStore(**kwargs)
case VectorStoreType.AzureAISearch:
return AzureAISearch(**kwargs)
return AzureAISearchVectorStore(**kwargs)
case VectorStoreType.CosmosDB:
return CosmosDBVectoreStore(**kwargs)
case _:
Expand Down

0 comments on commit b1d6f0c

Please sign in to comment.