Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add reranker to search #381

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 92 additions & 1 deletion client/src/nv_ingest_client/util/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import time
from urllib.parse import urlparse
from typing import Union, Dict
import requests


def _dict_to_params(collections_dict: dict, write_params: dict):
Expand Down Expand Up @@ -761,6 +762,13 @@ def nvingest_retrieval(
model_name: str = "nvidia/nv-embedqa-e5-v5",
output_fields: List[str] = ["text", "source", "content_metadata"],
gpu_search: bool = True,
nv_ranker: bool = False,
nv_ranker_endpoint: str = "http://localhost:8015",
nv_ranker_model_name: str = "nvidia/llama-3.2-nv-rerankqa-1b-v2",
nv_ranker_nvidia_api_key: str = "",
nv_ranker_truncate: str = "END",
nv_ranker_top_k: int = 5,
nv_ranker_max_batch_size: int = 64,
):
"""
This function takes the input queries and conducts a hybrid/dense
Expand Down Expand Up @@ -792,7 +800,20 @@ def nvingest_retrieval(
The path where the sparse model has been loaded.
model_name : str, optional
The name of the dense embedding model available in the NIM embedding endpoint.

nv_ranker : bool
Set to True to use the nvidia reranker.
nv_ranker_endpoint : str
The endpoint to the nvidia reranker
nv_ranker_model_name: str
The name of the model host in the nvidia reranker
nv_ranker_nvidia_api_key : str,
The nvidia reranker api key, necessary when using non-local asset
truncate : str [`END`, `NONE`]
Truncate the incoming texts if length is longer than the model allows.
nv_ranker_max_batch_size : int
Max size for the number of candidates to rerank.
nv_ranker_top_k : int,
The number of candidates to return after reranking.
Returns
-------
List
Expand All @@ -819,6 +840,22 @@ def nvingest_retrieval(
)
else:
results = dense_retrieval(queries, collection_name, client, embed_model, top_k, output_fields=output_fields)
if nv_ranker:
rerank_results = []
for query, candidates in zip(queries, results):
rerank_results.append(
nv_rerank(
query,
candidates,
reranker_endpoint=nv_ranker_endpoint,
model_name=nv_ranker_model_name,
nvidia_api_key=nv_ranker_nvidia_api_key,
truncate=nv_ranker_truncate,
topk=nv_ranker_top_k,
max_batch_size=nv_ranker_max_batch_size,
)
)

return results


Expand Down Expand Up @@ -850,3 +887,57 @@ def remove_records(source_name: str, collection_name: str, milvus_uri: str = "ht
filter=f'(source["source_name"] == "{source_name}")',
)
return result_ids


def nv_rerank(
query,
candidates,
reranker_endpoint: str = "http://localhost:8015",
model_name: str = "nvidia/llama-3.2-nv-rerankqa-1b-v2",
nvidia_api_key: str = "",
truncate: str = "END",
max_batch_size: int = 64,
topk: int = 5,
):
"""
This function allows a user to rerank a set of candidates using the nvidia reranker nim.

Parameters
----------
query : str
Query the candidates are supposed to answer.
candidates : list
List of the candidates to rerank.
reranker_endpoint : str
The endpoint to the nvidia reranker
model_name: str
The name of the model host in the nvidia reranker
nvidia_api_key : str,
The nvidia reranker api key, necessary when using non-local asset
truncate : str [`END`, `NONE`]
Truncate the incoming texts if length is longer than the model allows.
max_batch_size : int
Max size for the number of candidates to rerank.
topk : int,
The number of candidates to return after reranking.

Returns
-------
Dict
Dictionary with top_k reranked candidates.
"""
# reranker = NVIDIARerank(base_url=reranker_endpoint, nvidia_api_key=nvidia_api_key, top_n=top_k)
headers = {"accept": "application/json", "Content-Type": "application/json"}
texts = []
map_candidates = {}
for idx, candidate in enumerate(candidates):
map_candidates[idx] = candidate
texts.append({"text": candidate["entity"]["text"]})
payload = {"model": model_name, "query": {"text": query}, "passages": texts, "truncate": truncate}
response = requests.post(f"{reranker_endpoint}/v1/ranking", headers=headers, json=payload)

rank_results = []
for rank_vals in response.json()["rankings"]:
idx = rank_vals["index"]
rank_results.append(map_candidates[idx])
return rank_results
22 changes: 22 additions & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,28 @@ services:
capabilities: [gpu]
runtime: nvidia

reranker:
# NIM ON
image: ${RERANKER_IMAGE:-nvcr.io/nim/nvidia/llama-3.2-nv-rerankqa-1b-v2}:${RERANKER_TAG:-1.3.0}
shm_size: 16gb
ports:
- "8015:8000"
environment:
- NIM_HTTP_API_PORT=8000
- NIM_TRITON_LOG_VERBOSE=1
- NGC_API_KEY=${NIM_NGC_API_KEY:-${NGC_API_KEY:-ngcapikey}}
- CUDA_VISIBLE_DEVICES=0
deploy:
resources:
reservations:
devices:
- driver: nvidia
device_ids: ["1"]
capabilities: [gpu]
runtime: nvidia
profiles:
- retrieval

nv-ingest-ms-runtime:
image: nvcr.io/ohlfw0olaadg/ea-participants/nv-ingest:24.10.1
build:
Expand Down
Loading