Skip to content

Commit

Permalink
Allow kwargs in body_func (#46)
Browse files Browse the repository at this point in the history
* Allow kwargs in body_func

* Remove async implementation

* Remove unused imports
  • Loading branch information
gdahia authored Sep 26, 2024
1 parent 439efe2 commit 00e593f
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions libs/elasticsearch/langchain_elasticsearch/retrievers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
import logging
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast
from typing import (
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Sequence,
Union,
cast,
)

from elasticsearch import Elasticsearch
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.callbacks import (
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever

Expand All @@ -29,6 +41,8 @@ class ElasticsearchRetriever(BaseRetriever):
document_mapper: Function to map Elasticsearch hits to LangChain Documents.
"""

_expects_other_args = True

es_client: Elasticsearch
index_name: Union[str, Sequence[str]]
body_func: Callable[[str], Dict]
Expand Down Expand Up @@ -94,12 +108,12 @@ def from_es_params(
)

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]:
if not self.es_client or not self.document_mapper:
raise ValueError("faulty configuration") # should not happen

body = self.body_func(query)
body = self.body_func(query, **kwargs)
results = self.es_client.search(index=self.index_name, body=body)
return [self.document_mapper(hit) for hit in results["hits"]["hits"]]

Expand Down

0 comments on commit 00e593f

Please sign in to comment.