Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce Drift Response and Streaming endpoint #1624

Merged
merged 10 commits into from
Jan 15, 2025
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20250115181733910773.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Add Drift Reduce response and streaming endpoint"
}
2 changes: 2 additions & 0 deletions graphrag/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
basic_search,
basic_search_streaming,
drift_search,
drift_search_streaming,
global_search,
global_search_streaming,
local_search,
Expand All @@ -29,6 +30,7 @@
"local_search",
"local_search_streaming",
"drift_search",
"drift_search_streaming",
"basic_search",
"basic_search_streaming",
# prompt tuning API
Expand Down
98 changes: 89 additions & 9 deletions graphrag/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,87 @@ async def local_search_streaming(
yield stream_chunk


@validate_call(config={"arbitrary_types_allowed": True})
async def drift_search_streaming(
config: GraphRagConfig,
nodes: pd.DataFrame,
entities: pd.DataFrame,
community_reports: pd.DataFrame,
text_units: pd.DataFrame,
relationships: pd.DataFrame,
community_level: int,
response_type: str,
query: str,
) -> AsyncGenerator:
"""Perform a DRIFT search and return the context data and response.

Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet)
- entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet)
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet)
- text_units (pd.DataFrame): A DataFrame containing the final text units (from create_final_text_units.parquet)
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from create_final_relationships.parquet)
- community_level (int): The community level to search at.
- query (str): The user query to search for.

Returns
-------
TODO: Document the search response type and format.

Raises
------
TODO: Document any exceptions to expect.
"""
vector_store_args = config.embeddings.vector_store
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa

description_embedding_store = _get_embedding_store(
config_args=vector_store_args, # type: ignore
embedding_name=entity_description_embedding,
)

full_content_embedding_store = _get_embedding_store(
config_args=vector_store_args, # type: ignore
embedding_name=community_full_content_embedding,
)

entities_ = read_indexer_entities(nodes, entities, community_level)
reports = read_indexer_reports(community_reports, nodes, community_level)
read_indexer_report_embeddings(reports, full_content_embedding_store)
prompt = _load_search_prompt(config.root_dir, config.drift_search.prompt)
reduce_prompt = _load_search_prompt(
config.root_dir, config.drift_search.reduce_prompt
)

search_engine = get_drift_search_engine(
config=config,
reports=reports,
text_units=read_indexer_text_units(text_units),
entities=entities_,
relationships=read_indexer_relationships(relationships),
description_embedding_store=description_embedding_store, # type: ignore
local_system_prompt=prompt,
reduce_system_prompt=reduce_prompt,
response_type=response_type,
)

search_result = search_engine.astream_search(query=query)

# when streaming results, a context data object is returned as the first result
# and the query response in subsequent tokens
context_data = None
get_context_data = True
async for stream_chunk in search_result:
if get_context_data:
context_data = _reformat_context_data(stream_chunk) # type: ignore
yield context_data
get_context_data = False
else:
yield stream_chunk


@validate_call(config={"arbitrary_types_allowed": True})
async def drift_search(
config: GraphRagConfig,
Expand All @@ -357,6 +438,7 @@ async def drift_search(
text_units: pd.DataFrame,
relationships: pd.DataFrame,
community_level: int,
response_type: str,
query: str,
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
Expand Down Expand Up @@ -400,6 +482,10 @@ async def drift_search(
reports = read_indexer_reports(community_reports, nodes, community_level)
read_indexer_report_embeddings(reports, full_content_embedding_store)
prompt = _load_search_prompt(config.root_dir, config.drift_search.prompt)
reduce_prompt = _load_search_prompt(
config.root_dir, config.drift_search.reduce_prompt
)

search_engine = get_drift_search_engine(
config=config,
reports=reports,
Expand All @@ -408,21 +494,15 @@ async def drift_search(
relationships=read_indexer_relationships(relationships),
description_embedding_store=description_embedding_store, # type: ignore
local_system_prompt=prompt,
reduce_system_prompt=reduce_prompt,
response_type=response_type,
)

result: SearchResult = await search_engine.asearch(query=query)
response = result.response
context_data = _reformat_context_data(result.context_data) # type: ignore

# TODO: Map/reduce the response to a single string with a comprehensive answer including all follow-ups
# For the time being, return highest scoring response (position 0) and context data
match response:
case dict():
return response["nodes"][0]["answer"], context_data # type: ignore
case str():
return response, context_data
case list():
return response, context_data
return response, context_data


@validate_call(config={"arbitrary_types_allowed": True})
Expand Down
6 changes: 5 additions & 1 deletion graphrag/cli/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from graphrag.prompts.index.entity_extraction import GRAPH_EXTRACTION_PROMPT
from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT
from graphrag.prompts.query.basic_search_system_prompt import BASIC_SEARCH_SYSTEM_PROMPT
from graphrag.prompts.query.drift_search_system_prompt import DRIFT_LOCAL_SYSTEM_PROMPT
from graphrag.prompts.query.drift_search_system_prompt import (
DRIFT_LOCAL_SYSTEM_PROMPT,
DRIFT_REDUCE_PROMPT,
)
from graphrag.prompts.query.global_search_knowledge_system_prompt import (
GENERAL_KNOWLEDGE_INSTRUCTION,
)
Expand Down Expand Up @@ -57,6 +60,7 @@ def initialize_project_at(path: Path) -> None:
"claim_extraction": CLAIM_EXTRACTION_PROMPT,
"community_report": COMMUNITY_REPORT_PROMPT,
"drift_search_system_prompt": DRIFT_LOCAL_SYSTEM_PROMPT,
"drift_reduce_prompt": DRIFT_REDUCE_PROMPT,
"global_search_map_system_prompt": MAP_SYSTEM_PROMPT,
"global_search_reduce_system_prompt": REDUCE_SYSTEM_PROMPT,
"global_search_knowledge_system_prompt": GENERAL_KNOWLEDGE_INSTRUCTION,
Expand Down
3 changes: 2 additions & 1 deletion graphrag/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,8 @@ def _query_cli(
data_dir=data,
root_dir=root,
community_level=community_level,
streaming=False, # Drift search does not support streaming (yet)
streaming=streaming,
response_type=response_type,
query=query,
)
case SearchType.BASIC:
Expand Down
33 changes: 29 additions & 4 deletions graphrag/cli/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def run_drift_search(
data_dir: Path | None,
root_dir: Path,
community_level: int,
response_type: str,
streaming: bool,
query: str,
):
Expand Down Expand Up @@ -234,8 +235,33 @@ def run_drift_search(

# call the Query API
if streaming:
error_msg = "Streaming is not supported yet for DRIFT search."
raise NotImplementedError(error_msg)

async def run_streaming_search():
full_response = ""
context_data = None
get_context_data = True
async for stream_chunk in api.drift_search_streaming(
config=config,
nodes=final_nodes,
entities=final_entities,
community_reports=final_community_reports,
text_units=final_text_units,
relationships=final_relationships,
community_level=community_level,
response_type=response_type,
query=query,
):
if get_context_data:
context_data = stream_chunk
get_context_data = False
else:
full_response += stream_chunk
print(stream_chunk, end="") # noqa: T201
sys.stdout.flush() # flush output buffer to display text immediately
print() # noqa: T201
return full_response, context_data

return asyncio.run(run_streaming_search())

# not streaming
response, context_data = asyncio.run(
Expand All @@ -247,6 +273,7 @@ def run_drift_search(
text_units=final_text_units,
relationships=final_relationships,
community_level=community_level,
response_type=response_type,
query=query,
)
)
Expand Down Expand Up @@ -281,8 +308,6 @@ def run_basic_search(
)
final_text_units: pd.DataFrame = dataframe_dict["create_final_text_units"]

print(streaming) # noqa: T201

# # call the Query API
if streaming:

Expand Down
5 changes: 5 additions & 0 deletions graphrag/config/create_graphrag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,7 @@ def hydrate_parallelization_params(
):
drift_search_model = DRIFTSearchConfig(
prompt=reader.str("prompt") or None,
reduce_prompt=reader.str("reduce_prompt") or None,
temperature=reader.float("llm_temperature")
or defs.DRIFT_SEARCH_LLM_TEMPERATURE,
top_p=reader.float("llm_top_p") or defs.DRIFT_SEARCH_LLM_TOP_P,
Expand All @@ -597,6 +598,10 @@ def hydrate_parallelization_params(
or defs.DRIFT_SEARCH_MAX_TOKENS,
data_max_tokens=reader.int("data_max_tokens")
or defs.DRIFT_SEARCH_DATA_MAX_TOKENS,
reduce_max_tokens=reader.int("reduce_max_tokens")
or defs.DRIFT_SEARCH_REDUCE_MAX_TOKENS,
reduce_temperature=reader.float("reduce_temperature")
or defs.DRIFT_SEARCH_REDUCE_LLM_TEMPERATURE,
concurrency=reader.int("concurrency") or defs.DRIFT_SEARCH_CONCURRENCY,
drift_k_followups=reader.int("drift_k_followups")
or defs.DRIFT_SEARCH_K_FOLLOW_UPS,
Expand Down
3 changes: 3 additions & 0 deletions graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@
DRIFT_SEARCH_PRIMER_FOLDS = 5
DRIFT_SEARCH_PRIMER_MAX_TOKENS = 12_000

DRIFT_SEARCH_REDUCE_LLM_TEMPERATURE = 0
DRIFT_SEARCH_REDUCE_MAX_TOKENS = 2_000

DRIFT_LOCAL_SEARCH_TEXT_UNIT_PROP = 0.9
DRIFT_LOCAL_SEARCH_COMMUNITY_PROP = 0.1
DRIFT_LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES = 10
Expand Down
1 change: 1 addition & 0 deletions graphrag/config/init_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@

drift_search:
prompt: "prompts/drift_search_system_prompt.txt"
reduce_prompt: "prompts/drift_search_reduce_prompt.txt"

basic_search:
prompt: "prompts/basic_search_system_prompt.txt"
Expand Down
13 changes: 13 additions & 0 deletions graphrag/config/models/drift_search_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ class DRIFTSearchConfig(BaseModel):
prompt: str | None = Field(
description="The drift search prompt to use.", default=None
)
reduce_prompt: str | None = Field(
description="The drift search reduce prompt to use.", default=None
)
temperature: float = Field(
description="The temperature to use for token generation.",
default=defs.DRIFT_SEARCH_LLM_TEMPERATURE,
Expand All @@ -35,6 +38,16 @@ class DRIFTSearchConfig(BaseModel):
default=defs.DRIFT_SEARCH_DATA_MAX_TOKENS,
)

reduce_max_tokens: int = Field(
description="The reduce llm maximum tokens response to produce.",
default=defs.DRIFT_SEARCH_REDUCE_MAX_TOKENS,
)

reduce_temperature: float = Field(
description="The temperature to use for token generation in reduce.",
default=defs.DRIFT_SEARCH_REDUCE_LLM_TEMPERATURE,
)

concurrency: int = Field(
description="The number of concurrent requests.",
default=defs.DRIFT_SEARCH_CONCURRENCY,
Expand Down
4 changes: 1 addition & 3 deletions graphrag/prompts/query/drift_search_system_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
---Target response length and format---
Multiple paragraphs
{response_type}
---Goal---
Expand All @@ -133,8 +133,6 @@
Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. Now answer the following query using the data above:
{query}
"""


Expand Down
4 changes: 4 additions & 0 deletions graphrag/query/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ def get_drift_search_engine(
entities: list[Entity],
relationships: list[Relationship],
description_embedding_store: BaseVectorStore,
response_type: str,
local_system_prompt: str | None = None,
reduce_system_prompt: str | None = None,
) -> DRIFTSearch:
"""Create a local search engine based on data + configuration."""
llm = get_llm(config)
Expand All @@ -191,7 +193,9 @@ def get_drift_search_engine(
entity_text_embeddings=description_embedding_store,
text_units=text_units,
local_system_prompt=local_system_prompt,
reduce_system_prompt=reduce_system_prompt,
config=config.drift_search,
response_type=response_type,
),
token_encoder=token_encoder,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from graphrag.model.text_unit import TextUnit
from graphrag.prompts.query.drift_search_system_prompt import (
DRIFT_LOCAL_SYSTEM_PROMPT,
DRIFT_REDUCE_PROMPT,
)
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.llm.base import BaseTextEmbedding
Expand Down Expand Up @@ -51,13 +52,16 @@ def __init__(
config: DRIFTSearchConfig | None = None,
local_system_prompt: str | None = None,
local_mixed_context: LocalSearchMixedContext | None = None,
reduce_system_prompt: str | None = None,
response_type: str | None = None,
):
"""Initialize the DRIFT search context builder with necessary components."""
self.config = config or DRIFTSearchConfig()
self.chat_llm = chat_llm
self.text_embedder = text_embedder
self.token_encoder = token_encoder
self.local_system_prompt = local_system_prompt or DRIFT_LOCAL_SYSTEM_PROMPT
self.reduce_system_prompt = reduce_system_prompt or DRIFT_REDUCE_PROMPT

self.entities = entities
self.entity_text_embeddings = entity_text_embeddings
Expand All @@ -67,6 +71,8 @@ def __init__(
self.covariates = covariates
self.embedding_vectorstore_key = embedding_vectorstore_key

self.response_type = response_type

self.local_mixed_context = (
local_mixed_context or self.init_local_context_builder()
)
Expand Down
Loading
Loading