Skip to content

Commit

Permalink
Add streaming endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
AlonsoGuevara committed Jan 15, 2025
1 parent d148589 commit 1542b27
Show file tree
Hide file tree
Showing 8 changed files with 723 additions and 600 deletions.
2 changes: 2 additions & 0 deletions graphrag/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
basic_search,
basic_search_streaming,
drift_search,
drift_search_streaming,
global_search,
global_search_streaming,
local_search,
Expand All @@ -29,6 +30,7 @@
"local_search",
"local_search_streaming",
"drift_search",
"drift_search_streaming",
"basic_search",
"basic_search_streaming",
# prompt tuning API
Expand Down
95 changes: 85 additions & 10 deletions graphrag/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,87 @@ async def local_search_streaming(
yield stream_chunk


@validate_call(config={"arbitrary_types_allowed": True})
async def drift_search_streaming(
config: GraphRagConfig,
nodes: pd.DataFrame,
entities: pd.DataFrame,
community_reports: pd.DataFrame,
text_units: pd.DataFrame,
relationships: pd.DataFrame,
community_level: int,
response_type: str,
query: str,
) -> AsyncGenerator:
"""Perform a DRIFT search and return the context data and response.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet)
- entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet)
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet)
- text_units (pd.DataFrame): A DataFrame containing the final text units (from create_final_text_units.parquet)
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from create_final_relationships.parquet)
- community_level (int): The community level to search at.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
vector_store_args = config.embeddings.vector_store
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa

description_embedding_store = _get_embedding_store(
config_args=vector_store_args, # type: ignore
embedding_name=entity_description_embedding,
)

full_content_embedding_store = _get_embedding_store(
config_args=vector_store_args, # type: ignore
embedding_name=community_full_content_embedding,
)

entities_ = read_indexer_entities(nodes, entities, community_level)
reports = read_indexer_reports(community_reports, nodes, community_level)
read_indexer_report_embeddings(reports, full_content_embedding_store)
prompt = _load_search_prompt(config.root_dir, config.drift_search.prompt)
reduce_prompt = _load_search_prompt(
config.root_dir, config.drift_search.reduce_prompt
)

search_engine = get_drift_search_engine(
config=config,
reports=reports,
text_units=read_indexer_text_units(text_units),
entities=entities_,
relationships=read_indexer_relationships(relationships),
description_embedding_store=description_embedding_store, # type: ignore
local_system_prompt=prompt,
reduce_system_prompt=reduce_prompt,
response_type=response_type,
)

search_result = search_engine.astream_search(query=query)

# when streaming results, a context data object is returned as the first result
# and the query response in subsequent tokens
context_data = None
get_context_data = True
async for stream_chunk in search_result:
if get_context_data:
context_data = _reformat_context_data(stream_chunk) # type: ignore
yield context_data
get_context_data = False
else:
yield stream_chunk


@validate_call(config={"arbitrary_types_allowed": True})
async def drift_search(
config: GraphRagConfig,
Expand Down Expand Up @@ -401,7 +482,9 @@ async def drift_search(
reports = read_indexer_reports(community_reports, nodes, community_level)
read_indexer_report_embeddings(reports, full_content_embedding_store)
prompt = _load_search_prompt(config.root_dir, config.drift_search.prompt)
reduce_prompt = _load_search_prompt(config.root_dir, config.drift_search.reduce_prompt)
reduce_prompt = _load_search_prompt(
config.root_dir, config.drift_search.reduce_prompt
)

search_engine = get_drift_search_engine(
config=config,
Expand All @@ -419,15 +502,7 @@ async def drift_search(
response = result.response
context_data = _reformat_context_data(result.context_data) # type: ignore

# TODO: Map/reduce the response to a single string with a comprehensive answer including all follow-ups
# For the time being, return highest scoring response (position 0) and context data
match response:
case dict():
return response["nodes"][0]["answer"], context_data # type: ignore
case str():
return response, context_data
case list():
return response, context_data
return response, context_data


@validate_call(config={"arbitrary_types_allowed": True})
Expand Down
2 changes: 1 addition & 1 deletion graphrag/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def _query_cli(
data_dir=data,
root_dir=root,
community_level=community_level,
streaming=False, # Drift search does not support streaming (yet)
streaming=streaming,
response_type=response_type,
query=query,
)
Expand Down
31 changes: 27 additions & 4 deletions graphrag/cli/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,33 @@ def run_drift_search(

# call the Query API
if streaming:
error_msg = "Streaming is not supported yet for DRIFT search."
raise NotImplementedError(error_msg)

async def run_streaming_search():
full_response = ""
context_data = None
get_context_data = True
async for stream_chunk in api.drift_search_streaming(
config=config,
nodes=final_nodes,
entities=final_entities,
community_reports=final_community_reports,
text_units=final_text_units,
relationships=final_relationships,
community_level=community_level,
response_type=response_type,
query=query,
):
if get_context_data:
context_data = stream_chunk
get_context_data = False
else:
full_response += stream_chunk
print(stream_chunk, end="") # noqa: T201
sys.stdout.flush() # flush output buffer to display text immediately
print() # noqa: T201
return full_response, context_data

return asyncio.run(run_streaming_search())

# not streaming
response, context_data = asyncio.run(
Expand Down Expand Up @@ -283,8 +308,6 @@ def run_basic_search(
)
final_text_units: pd.DataFrame = dataframe_dict["create_final_text_units"]

print(streaming) # noqa: T201

# # call the Query API
if streaming:

Expand Down
4 changes: 1 addition & 3 deletions graphrag/prompts/query/drift_search_system_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
---Target response length and format---
Multiple paragraphs
{response_type}
---Goal---
Expand All @@ -133,8 +133,6 @@
Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. Now answer the following query using the data above:
{query}
"""


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from graphrag.model.relationship import Relationship
from graphrag.model.text_unit import TextUnit
from graphrag.prompts.query.drift_search_system_prompt import (
DRIFT_LOCAL_SYSTEM_PROMPT, DRIFT_REDUCE_PROMPT
DRIFT_LOCAL_SYSTEM_PROMPT,
DRIFT_REDUCE_PROMPT,
)
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.llm.base import BaseTextEmbedding
Expand Down Expand Up @@ -52,7 +53,7 @@ def __init__(
local_system_prompt: str | None = None,
local_mixed_context: LocalSearchMixedContext | None = None,
reduce_system_prompt: str | None = None,
response_type: str | None = None
response_type: str | None = None,
):
"""Initialize the DRIFT search context builder with necessary components."""
self.config = config or DRIFTSearchConfig()
Expand Down
Loading

0 comments on commit 1542b27

Please sign in to comment.