diff --git a/graphrag/vector_stores/azure_ai_search.py b/graphrag/vector_stores/azure_ai_search.py index eebf2fa05e..02e3a35c41 100644 --- a/graphrag/vector_stores/azure_ai_search.py +++ b/graphrag/vector_stores/azure_ai_search.py @@ -33,7 +33,7 @@ ) -class AzureAISearch(BaseVectorStore): +class AzureAISearchVectorStore(BaseVectorStore): """Azure AI Search vector storage implementation.""" index_client: SearchIndexClient diff --git a/graphrag/vector_stores/cosmosdb_vector_store.py b/graphrag/vector_stores/cosmosdb.py similarity index 79% rename from graphrag/vector_stores/cosmosdb_vector_store.py rename to graphrag/vector_stores/cosmosdb.py index 7c602eac2d..3683203cd6 100644 --- a/graphrag/vector_stores/cosmosdb_vector_store.py +++ b/graphrag/vector_stores/cosmosdb.py @@ -57,32 +57,28 @@ def connect(self, **kwargs: Any) -> Any: self._container_name = collection_name self.vector_size = kwargs.get("vector_size", DEFAULT_VECTOR_SIZE) - self.create_database() - self.create_container() + self._create_database() + self._create_container() - def create_database(self) -> None: + def _create_database(self) -> None: """Create the database if it doesn't exist.""" - database_name = self._database_name - self._cosmos_client.create_database_if_not_exists(id=database_name) - self._database_client = self._cosmos_client.get_database_client(database_name) + self._cosmos_client.create_database_if_not_exists(id=self._database_name) + self._database_client = self._cosmos_client.get_database_client(self._database_name) - def delete_database(self) -> None: + def _delete_database(self) -> None: """Delete the database if it exists.""" - if self.database_exists(): + if self._database_exists(): self._cosmos_client.delete_database(self._database_name) - def database_exists(self) -> bool: + def _database_exists(self) -> bool: """Check if the database exists.""" - database_name = self._database_name - database_names = [ + existing_database_names = [ database["id"] for database in self._cosmos_client.list_databases() ] - return database_name in database_names + return self._database_name in existing_database_names - def create_container(self) -> None: + def _create_container(self) -> None: """Create the container if it doesn't exist.""" - database_client = self._database_client - container_name = self._container_name partition_key = PartitionKey(path="/id", kind="Hash") # Define the container vector policy @@ -107,39 +103,36 @@ def create_container(self) -> None: } # Create the container and container client - database_client.create_container_if_not_exists( - id=container_name, + self._database_client.create_container_if_not_exists( + id=self._container_name, partition_key=partition_key, indexing_policy=indexing_policy, vector_embedding_policy=vector_embedding_policy, ) - self._container_client = database_client.get_container_client(container_name) + self._container_client = self._database_client.get_container_client(self._container_name) - def delete_container(self) -> None: + def _delete_container(self) -> None: """Delete the vector store container in the database if it exists.""" - database_client = self._database_client - if self.container_exists(): - database_client.delete_container(self._container_name) + if self._container_exists(): + self._database_client.delete_container(self._container_name) - def container_exists(self) -> bool: + def _container_exists(self) -> bool: """Check if the container name exists in the database.""" - database_client = self._database_client - container_names = [ - container["id"] for container in database_client.list_containers() + existing_container_names = [ + container["id"] for container in self._database_client.list_containers() ] - return self._container_name in container_names + return self._container_name in existing_container_names def load_documents( self, documents: list[VectorStoreDocument], overwrite: bool = True ) -> None: """Load documents into CosmosDB.""" - # Create the CosmosDB container, if it doesn't exist + # Create a CosmosDB container on overwrite if overwrite: - self.delete_container() - self.create_container() + self._delete_container() + self._create_container() - container_client = self._container_client - if container_client is None: + if self._container_client is None: msg = "Container client is not initialized." raise ValueError(msg) @@ -152,32 +145,19 @@ def load_documents( "text": doc.text, "attributes": json.dumps(doc.attributes), } - container_client.upsert_item(doc_json) - - def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: - """Build a query filter to filter documents by a list of ids.""" - if include_ids is None or len(include_ids) == 0: - self.query_filter = None - else: - if isinstance(include_ids[0], str): - id_filter = ", ".join([f"'{id}'" for id in include_ids]) - else: - id_filter = ", ".join([str(id) for id in include_ids]) - self.query_filter = f"SELECT * FROM c WHERE c.id IN ({id_filter})" # noqa: S608 - return self.query_filter + self._container_client.upsert_item(doc_json) def similarity_search_by_vector( self, query_embedding: list[float], k: int = 10, **kwargs: Any ) -> list[VectorStoreSearchResult]: """Perform a vector-based similarity search.""" - container_client = self._container_client - if container_client is None: + if self._container_client is None: msg = "Container client is not initialized." raise ValueError(msg) query = f"SELECT TOP {k} c.id, c.text, c.vector, c.attributes, VectorDistance(c.vector, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.vector, @embedding)" # noqa: S608 query_params = [{"name": "@embedding", "value": query_embedding}] - items = container_client.query_items( + items = self._container_client.query_items( query=query, parameters=query_params, enable_cross_partition_query=True, @@ -207,14 +187,25 @@ def similarity_search_by_text( ) return [] + def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: + """Build a query filter to filter documents by a list of ids.""" + if include_ids is None or len(include_ids) == 0: + self.query_filter = None + else: + if isinstance(include_ids[0], str): + id_filter = ", ".join([f"'{id}'" for id in include_ids]) + else: + id_filter = ", ".join([str(id) for id in include_ids]) + self.query_filter = f"SELECT * FROM c WHERE c.id IN ({id_filter})" # noqa: S608 + return self.query_filter + def search_by_id(self, id: str) -> VectorStoreDocument: """Search for a document by id.""" - container_client = self._container_client - if container_client is None: + if self._container_client is None: msg = "Container client is not initialized." raise ValueError(msg) - item = container_client.read_item(item=id, partition_key=id) + item = self._container_client.read_item(item=id, partition_key=id) return VectorStoreDocument( id=item.get("id", ""), vector=item.get("vector", []), diff --git a/graphrag/vector_stores/factory.py b/graphrag/vector_stores/factory.py index 488be08838..1c37316d0c 100644 --- a/graphrag/vector_stores/factory.py +++ b/graphrag/vector_stores/factory.py @@ -6,9 +6,9 @@ from enum import Enum from typing import ClassVar -from graphrag.vector_stores.azure_ai_search import AzureAISearch +from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore from graphrag.vector_stores.base import BaseVectorStore -from graphrag.vector_stores.cosmosdb_vector_store import CosmosDBVectoreStore +from graphrag.vector_stores.cosmosdb import CosmosDBVectoreStore from graphrag.vector_stores.lancedb import LanceDBVectorStore @@ -42,7 +42,7 @@ def create_vector_store( case VectorStoreType.LanceDB: return LanceDBVectorStore(**kwargs) case VectorStoreType.AzureAISearch: - return AzureAISearch(**kwargs) + return AzureAISearchVectorStore(**kwargs) case VectorStoreType.CosmosDB: return CosmosDBVectoreStore(**kwargs) case _: