Skip to content

Commit

Permalink
Merge pull request #1068 from sky-2002/epsilla-retriever
Browse files Browse the repository at this point in the history
feat(dspy): Add epsilla retriever
  • Loading branch information
arnavsinghvi11 authored Jun 13, 2024
2 parents 833ded7 + 41c399f commit fc58b9c
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 1 deletion.
45 changes: 45 additions & 0 deletions dspy/retrieve/epsilla_rm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from collections import defaultdict # noqa: F401
from typing import Dict, List, Union # noqa: UP035

import dspy
from dsp.utils import dotdict

try:
from pyepsilla import vectordb
except ImportError:
raise ImportError( # noqa: B904
"The 'pyepsilla' extra is required to use EpsillaRM. Install it with `pip install dspy-ai[epsilla]`",
)


class EpsillaRM(dspy.Retrieve):
def __init__(
self,
epsilla_client: vectordb.Client,
db_name: str,
db_path: str,
table_name: str,
k: int = 3,
page_content: str = "document",
):
self._epsilla_client = epsilla_client
self._epsilla_client.load_db(db_name=db_name, db_path=db_path)
self._epsilla_client.use_db(db_name=db_name)
self.page_content = page_content
self.table_name = table_name

super().__init__(k=k)

def forward(self, query_or_queries: Union[str, List[str]], k: Union[int, None] = None, **kwargs) -> dspy.Prediction: # noqa: ARG002
queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries
queries = [q for q in queries if q]
limit = k if k else self.k
all_query_results: list = []

passages: Dict = defaultdict(float)

for result_dict in all_query_results:
for result in result_dict:
passages[result[self.page_content]] += result["@distance"]
sorted_passages = sorted(passages.items(), key=lambda x: x[1], reverse=False)[:limit]
return dspy.Prediction(passages=[dotdict({"long_text": passage}) for passage, _ in sorted_passages])
70 changes: 69 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ anthropic = ["anthropic~=0.18.0"]
chromadb = ["chromadb~=0.4.14"]
qdrant = ["qdrant-client>=1.6.2", "fastembed>=0.2.0"]
marqo = ["marqo"]
epsilla = ["pyepsilla~=0.3.7"]
pinecone = ["pinecone-client~=2.2.4"]
weaviate = ["weaviate-client~=4.5.4"]
milvus = ["pymilvus~=2.3.7"]
Expand Down Expand Up @@ -99,6 +100,7 @@ anthropic = { version = "^0.18.0", optional = true }
chromadb = { version = "^0.4.14", optional = true }
fastembed = { version = ">=0.2.0", optional = true }
marqo = { version = "*", optional = true }
pyepsilla = {version = "^0.3.7", optional = true}
qdrant-client = { version = "^1.6.2", optional = true }
pinecone-client = { version = "^2.2.4", optional = true }
weaviate-client = { version = "^4.5.4", optional = true }
Expand Down Expand Up @@ -140,6 +142,7 @@ ipykernel = "^6.29.4"
chromadb = ["chromadb"]
qdrant = ["qdrant-client", "fastembed"]
marqo = ["marqo"]
epsilla = ["pyepsilla"]
pinecone = ["pinecone-client"]
weaviate = ["weaviate-client"]
milvus = ["pymilvus"]
Expand Down

0 comments on commit fc58b9c

Please sign in to comment.