Skip to content

Commit

Permalink
Add streaming support for local/global search (#944)
Browse files Browse the repository at this point in the history
* Added streaming output support for global search. Introduce `--streaming` flag to enable or disable streaming mode

* ran ruff format --preview

* update

* cleanup code and streaming api

* update cli argument

* remove whitespace

* checkpoint - add context data to streaming api

* cleanup help menu

* ruff format update

* add context data to streaming response

* add semversioner file

* rename variable for better readability

* rename variable for better readability

* ruff fixes

* fix abstract class type annotation

* add documentation for --streaming CLI flag

---------

Co-authored-by: 6GOD <[email protected]>
Co-authored-by: Alonso Guevara <[email protected]>
  • Loading branch information
3 people authored Aug 20, 2024
1 parent a6238c6 commit 62546a3
Show file tree
Hide file tree
Showing 10 changed files with 593 additions and 58 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/minor-20240816080238245653.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Add streaming support for local/global search to query cli"
}
1 change: 1 addition & 0 deletions docsite/posts/query/3-cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ python -m graphrag.query --config <config_file.yml> --data <path-to-data> --comm
- `--community_level <community-level>` - Community level in the Leiden community hierarchy from which we will load the community reports higher value means we use reports on smaller communities. Default: 2
- `--response_type <response-type>` - Free form text describing the response type and format, can be anything, e.g. `Multiple Paragraphs`, `Single Paragraph`, `Single Sentence`, `List of 3-7 Points`, `Single Page`, `Multi-Page Report`. Default: `Multiple Paragraphs`.
- `--method <"local"|"global">` - Method to use to answer the query, one of local or global. For more information check [Overview](overview.md)
- `--streaming` - Stream back the LLM response

## Env Variables

Expand Down
8 changes: 8 additions & 0 deletions graphrag/query/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ def __str__(self):
default="Multiple Paragraphs",
)

parser.add_argument(
"--streaming",
help="Output response in a streaming (chunk-by-chunk) manner",
action="store_true",
)

parser.add_argument(
"query",
nargs=1,
Expand All @@ -89,6 +95,7 @@ def __str__(self):
args.root,
args.community_level,
args.response_type,
args.streaming,
args.query[0],
)
case SearchType.GLOBAL:
Expand All @@ -98,6 +105,7 @@ def __str__(self):
args.root,
args.community_level,
args.response_type,
args.streaming,
args.query[0],
)
case _:
Expand Down
265 changes: 216 additions & 49 deletions graphrag/query/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,17 @@
This API provides access to the query engine of graphrag, allowing external applications
to hook into graphrag and run queries over a knowledge graph generated by graphrag.
Contains the following functions:
- global_search: Perform a global search.
- global_search_streaming: Perform a global search and stream results via a generator.
- local_search: Perform a local search.
- local_search_streaming: Perform a local search and stream results via a generator.
WARNING: This API is under development and may undergo changes in future releases.
Backwards compatibility is not guaranteed at this time.
"""

from collections.abc import AsyncGenerator
from typing import Any

import pandas as pd
Expand All @@ -35,53 +42,6 @@
reporter = PrintProgressReporter("")


def __get_embedding_description_store(
entities: list[Entity],
vector_store_type: str = VectorStoreType.LanceDB,
config_args: dict | None = None,
):
"""Get the embedding description store."""
if not config_args:
config_args = {}

collection_name = config_args.get(
"query_collection_name", "entity_description_embeddings"
)
config_args.update({"collection_name": collection_name})
description_embedding_store = VectorStoreFactory.get_vector_store(
vector_store_type=vector_store_type, kwargs=config_args
)

description_embedding_store.connect(**config_args)

if config_args.get("overwrite", True):
# this step assumes the embeddings were originally stored in a file rather
# than a vector database

# dump embeddings from the entities list to the description_embedding_store
store_entity_semantic_embeddings(
entities=entities, vectorstore=description_embedding_store
)
else:
# load description embeddings to an in-memory lancedb vectorstore
# and connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
collection_name=collection_name
)
description_embedding_store.connect(
db_uri=config_args.get("db_uri", "./lancedb")
)

# load data from an existing table
description_embedding_store.document_collection = (
description_embedding_store.db_connection.open_table(
description_embedding_store.collection_name
)
)

return description_embedding_store


@validate_call(config={"arbitrary_types_allowed": True})
async def global_search(
config: GraphRagConfig,
Expand Down Expand Up @@ -125,6 +85,61 @@ async def global_search(
return result.response


@validate_call(config={"arbitrary_types_allowed": True})
async def global_search_streaming(
config: GraphRagConfig,
nodes: pd.DataFrame,
entities: pd.DataFrame,
community_reports: pd.DataFrame,
community_level: int,
response_type: str,
query: str,
) -> AsyncGenerator:
"""Perform a global search and return results as a generator.
Context data is returned as a dictionary of lists, with one list entry for each record.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet)
- entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet)
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet)
- community_level (int): The community level to search at.
- response_type (str): The type of response to return.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
reports = read_indexer_reports(community_reports, nodes, community_level)
_entities = read_indexer_entities(nodes, entities, community_level)
search_engine = get_global_search_engine(
config,
reports=reports,
entities=_entities,
response_type=response_type,
)
search_result = search_engine.astream_search(query=query)

# when streaming results, a context data object is returned as the first result
# and the query response in subsequent tokens
context_data = None
get_context_data = True
async for stream_chunk in search_result:
if get_context_data:
context_data = _reformat_context_data(stream_chunk)
yield context_data
get_context_data = False
else:
yield stream_chunk


@validate_call(config={"arbitrary_types_allowed": True})
async def local_search(
config: GraphRagConfig,
Expand Down Expand Up @@ -164,16 +179,17 @@ async def local_search(
vector_store_args = (
config.embeddings.vector_store if config.embeddings.vector_store else {}
)

reporter.info(f"Vector Store Args: {vector_store_args}")

vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB)

_entities = read_indexer_entities(nodes, entities, community_level)
description_embedding_store = __get_embedding_description_store(
description_embedding_store = _get_embedding_description_store(
entities=_entities,
vector_store_type=vector_store_type,
config_args=vector_store_args,
)

_covariates = read_indexer_covariates(covariates) if covariates is not None else []

search_engine = get_local_search_engine(
Expand All @@ -190,3 +206,154 @@ async def local_search(
result = await search_engine.asearch(query=query)
reporter.success(f"Local Search Response: {result.response}")
return result.response


@validate_call(config={"arbitrary_types_allowed": True})
async def local_search_streaming(
config: GraphRagConfig,
nodes: pd.DataFrame,
entities: pd.DataFrame,
community_reports: pd.DataFrame,
text_units: pd.DataFrame,
relationships: pd.DataFrame,
covariates: pd.DataFrame | None,
community_level: int,
response_type: str,
query: str,
) -> AsyncGenerator:
"""Perform a local search and return results as a generator.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet)
- entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet)
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet)
- text_units (pd.DataFrame): A DataFrame containing the final text units (from create_final_text_units.parquet)
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from create_final_relationships.parquet)
- covariates (pd.DataFrame): A DataFrame containing the final covariates (from create_final_covariates.parquet)
- community_level (int): The community level to search at.
- response_type (str): The response type to return.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
vector_store_args = (
config.embeddings.vector_store if config.embeddings.vector_store else {}
)
reporter.info(f"Vector Store Args: {vector_store_args}")

vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB)

_entities = read_indexer_entities(nodes, entities, community_level)
description_embedding_store = _get_embedding_description_store(
entities=_entities,
vector_store_type=vector_store_type,
config_args=vector_store_args,
)

_covariates = read_indexer_covariates(covariates) if covariates is not None else []

search_engine = get_local_search_engine(
config=config,
reports=read_indexer_reports(community_reports, nodes, community_level),
text_units=read_indexer_text_units(text_units),
entities=_entities,
relationships=read_indexer_relationships(relationships),
covariates={"claims": _covariates},
description_embedding_store=description_embedding_store,
response_type=response_type,
)
search_result = search_engine.astream_search(query=query)

# when streaming results, a context data object is returned as the first result
# and the query response in subsequent tokens
context_data = None
get_context_data = True
async for stream_chunk in search_result:
if get_context_data:
context_data = _reformat_context_data(stream_chunk)
yield context_data
get_context_data = False
else:
yield stream_chunk


def _get_embedding_description_store(
entities: list[Entity],
vector_store_type: str = VectorStoreType.LanceDB,
config_args: dict | None = None,
):
"""Get the embedding description store."""
if not config_args:
config_args = {}

collection_name = config_args.get(
"query_collection_name", "entity_description_embeddings"
)
config_args.update({"collection_name": collection_name})
description_embedding_store = VectorStoreFactory.get_vector_store(
vector_store_type=vector_store_type, kwargs=config_args
)

description_embedding_store.connect(**config_args)

if config_args.get("overwrite", True):
# this step assumes the embeddings were originally stored in a file rather
# than a vector database

# dump embeddings from the entities list to the description_embedding_store
store_entity_semantic_embeddings(
entities=entities, vectorstore=description_embedding_store
)
else:
# load description embeddings to an in-memory lancedb vectorstore
# and connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
collection_name=collection_name
)
description_embedding_store.connect(
db_uri=config_args.get("db_uri", "./lancedb")
)

# load data from an existing table
description_embedding_store.document_collection = (
description_embedding_store.db_connection.open_table(
description_embedding_store.collection_name
)
)

return description_embedding_store


def _reformat_context_data(context_data: dict) -> dict:
"""
Reformats context_data for all query responses.
Reformats a dictionary of dataframes into a dictionary of lists.
One list entry for each record. Records are grouped by original
dictionary keys.
Note: depending on which query algorithm is used, the context_data may not
contain the same information (keys). In this case, the default behavior will be to
set these keys as empty lists to preserve a standard output format.
"""
final_format = {
"reports": [],
"entities": [],
"relationships": [],
"claims": [],
"sources": [],
}
for key in context_data:
records = context_data[key].to_dict(orient="records")
if len(records) < 1:
continue
final_format[key] = records
return final_format
Loading

0 comments on commit 62546a3

Please sign in to comment.