Skip to content

Commit

Permalink
Add Astra DB vector store implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Sep 23, 2024
1 parent 16b4ea5 commit 36df323
Show file tree
Hide file tree
Showing 7 changed files with 394 additions and 5 deletions.
1 change: 1 addition & 0 deletions dictionary.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ numpy
pypi
nbformat
semversioner
astrapy

# Library Methods
iterrows
Expand Down
2 changes: 1 addition & 1 deletion graphrag/index/verbs/text/embed/text_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ async def text_embed(
max_tokens: !ENV ${GRAPHRAG_MAX_TOKENS:6000} # The max tokens to use for openai
organization: !ENV ${GRAPHRAG_OPENAI_ORGANIZATION} # The organization to use for openai
vector_store: # The optional configuration for the vector store
type: lancedb # The type of vector store to use, available options are: azure_ai_search, lancedb
type: lancedb # The type of vector store to use, available options are: azure_ai_search, lancedb, astradb
<...>
```
"""
Expand Down
2 changes: 2 additions & 0 deletions graphrag/vector_stores/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

"""A package containing vector-storage implementations."""

from .astradb import AstraDBVectorStore
from .azure_ai_search import AzureAISearch
from .base import BaseVectorStore, VectorStoreDocument, VectorStoreSearchResult
from .lancedb import LanceDBVectorStore
from .typing import VectorStoreFactory, VectorStoreType

__all__ = [
"AstraDBVectorStore",
"AzureAISearch",
"BaseVectorStore",
"LanceDBVectorStore",
Expand Down
134 changes: 134 additions & 0 deletions graphrag/vector_stores/astradb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""The Astra DB vector store implementation package."""

import json
from typing import Any

from astrapy import DataAPIClient
from typing_extensions import override

from graphrag.model.types import TextEmbedder

from .base import (
DEFAULT_VECTOR_SIZE,
BaseVectorStore,
VectorStoreDocument,
VectorStoreSearchResult,
)


class AstraDBVectorStore(BaseVectorStore):
"""The Astra DB vector storage implementation."""

@override
def connect(
self,
*,
token: str | None = None,
database_id: str | None = None,
namespace: str | None = None,
**kwargs: Any,
) -> None:
"""Connect to the Astra DB database.
Parameters
----------
token :
The Astra DB application token (AstraCS:xyz...).
database_id :
The database ID or the corresponding API Endpoint.
namespace :
The database namespace. If not provided, an environment-specific default
namespace is used.
**kwargs :
Additional arguments passed to the ``DataAPIClient.get_database`` method.
"""
self.db_connection = DataAPIClient(token).get_database(
database_id, namespace=namespace, **kwargs
)

@override
def load_documents(
self, documents: list[VectorStoreDocument], overwrite: bool = True
) -> None:
if overwrite:
self.db_connection.drop_collection(self.collection_name)

if not documents:
return

if not self.document_collection or overwrite:
dimension = DEFAULT_VECTOR_SIZE
for doc in documents:
if doc.vector:
dimension = len(doc.vector)
break
self.document_collection = self.db_connection.create_collection(
self.collection_name,
dimension=dimension,
check_exists=False,
)

batch = [
{
"content": doc.text,
"_id": doc.id,
"$vector": doc.vector,
"metadata": json.dumps(doc.attributes),
}
for doc in documents
if doc.vector is not None
]

if batch and len(batch) > 0:
self.document_collection.insert_many(batch)

@override
def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
if include_ids is None or len(include_ids) == 0:
self.query_filter = {}
else:
self.query_filter = {"_id": {"$in": include_ids}}
return self.query_filter

@override
def similarity_search_by_vector(
self, query_embedding: list[float], k: int = 10, **kwargs: Any
) -> list[VectorStoreSearchResult]:
response = self.document_collection.find(
filter=self.query_filter or {},
projection={
"_id": True,
"content": True,
"metadata": True,
"$vector": True,
},
limit=k,
include_similarity=True,
sort={"$vector": query_embedding},
)
return [
VectorStoreSearchResult(
document=VectorStoreDocument(
id=doc["_id"],
text=doc["content"],
vector=doc["$vector"],
attributes=doc["metadata"],
),
score=doc["$similarity"],
)
for doc in response
]

@override
def similarity_search_by_text(
self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any
) -> list[VectorStoreSearchResult]:
query_embedding = text_embedder(text)
if query_embedding:
return self.similarity_search_by_vector(
query_embedding=query_embedding, k=k, **kwargs
)
return []
8 changes: 5 additions & 3 deletions graphrag/vector_stores/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from enum import Enum
from typing import ClassVar

from .azure_ai_search import AzureAISearch
from .lancedb import LanceDBVectorStore
from . import AstraDBVectorStore, AzureAISearch, BaseVectorStore, LanceDBVectorStore


class VectorStoreType(str, Enum):
"""The supported vector store types."""

AstraDB = "astradb"
LanceDB = "lancedb"
AzureAISearch = "azure_ai_search"

Expand All @@ -30,9 +30,11 @@ def register(cls, vector_store_type: str, vector_store: type):
@classmethod
def get_vector_store(
cls, vector_store_type: VectorStoreType | str, kwargs: dict
) -> LanceDBVectorStore | AzureAISearch:
) -> BaseVectorStore:
"""Get the vector store type from a string."""
match vector_store_type:
case VectorStoreType.AstraDB:
return AstraDBVectorStore(**kwargs)
case VectorStoreType.LanceDB:
return LanceDBVectorStore(**kwargs)
case VectorStoreType.AzureAISearch:
Expand Down
Loading

0 comments on commit 36df323

Please sign in to comment.