From 62546a3c14b39f260d2de53fcb4ebe2470cdb06d Mon Sep 17 00:00:00 2001 From: Josh Bradley Date: Tue, 20 Aug 2024 15:44:48 -0400 Subject: [PATCH] Add streaming support for local/global search (#944) * Added streaming output support for global search. Introduce `--streaming` flag to enable or disable streaming mode * ran ruff format --preview * update * cleanup code and streaming api * update cli argument * remove whitespace * checkpoint - add context data to streaming api * cleanup help menu * ruff format update * add context data to streaming response * add semversioner file * rename variable for better readability * rename variable for better readability * ruff fixes * fix abstract class type annotation * add documentation for --streaming CLI flag --------- Co-authored-by: 6GOD <55304045+6ixGODD@users.noreply.github.com> Co-authored-by: Alonso Guevara --- .../minor-20240816080238245653.json | 4 + docsite/posts/query/3-cli.md | 1 + graphrag/query/__main__.py | 8 + graphrag/query/api.py | 265 ++++++++++++++---- graphrag/query/cli.py | 77 ++++- graphrag/query/llm/base.py | 20 ++ graphrag/query/llm/oai/chat_openai.py | 125 ++++++++- graphrag/query/structured_search/base.py | 9 + .../structured_search/global_search/search.py | 110 ++++++++ .../structured_search/local_search/search.py | 32 +++ 10 files changed, 593 insertions(+), 58 deletions(-) create mode 100644 .semversioner/next-release/minor-20240816080238245653.json diff --git a/.semversioner/next-release/minor-20240816080238245653.json b/.semversioner/next-release/minor-20240816080238245653.json new file mode 100644 index 0000000000..3eacdade0d --- /dev/null +++ b/.semversioner/next-release/minor-20240816080238245653.json @@ -0,0 +1,4 @@ +{ + "type": "minor", + "description": "Add streaming support for local/global search to query cli" +} diff --git a/docsite/posts/query/3-cli.md b/docsite/posts/query/3-cli.md index 518d9cd871..97734d4f1d 100644 --- a/docsite/posts/query/3-cli.md +++ b/docsite/posts/query/3-cli.md @@ -19,6 +19,7 @@ python -m graphrag.query --config --data --comm - `--community_level ` - Community level in the Leiden community hierarchy from which we will load the community reports higher value means we use reports on smaller communities. Default: 2 - `--response_type ` - Free form text describing the response type and format, can be anything, e.g. `Multiple Paragraphs`, `Single Paragraph`, `Single Sentence`, `List of 3-7 Points`, `Single Page`, `Multi-Page Report`. Default: `Multiple Paragraphs`. - `--method <"local"|"global">` - Method to use to answer the query, one of local or global. For more information check [Overview](overview.md) +- `--streaming` - Stream back the LLM response ## Env Variables diff --git a/graphrag/query/__main__.py b/graphrag/query/__main__.py index 19ad00d5c9..41d6958803 100644 --- a/graphrag/query/__main__.py +++ b/graphrag/query/__main__.py @@ -72,6 +72,12 @@ def __str__(self): default="Multiple Paragraphs", ) + parser.add_argument( + "--streaming", + help="Output response in a streaming (chunk-by-chunk) manner", + action="store_true", + ) + parser.add_argument( "query", nargs=1, @@ -89,6 +95,7 @@ def __str__(self): args.root, args.community_level, args.response_type, + args.streaming, args.query[0], ) case SearchType.GLOBAL: @@ -98,6 +105,7 @@ def __str__(self): args.root, args.community_level, args.response_type, + args.streaming, args.query[0], ) case _: diff --git a/graphrag/query/api.py b/graphrag/query/api.py index 8f6f82470a..050af3b729 100644 --- a/graphrag/query/api.py +++ b/graphrag/query/api.py @@ -7,10 +7,17 @@ This API provides access to the query engine of graphrag, allowing external applications to hook into graphrag and run queries over a knowledge graph generated by graphrag. +Contains the following functions: + - global_search: Perform a global search. + - global_search_streaming: Perform a global search and stream results via a generator. + - local_search: Perform a local search. + - local_search_streaming: Perform a local search and stream results via a generator. + WARNING: This API is under development and may undergo changes in future releases. Backwards compatibility is not guaranteed at this time. """ +from collections.abc import AsyncGenerator from typing import Any import pandas as pd @@ -35,53 +42,6 @@ reporter = PrintProgressReporter("") -def __get_embedding_description_store( - entities: list[Entity], - vector_store_type: str = VectorStoreType.LanceDB, - config_args: dict | None = None, -): - """Get the embedding description store.""" - if not config_args: - config_args = {} - - collection_name = config_args.get( - "query_collection_name", "entity_description_embeddings" - ) - config_args.update({"collection_name": collection_name}) - description_embedding_store = VectorStoreFactory.get_vector_store( - vector_store_type=vector_store_type, kwargs=config_args - ) - - description_embedding_store.connect(**config_args) - - if config_args.get("overwrite", True): - # this step assumes the embeddings were originally stored in a file rather - # than a vector database - - # dump embeddings from the entities list to the description_embedding_store - store_entity_semantic_embeddings( - entities=entities, vectorstore=description_embedding_store - ) - else: - # load description embeddings to an in-memory lancedb vectorstore - # and connect to a remote db, specify url and port values. - description_embedding_store = LanceDBVectorStore( - collection_name=collection_name - ) - description_embedding_store.connect( - db_uri=config_args.get("db_uri", "./lancedb") - ) - - # load data from an existing table - description_embedding_store.document_collection = ( - description_embedding_store.db_connection.open_table( - description_embedding_store.collection_name - ) - ) - - return description_embedding_store - - @validate_call(config={"arbitrary_types_allowed": True}) async def global_search( config: GraphRagConfig, @@ -125,6 +85,61 @@ async def global_search( return result.response +@validate_call(config={"arbitrary_types_allowed": True}) +async def global_search_streaming( + config: GraphRagConfig, + nodes: pd.DataFrame, + entities: pd.DataFrame, + community_reports: pd.DataFrame, + community_level: int, + response_type: str, + query: str, +) -> AsyncGenerator: + """Perform a global search and return results as a generator. + + Context data is returned as a dictionary of lists, with one list entry for each record. + + Parameters + ---------- + - config (GraphRagConfig): A graphrag configuration (from settings.yaml) + - nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet) + - entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet) + - community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet) + - community_level (int): The community level to search at. + - response_type (str): The type of response to return. + - query (str): The user query to search for. + + Returns + ------- + TODO: Document the search response type and format. + + Raises + ------ + TODO: Document any exceptions to expect. + """ + reports = read_indexer_reports(community_reports, nodes, community_level) + _entities = read_indexer_entities(nodes, entities, community_level) + search_engine = get_global_search_engine( + config, + reports=reports, + entities=_entities, + response_type=response_type, + ) + search_result = search_engine.astream_search(query=query) + + # when streaming results, a context data object is returned as the first result + # and the query response in subsequent tokens + context_data = None + get_context_data = True + async for stream_chunk in search_result: + if get_context_data: + context_data = _reformat_context_data(stream_chunk) + yield context_data + get_context_data = False + else: + yield stream_chunk + + @validate_call(config={"arbitrary_types_allowed": True}) async def local_search( config: GraphRagConfig, @@ -164,16 +179,17 @@ async def local_search( vector_store_args = ( config.embeddings.vector_store if config.embeddings.vector_store else {} ) - reporter.info(f"Vector Store Args: {vector_store_args}") + vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB) _entities = read_indexer_entities(nodes, entities, community_level) - description_embedding_store = __get_embedding_description_store( + description_embedding_store = _get_embedding_description_store( entities=_entities, vector_store_type=vector_store_type, config_args=vector_store_args, ) + _covariates = read_indexer_covariates(covariates) if covariates is not None else [] search_engine = get_local_search_engine( @@ -190,3 +206,154 @@ async def local_search( result = await search_engine.asearch(query=query) reporter.success(f"Local Search Response: {result.response}") return result.response + + +@validate_call(config={"arbitrary_types_allowed": True}) +async def local_search_streaming( + config: GraphRagConfig, + nodes: pd.DataFrame, + entities: pd.DataFrame, + community_reports: pd.DataFrame, + text_units: pd.DataFrame, + relationships: pd.DataFrame, + covariates: pd.DataFrame | None, + community_level: int, + response_type: str, + query: str, +) -> AsyncGenerator: + """Perform a local search and return results as a generator. + + Parameters + ---------- + - config (GraphRagConfig): A graphrag configuration (from settings.yaml) + - nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet) + - entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet) + - community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet) + - text_units (pd.DataFrame): A DataFrame containing the final text units (from create_final_text_units.parquet) + - relationships (pd.DataFrame): A DataFrame containing the final relationships (from create_final_relationships.parquet) + - covariates (pd.DataFrame): A DataFrame containing the final covariates (from create_final_covariates.parquet) + - community_level (int): The community level to search at. + - response_type (str): The response type to return. + - query (str): The user query to search for. + + Returns + ------- + TODO: Document the search response type and format. + + Raises + ------ + TODO: Document any exceptions to expect. + """ + vector_store_args = ( + config.embeddings.vector_store if config.embeddings.vector_store else {} + ) + reporter.info(f"Vector Store Args: {vector_store_args}") + + vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB) + + _entities = read_indexer_entities(nodes, entities, community_level) + description_embedding_store = _get_embedding_description_store( + entities=_entities, + vector_store_type=vector_store_type, + config_args=vector_store_args, + ) + + _covariates = read_indexer_covariates(covariates) if covariates is not None else [] + + search_engine = get_local_search_engine( + config=config, + reports=read_indexer_reports(community_reports, nodes, community_level), + text_units=read_indexer_text_units(text_units), + entities=_entities, + relationships=read_indexer_relationships(relationships), + covariates={"claims": _covariates}, + description_embedding_store=description_embedding_store, + response_type=response_type, + ) + search_result = search_engine.astream_search(query=query) + + # when streaming results, a context data object is returned as the first result + # and the query response in subsequent tokens + context_data = None + get_context_data = True + async for stream_chunk in search_result: + if get_context_data: + context_data = _reformat_context_data(stream_chunk) + yield context_data + get_context_data = False + else: + yield stream_chunk + + +def _get_embedding_description_store( + entities: list[Entity], + vector_store_type: str = VectorStoreType.LanceDB, + config_args: dict | None = None, +): + """Get the embedding description store.""" + if not config_args: + config_args = {} + + collection_name = config_args.get( + "query_collection_name", "entity_description_embeddings" + ) + config_args.update({"collection_name": collection_name}) + description_embedding_store = VectorStoreFactory.get_vector_store( + vector_store_type=vector_store_type, kwargs=config_args + ) + + description_embedding_store.connect(**config_args) + + if config_args.get("overwrite", True): + # this step assumes the embeddings were originally stored in a file rather + # than a vector database + + # dump embeddings from the entities list to the description_embedding_store + store_entity_semantic_embeddings( + entities=entities, vectorstore=description_embedding_store + ) + else: + # load description embeddings to an in-memory lancedb vectorstore + # and connect to a remote db, specify url and port values. + description_embedding_store = LanceDBVectorStore( + collection_name=collection_name + ) + description_embedding_store.connect( + db_uri=config_args.get("db_uri", "./lancedb") + ) + + # load data from an existing table + description_embedding_store.document_collection = ( + description_embedding_store.db_connection.open_table( + description_embedding_store.collection_name + ) + ) + + return description_embedding_store + + +def _reformat_context_data(context_data: dict) -> dict: + """ + Reformats context_data for all query responses. + + Reformats a dictionary of dataframes into a dictionary of lists. + One list entry for each record. Records are grouped by original + dictionary keys. + + Note: depending on which query algorithm is used, the context_data may not + contain the same information (keys). In this case, the default behavior will be to + set these keys as empty lists to preserve a standard output format. + """ + final_format = { + "reports": [], + "entities": [], + "relationships": [], + "claims": [], + "sources": [], + } + for key in context_data: + records = context_data[key].to_dict(orient="records") + if len(records) < 1: + continue + final_format[key] = records + return final_format diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 915807a3a2..59eb4b454a 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -5,6 +5,7 @@ import asyncio import re +import sys from pathlib import Path from typing import cast @@ -22,11 +23,12 @@ def run_global_search( - config_dir: str | None, + config_filepath: str | None, data_dir: str | None, root_dir: str | None, community_level: int, response_type: str, + streaming: bool, query: str, ): """Perform a global search with a given query. @@ -34,7 +36,7 @@ def run_global_search( Loads index files required for global search and calls the Query API. """ data_dir, root_dir, config = _configure_paths_and_settings( - data_dir, root_dir, config_dir + data_dir, root_dir, config_filepath ) data_path = Path(data_dir) @@ -48,6 +50,34 @@ def run_global_search( data_path / "create_final_community_reports.parquet" ) + # call the Query API + if streaming: + + async def run_streaming_search(): + full_response = "" + context_data = None + get_context_data = True + async for stream_chunk in api.global_search_streaming( + config=config, + nodes=final_nodes, + entities=final_entities, + community_reports=final_community_reports, + community_level=community_level, + response_type=response_type, + query=query, + ): + if get_context_data: + context_data = stream_chunk + get_context_data = False + else: + full_response += stream_chunk + print(stream_chunk, end="") # noqa: T201 + sys.stdout.flush() # flush output buffer to display text immediately + print() # noqa: T201 + return full_response, context_data + + return asyncio.run(run_streaming_search()) + # not streaming return asyncio.run( api.global_search( config=config, @@ -62,11 +92,12 @@ def run_global_search( def run_local_search( - config_dir: str | None, + config_filepath: str | None, data_dir: str | None, root_dir: str | None, community_level: int, response_type: str, + streaming: bool, query: str, ): """Perform a local search with a given query. @@ -74,7 +105,7 @@ def run_local_search( Loads index files required for local search and calls the Query API. """ data_dir, root_dir, config = _configure_paths_and_settings( - data_dir, root_dir, config_dir + data_dir, root_dir, config_filepath ) data_path = Path(data_dir) @@ -95,6 +126,36 @@ def run_local_search( ) # call the Query API + if streaming: + + async def run_streaming_search(): + full_response = "" + context_data = None + get_context_data = True + async for stream_chunk in api.local_search_streaming( + config=config, + nodes=final_nodes, + entities=final_entities, + community_reports=final_community_reports, + text_units=final_text_units, + relationships=final_relationships, + covariates=final_covariates, + community_level=community_level, + response_type=response_type, + query=query, + ): + if get_context_data: + context_data = stream_chunk + get_context_data = False + else: + full_response += stream_chunk + print(stream_chunk, end="") # noqa: T201 + sys.stdout.flush() # flush output buffer to display text immediately + print() # noqa: T201 + return full_response, context_data + + return asyncio.run(run_streaming_search()) + # not streaming return asyncio.run( api.local_search( config=config, @@ -114,14 +175,14 @@ def run_local_search( def _configure_paths_and_settings( data_dir: str | None, root_dir: str | None, - config_dir: str | None, + config_filepath: str | None, ) -> tuple[str, str | None, GraphRagConfig]: if data_dir is None and root_dir is None: msg = "Either data_dir or root_dir must be provided." raise ValueError(msg) if data_dir is None: data_dir = _infer_data_dir(cast(str, root_dir)) - config = _create_graphrag_config(root_dir, config_dir) + config = _create_graphrag_config(root_dir, config_filepath) return data_dir, root_dir, config @@ -141,10 +202,10 @@ def _infer_data_dir(root: str) -> str: def _create_graphrag_config( root: str | None, - config_dir: str | None, + config_filepath: str | None, ) -> GraphRagConfig: """Create a GraphRag configuration.""" - return _read_config_parameters(root or "./", config_dir) + return _read_config_parameters(root or "./", config_filepath) def _read_config_parameters(root: str, config: str | None): diff --git a/graphrag/query/llm/base.py b/graphrag/query/llm/base.py index 228150af50..2c18bb29a1 100644 --- a/graphrag/query/llm/base.py +++ b/graphrag/query/llm/base.py @@ -4,6 +4,7 @@ """Base classes for LLM and Embedding models.""" from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator, Generator from typing import Any @@ -31,6 +32,15 @@ def generate( ) -> str: """Generate a response.""" + @abstractmethod + def stream_generate( + self, + messages: str | list[Any], + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> Generator[str, None, None]: + """Generate a response with streaming.""" + @abstractmethod async def agenerate( self, @@ -41,6 +51,16 @@ async def agenerate( ) -> str: """Generate a response asynchronously.""" + @abstractmethod + async def astream_generate( + self, + messages: str | list[Any], + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> AsyncGenerator[str, None]: + """Generate a response asynchronously with streaming.""" + ... + class BaseTextEmbedding(ABC): """The text embedding interface.""" diff --git a/graphrag/query/llm/oai/chat_openai.py b/graphrag/query/llm/oai/chat_openai.py index 92a9755b10..7dc3579a19 100644 --- a/graphrag/query/llm/oai/chat_openai.py +++ b/graphrag/query/llm/oai/chat_openai.py @@ -3,7 +3,7 @@ """Chat-based OpenAI LLM implementation.""" -from collections.abc import Callable +from collections.abc import AsyncGenerator, Callable, Generator from typing import Any from tenacity import ( @@ -92,6 +92,38 @@ def generate( # TODO: why not just throw in this case? return "" + def stream_generate( + self, + messages: str | list[Any], + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> Generator[str, None, None]: + """Generate text with streaming.""" + try: + retryer = Retrying( + stop=stop_after_attempt(self.max_retries), + wait=wait_exponential_jitter(max=10), + reraise=True, + retry=retry_if_exception_type(self.retry_error_types), + ) + for attempt in retryer: + with attempt: + generator = self._stream_generate( + messages=messages, + callbacks=callbacks, + **kwargs, + ) + yield from generator + + except RetryError as e: + self._reporter.error( + message="Error at stream_generate()", + details={self.__class__.__name__: str(e)}, + ) + return + else: + return + async def agenerate( self, messages: str | list[Any], @@ -122,6 +154,35 @@ async def agenerate( # TODO: why not just throw in this case? return "" + async def astream_generate( + self, + messages: str | list[Any], + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> AsyncGenerator[str, None]: + """Generate text asynchronously with streaming.""" + try: + retryer = AsyncRetrying( + stop=stop_after_attempt(self.max_retries), + wait=wait_exponential_jitter(max=10), + reraise=True, + retry=retry_if_exception_type(self.retry_error_types), # type: ignore + ) + async for attempt in retryer: + with attempt: + generator = self._astream_generate( + messages=messages, + callbacks=callbacks, + **kwargs, + ) + async for response in generator: + yield response + except RetryError as e: + self._reporter.error(f"Error at astream_generate(): {e}") + return + else: + return + def _generate( self, messages: str | list[Any], @@ -163,6 +224,37 @@ def _generate( return full_response return response.choices[0].message.content or "" # type: ignore + def _stream_generate( + self, + messages: str | list[Any], + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> Generator[str, None, None]: + model = self.model + if not model: + raise ValueError(_MODEL_REQUIRED_MSG) + response = self.sync_client.chat.completions.create( # type: ignore + model=model, + messages=messages, # type: ignore + stream=True, + **kwargs, + ) + for chunk in response: + if not chunk or not chunk.choices: + continue + + delta = ( + chunk.choices[0].delta.content + if chunk.choices[0].delta and chunk.choices[0].delta.content + else "" + ) + + yield delta + + if callbacks: + for callback in callbacks: + callback.on_llm_new_token(delta) + async def _agenerate( self, messages: str | list[Any], @@ -204,3 +296,34 @@ async def _agenerate( return full_response return response.choices[0].message.content or "" # type: ignore + + async def _astream_generate( + self, + messages: str | list[Any], + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> AsyncGenerator[str, None]: + model = self.model + if not model: + raise ValueError(_MODEL_REQUIRED_MSG) + response = await self.async_client.chat.completions.create( # type: ignore + model=model, + messages=messages, # type: ignore + stream=True, + **kwargs, + ) + async for chunk in response: + if not chunk or not chunk.choices: + continue + + delta = ( + chunk.choices[0].delta.content + if chunk.choices[0].delta and chunk.choices[0].delta.content + else "" + ) # type: ignore + + yield delta + + if callbacks: + for callback in callbacks: + callback.on_llm_new_token(delta) diff --git a/graphrag/query/structured_search/base.py b/graphrag/query/structured_search/base.py index 6dd02485f8..e74a03f67a 100644 --- a/graphrag/query/structured_search/base.py +++ b/graphrag/query/structured_search/base.py @@ -4,6 +4,7 @@ """Base classes for search algos.""" from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator from dataclasses import dataclass from typing import Any @@ -67,3 +68,11 @@ async def asearch( **kwargs, ) -> SearchResult: """Search for the given query asynchronously.""" + + @abstractmethod + def astream_search( + self, + query: str, + conversation_history: ConversationHistory | None = None, + ) -> AsyncGenerator[str, None]: + """Stream search for the given query.""" diff --git a/graphrag/query/structured_search/global_search/search.py b/graphrag/query/structured_search/global_search/search.py index 12dc45fe7a..3ccde79003 100644 --- a/graphrag/query/structured_search/global_search/search.py +++ b/graphrag/query/structured_search/global_search/search.py @@ -7,6 +7,7 @@ import json import logging import time +from collections.abc import AsyncGenerator from dataclasses import dataclass from typing import Any @@ -100,6 +101,37 @@ def __init__( self.semaphore = asyncio.Semaphore(concurrent_coroutines) + async def astream_search( + self, + query: str, + conversation_history: ConversationHistory | None = None, + ) -> AsyncGenerator: + """Stream the global search response.""" + context_chunks, context_records = self.context_builder.build_context( + conversation_history=conversation_history, **self.context_builder_params + ) + if self.callbacks: + for callback in self.callbacks: + callback.on_map_response_start(context_chunks) # type: ignore + map_responses = await asyncio.gather(*[ + self._map_response_single_batch( + context_data=data, query=query, **self.map_llm_params + ) + for data in context_chunks + ]) + if self.callbacks: + for callback in self.callbacks: + callback.on_map_response_end(map_responses) # type: ignore + + # send context records first before sending the reduce response + yield context_records + async for response in self._stream_reduce_response( + map_responses=map_responses, # type: ignore + query=query, + **self.reduce_llm_params, + ): + yield response + async def asearch( self, query: str, @@ -357,3 +389,81 @@ async def _reduce_response( llm_calls=1, prompt_tokens=num_tokens(search_prompt, self.token_encoder), ) + + async def _stream_reduce_response( + self, + map_responses: list[SearchResult], + query: str, + **llm_kwargs, + ) -> AsyncGenerator[str, None]: + # collect all key points into a single list to prepare for sorting + key_points = [] + for index, response in enumerate(map_responses): + if not isinstance(response.response, list): + continue + for element in response.response: + if not isinstance(element, dict): + continue + if "answer" not in element or "score" not in element: + continue + key_points.append({ + "analyst": index, + "answer": element["answer"], + "score": element["score"], + }) + + # filter response with score = 0 and rank responses by descending order of score + filtered_key_points = [ + point + for point in key_points + if point["score"] > 0 # type: ignore + ] + + if len(filtered_key_points) == 0 and not self.allow_general_knowledge: + # return no data answer if no key points are found + log.warning( + "Warning: All map responses have score 0 (i.e., no relevant information found from the dataset), returning a canned 'I do not know' answer. You can try enabling `allow_general_knowledge` to encourage the LLM to incorporate relevant general knowledge, at the risk of increasing hallucinations." + ) + yield NO_DATA_ANSWER + return + + filtered_key_points = sorted( + filtered_key_points, + key=lambda x: x["score"], # type: ignore + reverse=True, # type: ignore + ) + + data = [] + total_tokens = 0 + for point in filtered_key_points: + formatted_response_data = [ + f'----Analyst {point["analyst"] + 1}----', + f'Importance Score: {point["score"]}', + point["answer"], + ] + formatted_response_text = "\n".join(formatted_response_data) + if ( + total_tokens + num_tokens(formatted_response_text, self.token_encoder) + > self.max_data_tokens + ): + break + data.append(formatted_response_text) + total_tokens += num_tokens(formatted_response_text, self.token_encoder) + text_data = "\n\n".join(data) + + search_prompt = self.reduce_system_prompt.format( + report_data=text_data, response_type=self.response_type + ) + if self.allow_general_knowledge: + search_prompt += "\n" + self.general_knowledge_inclusion_prompt + search_messages = [ + {"role": "system", "content": search_prompt}, + {"role": "user", "content": query}, + ] + + async for resp in self.llm.astream_generate( # type: ignore + search_messages, + callbacks=self.callbacks, # type: ignore + **llm_kwargs, # type: ignore + ): + yield resp diff --git a/graphrag/query/structured_search/local_search/search.py b/graphrag/query/structured_search/local_search/search.py index 80dd667004..fd94e46f8d 100644 --- a/graphrag/query/structured_search/local_search/search.py +++ b/graphrag/query/structured_search/local_search/search.py @@ -5,6 +5,7 @@ import logging import time +from collections.abc import AsyncGenerator from typing import Any import tiktoken @@ -106,6 +107,37 @@ async def asearch( prompt_tokens=num_tokens(search_prompt, self.token_encoder), ) + async def astream_search( + self, + query: str, + conversation_history: ConversationHistory | None = None, + ) -> AsyncGenerator: + """Build local search context that fits a single context window and generate answer for the user query.""" + start_time = time.time() + + context_text, context_records = self.context_builder.build_context( + query=query, + conversation_history=conversation_history, + **self.context_builder_params, + ) + log.info("GENERATE ANSWER: %s. QUERY: %s", start_time, query) + search_prompt = self.system_prompt.format( + context_data=context_text, response_type=self.response_type + ) + search_messages = [ + {"role": "system", "content": search_prompt}, + {"role": "user", "content": query}, + ] + + # send context records first before sending the reduce response + yield context_records + async for response in self.llm.astream_generate( # type: ignore + messages=search_messages, + callbacks=self.callbacks, + **self.llm_params, + ): + yield response + def search( self, query: str,