diff --git a/.semversioner/next-release/minor-20241011205050985571.json b/.semversioner/next-release/minor-20241011205050985571.json new file mode 100644 index 0000000000..19abacd7c4 --- /dev/null +++ b/.semversioner/next-release/minor-20241011205050985571.json @@ -0,0 +1,4 @@ +{ + "type": "minor", + "description": "Added DRIFT graph reasoning query module" +} diff --git a/graphrag/config/defaults.py b/graphrag/config/defaults.py index b02d58ecef..b154e8f302 100644 --- a/graphrag/config/defaults.py +++ b/graphrag/config/defaults.py @@ -115,3 +115,28 @@ GLOBAL_SEARCH_MAP_MAX_TOKENS = 1000 GLOBAL_SEARCH_REDUCE_MAX_TOKENS = 2_000 GLOBAL_SEARCH_CONCURRENCY = 32 + +# DRIFT Search + +DRIFT_SEARCH_LLM_TEMPERATURE = 0 +DRIFT_SEARCH_LLM_TOP_P = 1 +DRIFT_SEARCH_LLM_N = 3 +DRIFT_SEARCH_MAX_TOKENS = 12_000 +DRIFT_SEARCH_DATA_MAX_TOKENS = 12_000 +DRIFT_SEARCH_CONCURRENCY = 32 + +DRIFT_SEARCH_K_FOLLOW_UPS = 20 +DRIFT_SEARCH_PRIMER_FOLDS = 5 +DRIFT_SEARCH_PRIMER_MAX_TOKENS = 12_000 + +DRIFT_LOCAL_SEARCH_TEXT_UNIT_PROP = 0.9 +DRIFT_LOCAL_SEARCH_COMMUNITY_PROP = 0.1 +DRIFT_LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES = 10 +DRIFT_LOCAL_SEARCH_TOP_K_RELATIONSHIPS = 10 +DRIFT_LOCAL_SEARCH_MAX_TOKENS = 12_000 +DRIFT_LOCAL_SEARCH_LLM_TEMPERATURE = 0 +DRIFT_LOCAL_SEARCH_LLM_TOP_P = 1 +DRIFT_LOCAL_SEARCH_LLM_N = 1 +DRIFT_LOCAL_SEARCH_LLM_MAX_TOKENS = 2000 + +DRIFT_N_DEPTH = 3 diff --git a/graphrag/config/models/__init__.py b/graphrag/config/models/__init__.py index 43c4cde506..68691cc6ee 100644 --- a/graphrag/config/models/__init__.py +++ b/graphrag/config/models/__init__.py @@ -8,6 +8,7 @@ from .claim_extraction_config import ClaimExtractionConfig from .cluster_graph_config import ClusterGraphConfig from .community_reports_config import CommunityReportsConfig +from .drift_config import DRIFTSearchConfig from .embed_graph_config import EmbedGraphConfig from .entity_extraction_config import EntityExtractionConfig from .global_search_config import GlobalSearchConfig @@ -30,6 +31,7 @@ "ClaimExtractionConfig", "ClusterGraphConfig", "CommunityReportsConfig", + "DRIFTSearchConfig", "EmbedGraphConfig", "EntityExtractionConfig", "GlobalSearchConfig", diff --git a/graphrag/config/models/drift_config.py b/graphrag/config/models/drift_config.py new file mode 100644 index 0000000000..03b80150a2 --- /dev/null +++ b/graphrag/config/models/drift_config.py @@ -0,0 +1,103 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs + + +class DRIFTSearchConfig(BaseModel): + """The default configuration section for Cache.""" + + temperature: float = Field( + description="The temperature to use for token generation.", + default=defs.DRIFT_SEARCH_LLM_TEMPERATURE, + ) + top_p: float = Field( + description="The top-p value to use for token generation.", + default=defs.DRIFT_SEARCH_LLM_TOP_P, + ) + n: int = Field( + description="The number of completions to generate.", + default=defs.DRIFT_SEARCH_LLM_N, + ) + max_tokens: int = Field( + description="The maximum context size in tokens.", + default=defs.DRIFT_SEARCH_MAX_TOKENS, + ) + data_max_tokens: int = Field( + description="The data llm maximum tokens.", + default=defs.DRIFT_SEARCH_DATA_MAX_TOKENS, + ) + + concurrency: int = Field( + description="The number of concurrent requests.", + default=defs.DRIFT_SEARCH_CONCURRENCY, + ) + + drift_k_followups: int = Field( + description="The number of top global results to retrieve.", + default=defs.DRIFT_SEARCH_K_FOLLOW_UPS, + ) + + primer_folds: int = Field( + description="The number of folds for search priming.", + default=defs.DRIFT_SEARCH_PRIMER_FOLDS, + ) + + primer_llm_max_tokens: int = Field( + description="The maximum number of tokens for the LLM in primer.", + default=defs.DRIFT_SEARCH_PRIMER_MAX_TOKENS, + ) + + n_depth: int = Field( + description="The number of drift search steps to take.", + default=defs.DRIFT_N_DEPTH, + ) + + local_search_text_unit_prop: float = Field( + description="The proportion of search dedicated to text units.", + default=defs.DRIFT_LOCAL_SEARCH_TEXT_UNIT_PROP, + ) + + local_search_community_prop: float = Field( + description="The proportion of search dedicated to community properties.", + default=defs.DRIFT_LOCAL_SEARCH_COMMUNITY_PROP, + ) + + local_search_top_k_mapped_entities: int = Field( + description="The number of top K entities to map during local search.", + default=defs.DRIFT_LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES, + ) + + local_search_top_k_relationships: int = Field( + description="The number of top K relationships to map during local search.", + default=defs.DRIFT_LOCAL_SEARCH_TOP_K_RELATIONSHIPS, + ) + + local_search_max_data_tokens: int = Field( + description="The maximum context size in tokens for local search.", + default=defs.DRIFT_LOCAL_SEARCH_MAX_TOKENS, + ) + + local_search_temperature: float = Field( + description="The temperature to use for token generation in local search.", + default=defs.DRIFT_LOCAL_SEARCH_LLM_TEMPERATURE, + ) + + local_search_top_p: float = Field( + description="The top-p value to use for token generation in local search.", + default=defs.DRIFT_LOCAL_SEARCH_LLM_TOP_P, + ) + + local_search_n: int = Field( + description="The number of completions to generate in local search.", + default=defs.DRIFT_LOCAL_SEARCH_LLM_N, + ) + + local_search_llm_max_gen_tokens: int = Field( + description="The maximum number of generated tokens for the LLM in local search.", + default=defs.DRIFT_LOCAL_SEARCH_LLM_MAX_TOKENS, + ) diff --git a/graphrag/query/context_builder/builders.py b/graphrag/query/context_builder/builders.py index 7a4ba277ae..c47ee4d4bc 100644 --- a/graphrag/query/context_builder/builders.py +++ b/graphrag/query/context_builder/builders.py @@ -33,3 +33,15 @@ def build_context( **kwargs, ) -> tuple[str | list[str], dict[str, pd.DataFrame]]: """Build the context for the local search mode.""" + + +class DRIFTContextBuilder(ABC): + """Base class for DRIFT-search context builders.""" + + @abstractmethod + def build_context( + self, + query: str, + **kwargs, + ) -> pd.DataFrame: + """Build the context for the primer search actions.""" diff --git a/graphrag/query/factories.py b/graphrag/query/factories.py index 7d07e1d700..3ad520bfbc 100644 --- a/graphrag/query/factories.py +++ b/graphrag/query/factories.py @@ -21,7 +21,6 @@ from graphrag.query.llm.oai.chat_openai import ChatOpenAI from graphrag.query.llm.oai.embedding import OpenAIEmbedding from graphrag.query.llm.oai.typing import OpenaiApiType -from graphrag.query.structured_search.base import BaseSearch from graphrag.query.structured_search.global_search.community_context import ( GlobalCommunityContext, ) @@ -107,7 +106,7 @@ def get_local_search_engine( covariates: dict[str, list[Covariate]], response_type: str, description_embedding_store: BaseVectorStore, -) -> BaseSearch: +) -> LocalSearch: """Create a local search engine based on data + configuration.""" llm = get_llm(config) text_embedder = get_text_embedder(config) @@ -158,7 +157,7 @@ def get_global_search_engine( reports: list[CommunityReport], entities: list[Entity], response_type: str, -) -> BaseSearch: +) -> GlobalSearch: """Create a global search engine based on data + configuration.""" token_encoder = tiktoken.get_encoding(config.encoding_model) gs_config = config.global_search diff --git a/graphrag/query/indexer_adapters.py b/graphrag/query/indexer_adapters.py index dba4712bb4..812a3c23a7 100644 --- a/graphrag/query/indexer_adapters.py +++ b/graphrag/query/indexer_adapters.py @@ -63,6 +63,7 @@ def read_indexer_reports( final_community_reports: pd.DataFrame, final_nodes: pd.DataFrame, community_level: int, + content_embedding_col: str | None = None, ) -> list[CommunityReport]: """Read in the Community Reports from the raw indexing outputs.""" report_df = final_community_reports @@ -83,7 +84,7 @@ def read_indexer_reports( id_col="community", short_id_col="community", summary_embedding_col=None, - content_embedding_col=None, + content_embedding_col=content_embedding_col, ) diff --git a/graphrag/query/structured_search/base.py b/graphrag/query/structured_search/base.py index e74a03f67a..278d8c4bfc 100644 --- a/graphrag/query/structured_search/base.py +++ b/graphrag/query/structured_search/base.py @@ -6,12 +6,13 @@ from abc import ABC, abstractmethod from collections.abc import AsyncGenerator from dataclasses import dataclass -from typing import Any +from typing import Any, Generic, TypeVar import pandas as pd import tiktoken from graphrag.query.context_builder.builders import ( + DRIFTContextBuilder, GlobalContextBuilder, LocalContextBuilder, ) @@ -34,13 +35,16 @@ class SearchResult: prompt_tokens: int -class BaseSearch(ABC): +T = TypeVar("T", GlobalContextBuilder, LocalContextBuilder, DRIFTContextBuilder) + + +class BaseSearch(ABC, Generic[T]): """The Base Search implementation.""" def __init__( self, llm: BaseLLM, - context_builder: GlobalContextBuilder | LocalContextBuilder, + context_builder: T, token_encoder: tiktoken.Encoding | None = None, llm_params: dict[str, Any] | None = None, context_builder_params: dict[str, Any] | None = None, @@ -74,5 +78,5 @@ def astream_search( self, query: str, conversation_history: ConversationHistory | None = None, - ) -> AsyncGenerator[str, None]: + ) -> AsyncGenerator[str, None] | None: """Stream search for the given query.""" diff --git a/graphrag/query/structured_search/drift_search/__init__.py b/graphrag/query/structured_search/drift_search/__init__.py new file mode 100644 index 0000000000..fe58251741 --- /dev/null +++ b/graphrag/query/structured_search/drift_search/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""DriftSearch module.""" diff --git a/graphrag/query/structured_search/drift_search/action.py b/graphrag/query/structured_search/drift_search/action.py new file mode 100644 index 0000000000..1e90a9a18d --- /dev/null +++ b/graphrag/query/structured_search/drift_search/action.py @@ -0,0 +1,238 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""DRIFT Search Query State.""" + +import json +import logging +from typing import Any + +from graphrag.query.llm.text_utils import num_tokens + +log = logging.getLogger(__name__) + + +class DriftAction: + """ + Represent an action containing a query, answer, score, and follow-up actions. + + This class encapsulates action strings produced by the LLM in a structured way. + """ + + def __init__( + self, + query: str, + answer: str | None = None, + follow_ups: list["DriftAction"] | None = None, + ): + """ + Initialize the DriftAction with a query, optional answer, and follow-up actions. + + Args: + query (str): The query for the action. + answer (Optional[str]): The answer to the query, if available. + follow_ups (Optional[list[DriftAction]]): A list of follow-up actions. + """ + self.query = query + self.answer: str | None = answer # Corresponds to an 'intermediate_answer' + self.score: float | None = None + self.follow_ups: list[DriftAction] = ( + follow_ups if follow_ups is not None else [] + ) + self.metadata: dict[str, Any] = {} + + @property + def is_complete(self) -> bool: + """Check if the action is complete (i.e., an answer is available).""" + return self.answer is not None + + async def asearch(self, search_engine: Any, global_query: str, scorer: Any = None): + """ + Execute an asynchronous search using the search engine, and update the action with the results. + + If a scorer is provided, compute the score for the action. + + Args: + search_engine (Any): The search engine to execute the query. + global_query (str): The global query string. + scorer (Any, optional): Scorer to compute scores for the action. + + Returns + ------- + self : DriftAction + Updated action with search results. + """ + if self.is_complete: + log.warning("Action already complete. Skipping search.") + return self + + search_result = await search_engine.asearch( + drift_query=global_query, query=self.query + ) + + try: + response = json.loads(search_result.response) + except json.JSONDecodeError as e: + error_message = "Failed to parse search response" + log.exception("%s: %s", error_message, search_result.response) + raise ValueError(error_message) from e + + self.answer = response.pop("response", None) + self.score = response.pop("score", float("-inf")) + self.metadata.update({"context_data": search_result.context_data}) + + if self.answer is None: + log.warning("No answer found for query: %s", self.query) + generated_tokens = 0 + else: + generated_tokens = num_tokens(self.answer, search_engine.token_encoder) + self.metadata.update({ + "token_ct": search_result.prompt_tokens + generated_tokens + }) + + self.follow_ups = response.pop("follow_up_queries", []) + if not self.follow_ups: + log.warning("No follow-up actions found for response: %s", response) + + if scorer: + self.compute_score(scorer) + + return self + + def compute_score(self, scorer: Any): + """ + Compute the score for the action using the provided scorer. + + Args: + scorer (Any): The scorer to compute the score. + """ + score = scorer.compute_score(self.query, self.answer) + self.score = ( + score if score is not None else float("-inf") + ) # Default to -inf for sorting + + def serialize(self, include_follow_ups: bool = True) -> dict[str, Any]: + """ + Serialize the action to a dictionary. + + Args: + include_follow_ups (bool): Whether to include follow-up actions in the serialization. + + Returns + ------- + dict[str, Any] + Serialized action as a dictionary. + """ + data = { + "query": self.query, + "answer": self.answer, + "score": self.score, + "metadata": self.metadata, + } + if include_follow_ups: + data["follow_ups"] = [action.serialize() for action in self.follow_ups] + return data + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> "DriftAction": + """ + Deserialize the action from a dictionary. + + Args: + data (dict[str, Any]): Serialized action data. + + Returns + ------- + DriftAction + A deserialized instance of DriftAction. + """ + # Ensure 'query' exists in the data, raise a ValueError if missing + query = data.get("query") + if query is None: + error_message = "Missing 'query' key in serialized data" + raise ValueError(error_message) + + # Initialize the DriftAction + action = cls(query) + action.answer = data.get("answer") + action.score = data.get("score") + action.metadata = data.get("metadata", {}) + + action.follow_ups = [ + cls.deserialize(fu_data) for fu_data in data.get("follow_up_queries", []) + ] + return action + + @classmethod + def from_primer_response( + cls, query: str, response: str | dict[str, Any] | list[dict[str, Any]] + ) -> "DriftAction": + """ + Create a DriftAction from a DRIFTPrimer response. + + Args: + query (str): The query string. + response (str | dict[str, Any] | list[dict[str, Any]]): Primer response data. + + Returns + ------- + DriftAction + A new instance of DriftAction based on the response. + + Raises + ------ + ValueError + If the response is not a dictionary or expected format. + """ + if isinstance(response, dict): + action = cls( + query, + follow_ups=response.get("follow_up_queries", []), + answer=response.get("intermediate_answer"), + ) + action.score = response.get("score") + return action + + # If response is a string, attempt to parse as JSON + if isinstance(response, str): + try: + parsed_response = json.loads(response) + if isinstance(parsed_response, dict): + return cls.from_primer_response(query, parsed_response) + error_message = "Parsed response must be a dictionary." + raise ValueError(error_message) + except json.JSONDecodeError as e: + error_message = f"Failed to parse response string: {e}. Parsed response must be a dictionary." + raise ValueError(error_message) from e + + error_message = f"Unsupported response type: {type(response).__name__}. Expected a dictionary or JSON string." + raise ValueError(error_message) + + def __hash__(self) -> int: + """ + Allow DriftAction objects to be hashable for use in networkx.MultiDiGraph. + + Assumes queries are unique. + + Returns + ------- + int + Hash based on the query. + """ + return hash(self.query) + + def __eq__(self, other: object) -> bool: + """ + Check equality based on the query string. + + Args: + other (Any): Another object to compare with. + + Returns + ------- + bool + True if the other object is a DriftAction with the same query, False otherwise. + """ + if not isinstance(other, DriftAction): + return False + return self.query == other.query diff --git a/graphrag/query/structured_search/drift_search/drift_context.py b/graphrag/query/structured_search/drift_search/drift_context.py new file mode 100644 index 0000000000..7911c54c78 --- /dev/null +++ b/graphrag/query/structured_search/drift_search/drift_context.py @@ -0,0 +1,215 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""DRIFT Context Builder implementation.""" + +import logging +from dataclasses import asdict +from typing import Any + +import numpy as np +import pandas as pd +import tiktoken + +from graphrag.config.models.drift_config import DRIFTSearchConfig +from graphrag.model import ( + CommunityReport, + Covariate, + Entity, + Relationship, + TextUnit, +) +from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey +from graphrag.query.llm.base import BaseTextEmbedding +from graphrag.query.llm.oai.chat_openai import ChatOpenAI +from graphrag.query.structured_search.base import DRIFTContextBuilder +from graphrag.query.structured_search.drift_search.primer import PrimerQueryProcessor +from graphrag.query.structured_search.drift_search.system_prompt import ( + DRIFT_LOCAL_SYSTEM_PROMPT, +) +from graphrag.query.structured_search.local_search.mixed_context import ( + LocalSearchMixedContext, +) +from graphrag.vector_stores import BaseVectorStore + +log = logging.getLogger(__name__) + + +class DRIFTSearchContextBuilder(DRIFTContextBuilder): + """Class representing the DRIFT Search Context Builder.""" + + def __init__( + self, + chat_llm: ChatOpenAI, + text_embedder: BaseTextEmbedding, + entities: list[Entity], + entity_text_embeddings: BaseVectorStore, + text_units: list[TextUnit] | None = None, + reports: list[CommunityReport] | None = None, + relationships: list[Relationship] | None = None, + covariates: dict[str, list[Covariate]] | None = None, + token_encoder: tiktoken.Encoding | None = None, + embedding_vectorstore_key: str = EntityVectorStoreKey.ID, + config: DRIFTSearchConfig | None = None, + local_system_prompt: str = DRIFT_LOCAL_SYSTEM_PROMPT, + local_mixed_context: LocalSearchMixedContext | None = None, + ): + """Initialize the DRIFT search context builder with necessary components.""" + self.config = config or DRIFTSearchConfig() + self.chat_llm = chat_llm + self.text_embedder = text_embedder + self.token_encoder = token_encoder + self.local_system_prompt = local_system_prompt + + self.entities = entities + self.entity_text_embeddings = entity_text_embeddings + self.reports = reports + self.text_units = text_units + self.relationships = relationships + self.covariates = covariates + self.embedding_vectorstore_key = embedding_vectorstore_key + + self.llm_tokens = 0 + self.local_mixed_context = ( + local_mixed_context or self.init_local_context_builder() + ) + + def init_local_context_builder(self) -> LocalSearchMixedContext: + """ + Initialize the local search mixed context builder. + + Returns + ------- + LocalSearchMixedContext: Initialized local context. + """ + return LocalSearchMixedContext( + community_reports=self.reports, + text_units=self.text_units, + entities=self.entities, + relationships=self.relationships, + covariates=self.covariates, + entity_text_embeddings=self.entity_text_embeddings, + embedding_vectorstore_key=self.embedding_vectorstore_key, + text_embedder=self.text_embedder, + token_encoder=self.token_encoder, + ) + + @staticmethod + def convert_reports_to_df(reports: list[CommunityReport]) -> pd.DataFrame: + """ + Convert a list of CommunityReport objects to a pandas DataFrame. + + Args + ---- + reports : list[CommunityReport] + List of CommunityReport objects. + + Returns + ------- + pd.DataFrame: DataFrame with report data. + + Raises + ------ + ValueError: If some reports are missing full content or full content embeddings. + """ + report_df = pd.DataFrame([asdict(report) for report in reports]) + missing_content_error = "Some reports are missing full content." + missing_embedding_error = "Some reports are missing full content embeddings." + + if ( + "full_content" not in report_df.columns + or report_df["full_content"].isna().sum() > 0 + ): + raise ValueError(missing_content_error) + + if ( + "full_content_embedding" not in report_df.columns + or report_df["full_content_embedding"].isna().sum() > 0 + ): + raise ValueError(missing_embedding_error) + return report_df + + @staticmethod + def check_query_doc_encodings(query_embedding: Any, embedding: Any) -> bool: + """ + Check if the embeddings are compatible. + + Args + ---- + query_embedding : Any + Embedding of the query. + embedding : Any + Embedding to compare against. + + Returns + ------- + bool: True if embeddings match, otherwise False. + """ + return ( + query_embedding is not None + and embedding is not None + and isinstance(query_embedding, type(embedding)) + and len(query_embedding) == len(embedding) + and isinstance(query_embedding[0], type(embedding[0])) + ) + + def build_context(self, query: str, **kwargs) -> pd.DataFrame: + """ + Build DRIFT search context. + + Args + ---- + query : str + Search query string. + + Returns + ------- + pd.DataFrame: Top-k most similar documents. + + Raises + ------ + ValueError: If no community reports are available, or embeddings + are incompatible. + """ + if self.reports is None: + missing_reports_error = ( + "No community reports available. Please provide a list of reports." + ) + raise ValueError(missing_reports_error) + + query_processor = PrimerQueryProcessor( + chat_llm=self.chat_llm, + text_embedder=self.text_embedder, + token_encoder=self.token_encoder, + reports=self.reports, + ) + + query_embedding, token_ct = query_processor(query) + self.llm_tokens += token_ct + + report_df = self.convert_reports_to_df(self.reports) + + # Check compatibility between query embedding and document embeddings + if not self.check_query_doc_encodings( + query_embedding, report_df["full_content_embedding"].iloc[0] + ): + error_message = ( + "Query and document embeddings are not compatible. " + "Please ensure that the embeddings are of the same type and length." + ) + raise ValueError(error_message) + + # Vectorized cosine similarity computation + query_norm = np.linalg.norm(query_embedding) + document_norms = np.linalg.norm( + report_df["full_content_embedding"].to_list(), axis=1 + ) + dot_products = np.dot( + np.vstack(report_df["full_content_embedding"].to_list()), query_embedding + ) + report_df["similarity"] = dot_products / (document_norms * query_norm) + + # Sort by similarity and select top-k + top_k = report_df.nlargest(self.config.drift_k_followups, "similarity") + + return top_k.loc[:, ["short_id", "community_id", "full_content"]] diff --git a/graphrag/query/structured_search/drift_search/primer.py b/graphrag/query/structured_search/drift_search/primer.py new file mode 100644 index 0000000000..1a3d7b27df --- /dev/null +++ b/graphrag/query/structured_search/drift_search/primer.py @@ -0,0 +1,193 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Primer for DRIFT search.""" + +import json +import logging +import secrets +import time + +import numpy as np +import pandas as pd +import tiktoken +from tqdm.asyncio import tqdm_asyncio + +from graphrag.config.models.drift_config import DRIFTSearchConfig +from graphrag.model import CommunityReport +from graphrag.query.llm.base import BaseTextEmbedding +from graphrag.query.llm.oai.chat_openai import ChatOpenAI +from graphrag.query.llm.text_utils import num_tokens +from graphrag.query.structured_search.base import SearchResult +from graphrag.query.structured_search.drift_search.system_prompt import ( + DRIFT_PRIMER_PROMPT, +) + +log = logging.getLogger(__name__) + + +class PrimerQueryProcessor: + """Process the query by expanding it using community reports and generate follow-up actions.""" + + def __init__( + self, + chat_llm: ChatOpenAI, + text_embedder: BaseTextEmbedding, + reports: list[CommunityReport], + token_encoder: tiktoken.Encoding | None = None, + ): + """ + Initialize the PrimerQueryProcessor. + + Args: + chat_llm (ChatOpenAI): The language model used to process the query. + text_embedder (BaseTextEmbedding): The text embedding model. + reports (list[CommunityReport]): List of community reports. + token_encoder (tiktoken.Encoding, optional): Token encoder for token counting. + """ + self.chat_llm = chat_llm + self.text_embedder = text_embedder + self.token_encoder = token_encoder + self.reports = reports + + def expand_query(self, query: str) -> tuple[str, int]: + """ + Expand the query using a random community report template. + + Args: + query (str): The original search query. + + Returns + ------- + tuple[str, int]: Expanded query text and the number of tokens used. + """ + token_ct = 0 + template = secrets.choice(self.reports).full_content # nosec S311 + + prompt = f"""Create a hypothetical answer to the following query: {query}\n\n + Format it to follow the structure of the template below:\n\n + {template}\n" + Ensure that the hypothetical answer does not reference new named entities that are not present in the original query.""" + + messages = [{"role": "user", "content": prompt}] + + text = self.chat_llm.generate(messages) + token_ct = num_tokens(text + query) + if text == "": + log.warning("Failed to generate expansion for query: %s", query) + return query, token_ct + return text, token_ct + + def __call__(self, query: str) -> tuple[list[float], int]: + """ + Call method to process the query, expand it, and embed the result. + + Args: + query (str): The search query. + + Returns + ------- + tuple[list[float], int]: List of embeddings for the expanded query and the token count. + """ + hyde_query, token_ct = self.expand_query(query) + log.info("Expanded query: %s", hyde_query) + return self.text_embedder.embed(hyde_query), token_ct + + +class DRIFTPrimer: + """Perform initial query decomposition using global guidance from information in community reports.""" + + def __init__( + self, + config: DRIFTSearchConfig, + chat_llm: ChatOpenAI, + token_encoder: tiktoken.Encoding | None = None, + ): + """ + Initialize the DRIFTPrimer. + + Args: + config (DRIFTSearchConfig): Configuration settings for DRIFT search. + chat_llm (ChatOpenAI): The language model used for searching. + token_encoder (tiktoken.Encoding, optional): Token encoder for managing tokens. + """ + self.llm = chat_llm + self.config = config + self.token_encoder = token_encoder + + async def decompose_query( + self, query: str, reports: pd.DataFrame + ) -> tuple[dict, int]: + """ + Decompose the query into subqueries based on the fetched global structures. + + Args: + query (str): The original search query. + reports (pd.DataFrame): DataFrame containing community reports. + + Returns + ------- + tuple[dict, int]: Parsed response and the number of tokens used. + """ + community_reports = "\n\n".join(reports["full_content"].tolist()) + prompt = DRIFT_PRIMER_PROMPT.format( + query=query, community_reports=community_reports + ) + messages = [{"role": "user", "content": prompt}] + + response = await self.llm.agenerate( + messages, response_format={"type": "json_object"} + ) + + parsed_response = json.loads(response) + token_ct = num_tokens(prompt + response, self.token_encoder) + + return parsed_response, token_ct + + async def asearch( + self, + query: str, + top_k_reports: pd.DataFrame, + ) -> SearchResult: + """ + Asynchronous search method that processes the query and returns a SearchResult. + + Args: + query (str): The search query. + top_k_reports (pd.DataFrame): DataFrame containing the top-k reports. + + Returns + ------- + SearchResult: The search result containing the response and context data. + """ + start_time = time.perf_counter() + report_folds = self.split_reports(top_k_reports) + tasks = [self.decompose_query(query, fold) for fold in report_folds] + results_with_tokens = await tqdm_asyncio.gather(*tasks) + + completion_time = time.perf_counter() - start_time + + return SearchResult( + response=[response for response, _ in results_with_tokens], + context_data={"top_k_reports": top_k_reports}, + context_text=top_k_reports.to_json() or "", + completion_time=completion_time, + llm_calls=len(results_with_tokens), + prompt_tokens=sum(tokens for _, tokens in results_with_tokens), + ) + + def split_reports(self, reports: pd.DataFrame) -> list[pd.DataFrame]: + """ + Split the reports into folds, allowing for parallel processing. + + Args: + reports (pd.DataFrame): DataFrame of community reports. + + Returns + ------- + list[pd.DataFrame]: List of report folds. + """ + primer_folds = self.config.primer_folds or 1 # Ensure at least one fold + if primer_folds == 1: + return [reports] + return [pd.DataFrame(fold) for fold in np.array_split(reports, primer_folds)] diff --git a/graphrag/query/structured_search/drift_search/search.py b/graphrag/query/structured_search/drift_search/search.py new file mode 100644 index 0000000000..947cb726e6 --- /dev/null +++ b/graphrag/query/structured_search/drift_search/search.py @@ -0,0 +1,290 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""DRIFT Search implementation.""" + +import logging +import time +from collections.abc import AsyncGenerator +from typing import Any + +import tiktoken +from tqdm.asyncio import tqdm_asyncio + +from graphrag.config.models.drift_config import DRIFTSearchConfig +from graphrag.query.context_builder.conversation_history import ConversationHistory +from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey +from graphrag.query.llm.oai.chat_openai import ChatOpenAI +from graphrag.query.structured_search.base import BaseSearch, SearchResult +from graphrag.query.structured_search.drift_search.action import DriftAction +from graphrag.query.structured_search.drift_search.drift_context import ( + DRIFTSearchContextBuilder, +) +from graphrag.query.structured_search.drift_search.primer import DRIFTPrimer +from graphrag.query.structured_search.drift_search.state import QueryState +from graphrag.query.structured_search.local_search.search import LocalSearch + +log = logging.getLogger(__name__) + + +class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]): + """Class representing a DRIFT Search.""" + + def __init__( + self, + llm: ChatOpenAI, + context_builder: DRIFTSearchContextBuilder, + config: DRIFTSearchConfig | None = None, + token_encoder: tiktoken.Encoding | None = None, + query_state: QueryState | None = None, + ): + """ + Initialize the DRIFTSearch class. + + Args: + llm (ChatOpenAI): The language model used for searching. + context_builder (DRIFTSearchContextBuilder): Builder for search context. + config (DRIFTSearchConfig, optional): Configuration settings for DRIFTSearch. + token_encoder (tiktoken.Encoding, optional): Token encoder for managing tokens. + query_state (QueryState, optional): State of the current search query. + """ + super().__init__(llm, context_builder, token_encoder) + + self.config = config or DRIFTSearchConfig() + self.context_builder = context_builder + self.token_encoder = token_encoder + self.query_state = query_state or QueryState() + self.primer = DRIFTPrimer( + config=self.config, chat_llm=llm, token_encoder=token_encoder + ) + self.local_search = self.init_local_search() + + def init_local_search(self) -> LocalSearch: + """ + Initialize the LocalSearch object with parameters based on the DRIFT search configuration. + + Returns + ------- + LocalSearch: An instance of the LocalSearch class with the configured parameters. + """ + local_context_params = { + "text_unit_prop": self.config.local_search_text_unit_prop, + "community_prop": self.config.local_search_community_prop, + "top_k_mapped_entities": self.config.local_search_top_k_mapped_entities, + "top_k_relationships": self.config.local_search_top_k_relationships, + "include_entity_rank": True, + "include_relationship_weight": True, + "include_community_rank": False, + "return_candidate_context": False, + "embedding_vectorstore_key": EntityVectorStoreKey.ID, + "max_tokens": self.config.local_search_max_data_tokens, + } + + llm_params = { + "max_tokens": self.config.local_search_llm_max_gen_tokens, + "temperature": self.config.local_search_temperature, + "response_format": {"type": "json_object"}, + } + + return LocalSearch( + llm=self.llm, + system_prompt=self.context_builder.local_system_prompt, + context_builder=self.context_builder.local_mixed_context, + token_encoder=self.token_encoder, + llm_params=llm_params, + context_builder_params=local_context_params, + response_type="multiple paragraphs", + ) + + def _process_primer_results( + self, query: str, search_results: SearchResult + ) -> DriftAction: + """ + Process the results from the primer search to extract intermediate answers and follow-up queries. + + Args: + query (str): The original search query. + search_results (SearchResult): The results from the primer search. + + Returns + ------- + DriftAction: Action generated from the primer response. + + Raises + ------ + RuntimeError: If no intermediate answers or follow-up queries are found in the primer response. + """ + response = search_results.response + if isinstance(response, list) and isinstance(response[0], dict): + intermediate_answers = [ + i["intermediate_answer"] for i in response if "intermediate_answer" in i + ] + + if not intermediate_answers: + error_msg = "No intermediate answers found in primer response. Ensure that the primer response includes intermediate answers." + raise RuntimeError(error_msg) + + intermediate_answer = "\n\n".join([ + i["intermediate_answer"] for i in response if "intermediate_answer" in i + ]) + + follow_ups = [fu for i in response for fu in i.get("follow_up_queries", [])] + if len(follow_ups) == 0: + error_msg = "No follow-up queries found in primer response. Ensure that the primer response includes follow-up queries." + raise RuntimeError(error_msg) + + score = sum(i["score"] for i in response) / len(response) + response_data = { + "intermediate_answer": intermediate_answer, + "follow_up_queries": follow_ups, + "score": score, + } + return DriftAction.from_primer_response(query, response_data) + error_msg = "Response must be a list of dictionaries." + raise ValueError(error_msg) + + async def asearch_step( + self, global_query: str, search_engine: LocalSearch, actions: list[DriftAction] + ) -> list[DriftAction]: + """ + Perform an asynchronous search step by executing each DriftAction asynchronously. + + Args: + global_query (str): The global query for the search. + search_engine (LocalSearch): The local search engine instance. + actions (list[DriftAction]): A list of actions to perform. + + Returns + ------- + list[DriftAction]: The results from executing the search actions asynchronously. + """ + tasks = [ + action.asearch(search_engine=search_engine, global_query=global_query) + for action in actions + ] + return await tqdm_asyncio.gather(*tasks) + + async def asearch( + self, + query: str, + conversation_history: Any = None, + **kwargs, + ) -> SearchResult: + """ + Perform an asynchronous DRIFT search. + + Args: + query (str): The query to search for. + conversation_history (Any, optional): The conversation history, if any. + + Returns + ------- + SearchResult: The search result containing the response and context data. + + Raises + ------ + ValueError: If the query is empty. + """ + if query == "": + error_msg = "DRIFT Search query cannot be empty." + raise ValueError(error_msg) + + start_time = time.perf_counter() + primer_token_ct = 0 + context_token_ct = 0 + + # Check if query state is empty + if not self.query_state.graph: + # Prime the search with the primer + primer_context = self.context_builder.build_context(query) + context_token_ct = self.context_builder.llm_tokens + + primer_response = await self.primer.asearch( + query=query, top_k_reports=primer_context + ) + primer_token_ct = primer_response.prompt_tokens + # Package response into DriftAction + init_action = self._process_primer_results(query, primer_response) + self.query_state.add_action(init_action) + self.query_state.add_all_follow_ups(init_action, init_action.follow_ups) + + # Main loop + epochs = 0 + llm_call_offset = 0 + while epochs < self.config.n: + actions = self.query_state.rank_incomplete_actions() + if len(actions) == 0: + log.info("No more actions to take. Exiting DRIFT loop.") + break + actions = actions[: self.config.drift_k_followups] + llm_call_offset += len(actions) - self.config.drift_k_followups + # Process actions + results = await self.asearch_step( + global_query=query, search_engine=self.local_search, actions=actions + ) + + # Update query state + for action in results: + self.query_state.add_action(action) + self.query_state.add_all_follow_ups(action, action.follow_ups) + epochs += 1 + + t_elapsed = time.perf_counter() - start_time + + # Calculate token usage + total_tokens = ( + primer_token_ct + context_token_ct + self.query_state.action_token_ct() + ) + + # Package up context data + response_state, context_data, context_text = self.query_state.serialize( + include_context=True + ) + + return SearchResult( + response=response_state, + context_data=context_data, + context_text=context_text, + completion_time=t_elapsed, + llm_calls=1 + + self.config.primer_folds + + (self.config.drift_k_followups - llm_call_offset) * self.config.n_depth, + prompt_tokens=total_tokens, + ) + + def search( + self, + query: str, + conversation_history: Any = None, + **kwargs, + ) -> SearchResult: + """ + Perform a synchronous DRIFT search (Not Implemented). + + Args: + query (str): The query to search for. + conversation_history (Any, optional): The conversation history. + + Raises + ------ + NotImplementedError: Synchronous DRIFT is not implemented. + """ + error_msg = "Synchronous DRIFT is not implemented." + raise NotImplementedError(error_msg) + + def astream_search( + self, query: str, conversation_history: ConversationHistory | None = None + ) -> AsyncGenerator[str, None]: + """ + Perform a streaming DRIFT search (Not Implemented). + + Args: + query (str): The query to search for. + conversation_history (ConversationHistory, optional): The conversation history. + + Raises + ------ + NotImplementedError: Streaming DRIFT search is not implemented. + """ + error_msg = "Streaming DRIFT search is not implemented." + raise NotImplementedError(error_msg) diff --git a/graphrag/query/structured_search/drift_search/state.py b/graphrag/query/structured_search/drift_search/state.py new file mode 100644 index 0000000000..979a03e212 --- /dev/null +++ b/graphrag/query/structured_search/drift_search/state.py @@ -0,0 +1,141 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Manage the state of the DRIFT query, including a graph of actions.""" + +import logging +import random +from collections.abc import Callable +from typing import Any + +import networkx as nx + +from graphrag.query.structured_search.drift_search.action import DriftAction + +log = logging.getLogger(__name__) + + +class QueryState: + """Manage the state of the query, including a graph of actions.""" + + def __init__(self): + self.graph = nx.MultiDiGraph() + + def add_action(self, action: DriftAction, metadata: dict[str, Any] | None = None): + """Add an action to the graph with optional metadata.""" + self.graph.add_node(action, **(metadata or {})) + + def relate_actions( + self, parent: DriftAction, child: DriftAction, weight: float = 1.0 + ): + """Relate two actions in the graph.""" + self.graph.add_edge(parent, child, weight=weight) + + def add_all_follow_ups( + self, + action: DriftAction, + follow_ups: list[DriftAction] | list[str], + weight: float = 1.0, + ): + """Add all follow-up actions and links them to the given action.""" + if len(follow_ups) == 0: + log.warning("No follow-up actions for action: %s", action.query) + + for follow_up in follow_ups: + if isinstance(follow_up, str): + follow_up = DriftAction(query=follow_up) + elif not isinstance(follow_up, DriftAction): + log.warning( + "Follow-up action is not a string, found type: %s", type(follow_up) + ) + + self.add_action(follow_up) + self.relate_actions(action, follow_up, weight) + + def find_incomplete_actions(self) -> list[DriftAction]: + """Find all unanswered actions in the graph.""" + return [node for node in self.graph.nodes if not node.is_complete] + + def rank_incomplete_actions( + self, scorer: Callable[[DriftAction], float] | None = None + ) -> list[DriftAction]: + """Rank all unanswered actions in the graph if scorer available.""" + unanswered = self.find_incomplete_actions() + if scorer: + for node in unanswered: + node.compute_score(scorer) + return sorted( + unanswered, + key=lambda node: ( + node.score if node.score is not None else float("-inf") + ), + reverse=True, + ) + + # shuffle the list if no scorer + random.shuffle(unanswered) + return list(unanswered) + + def serialize( + self, include_context: bool = True + ) -> dict[str, Any] | tuple[dict[str, Any], dict[str, Any], str]: + """Serialize the graph to a dictionary, including nodes and edges.""" + # Create a mapping from nodes to unique IDs + node_to_id = {node: idx for idx, node in enumerate(self.graph.nodes())} + + # Serialize nodes + nodes: list[dict[str, Any]] = [ + { + **node.serialize(include_follow_ups=False), + "id": node_to_id[node], + **self.graph.nodes[node], + } + for node in self.graph.nodes() + ] + + # Serialize edges + edges: list[dict[str, Any]] = [ + { + "source": node_to_id[u], + "target": node_to_id[v], + "weight": edge_data.get("weight", 1.0), + } + for u, v, edge_data in self.graph.edges(data=True) + ] + + if include_context: + context_data = { + node["query"]: node["metadata"]["context_data"] + for node in nodes + if node["metadata"].get("context_data") and node.get("query") + } + + context_text = str(context_data) + + return {"nodes": nodes, "edges": edges}, context_data, context_text + + return {"nodes": nodes, "edges": edges} + + def deserialize(self, data: dict[str, Any]): + """Deserialize the dictionary back to a graph.""" + self.graph.clear() + id_to_action = {} + + for node_data in data.get("nodes", []): + node_id = node_data.pop("id") + action = DriftAction.deserialize(node_data) + self.add_action(action) + id_to_action[node_id] = action + + for edge_data in data.get("edges", []): + source_id = edge_data["source"] + target_id = edge_data["target"] + weight = edge_data.get("weight", 1.0) + source_action = id_to_action.get(source_id) + target_action = id_to_action.get(target_id) + if source_action and target_action: + self.relate_actions(source_action, target_action, weight) + + def action_token_ct(self) -> int: + """Return the token count of the action.""" + return sum(action.metadata.get("token_ct", 0) for action in self.graph.nodes) diff --git a/graphrag/query/structured_search/drift_search/system_prompt.py b/graphrag/query/structured_search/drift_search/system_prompt.py new file mode 100644 index 0000000000..eb0e07c262 --- /dev/null +++ b/graphrag/query/structured_search/drift_search/system_prompt.py @@ -0,0 +1,169 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""DRIFT Search prompts.""" + +DRIFT_LOCAL_SYSTEM_PROMPT = """ +---Role--- + +You are a helpful assistant responding to questions about data in the tables provided. + + +---Goal--- + +Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge. + +If you don't know the answer, just say so. Do not make anything up. + +Points supported by data should list their data references as follows: + +"This is an example sentence supported by multiple data references [Data: (record ids); (record ids)]." + +Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: + +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16)]." + +where 15, 16, 1, 5, 7, 23, 2, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Pay close attention specifically to the Sources tables as it contains the most relevant information for the user query. You will be rewarded for preserving the context of the sources in your response. + +---Target response length and format--- + +{response_type} + + +---Data tables--- + +{context_data} + + +---Goal--- + +Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge. + +If you don't know the answer, just say so. Do not make anything up. + +Points supported by data should list their data references as follows: + +"This is an example sentence supported by multiple data references [Data: (record ids); (record ids)]." + +Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: + +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16)]." + +where 15, 16, 1, 5, 7, 23, 2, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Pay close attention specifically to the Sources tables as it contains the most relevant information for the user query. You will be rewarded for preserving the context of the sources in your response. + +---Target response length and format--- + +{response_type} + +Add sections and commentary to the response as appropriate for the length and format. + +Additionally provide a score for how well the response addresses the overall research question: {global_query}. Based on your response, suggest a few follow-up questions that could be asked to further explore the topic. Do not include scores or follow up questions in the 'response' field of the JSON, add them to the respective 'score' and 'follow_up_queries' keys of the JSON output. Generate at least five good follow-up queries. Format your response in JSON with the following keys and values: + +{{'response': str, Put your answer, formatted in markdown, here. Do not answer the global query in this section. +'score': int, +'follow_up_queries': List[str]}} +""" + + +DRIFT_REDUCE_PROMPT = """ +---Role--- + +You are a helpful assistant responding to questions about data in the reports provided. + +---Goal--- + +Generate a response of the target length and format that responds to the user's question, summarizing all information in the input reports appropriate for the response length and format, and incorporating any relevant general knowledge while being as specific, accurate and concise as possible. + +If you don't know the answer, just say so. Do not make anything up. + +Points supported by data should list their data references as follows: + +"This is an example sentence supported by multiple data references [Data: (record ids); (record ids)]." + +Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: + +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (1, 5, 15)]." + +Do not include information where the supporting evidence for it is not provided. + +If you decide to use general knowledge, you should add a delimiter stating that the information is not supported by the data tables. For example: + +"Person X is the owner of Company Y and subject to many allegations of wrongdoing. [Data: General Knowledge (href)]" + +---Data Reports--- + +{context_data} + +---Target response length and format--- + +Multiple paragraphs + + +---Goal--- + +Generate a response of the target length and format that responds to the user's question, summarizing all information in the input reports appropriate for the response length and format, and incorporating any relevant general knowledge while being as specific, accurate and concise as possible. + +If you don't know the answer, just say so. Do not make anything up. + +Points supported by data should list their data references as follows: + +"This is an example sentence supported by multiple data references [Data: (record ids); (record ids)]." + +Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: + +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (1, 5, 15)]." + +Do not include information where the supporting evidence for it is not provided. + +If you decide to use general knowledge, you should add a delimiter stating that the information is not supported by the data tables. For example: + +"Person X is the owner of Company Y and subject to many allegations of wrongdoing. [Data: General Knowledge (href)]". + +Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. Now answer the following query using the data above: + +{query} + +""" + + +DRIFT_PRIMER_PROMPT = """You are a helpful agent designed to reason over a knowledge graph in response to a user query. +This is a unique knowledge graph where edges are freeform text rather than verb operators. You will begin your reasoning looking at a summary of the content of the most relevant communites and will provide: + +1. score: How well the intermediate answer addresses the query. A score of 0 indicates a poor, unfocused answer, while a score of 100 indicates a highly focused, relevant answer that addresses the query in its entirety. + +2. intermediate_answer: This answer should match the level of detail and length found in the community summaries. The intermediate answer should be exactly 2000 characters long. This must be formatted in markdown and must begin with a header that explains how the following text is related to the query. + +3. follow_up_queries: A list of follow-up queries that could be asked to further explore the topic. These should be formatted as a list of strings. Generate at least five good follow-up queries. + +Use this information to help you decide whether or not you need more information about the entities mentioned in the report. You may also use your general knowledge to think of entities which may help enrich your answer. + +You will also provide a full answer from the content you have available. Use the data provided to generate follow-up queries to help refine your search. Do not ask compound questions, for example: "What is the market cap of Apple and Microsoft?". Use your knowledge of the entity distribution to focus on entity types that will be useful for searching a broad area of the knowledge graph. + +For the query: + +{query} + +The top-ranked community summaries: + +{community_reports} + +Provide the intermediate answer, and all scores in JSON format following: + +{{'intermediate_answer': str, +'score': int, +'follow_up_queries': List[str]}} + +Begin: +""" diff --git a/graphrag/query/structured_search/global_search/search.py b/graphrag/query/structured_search/global_search/search.py index 5e8a71b937..5945ab9e98 100644 --- a/graphrag/query/structured_search/global_search/search.py +++ b/graphrag/query/structured_search/global_search/search.py @@ -54,7 +54,7 @@ class GlobalSearchResult(SearchResult): reduce_context_text: str | list[str] | dict[str, str] -class GlobalSearch(BaseSearch): +class GlobalSearch(BaseSearch[GlobalContextBuilder]): """Search orchestration for global search mode.""" def __init__( @@ -145,6 +145,7 @@ async def asearch( - Step 2: Combine the answers from step 2 to generate the final answer """ # Step 1: Generate answers for each batch of community short summaries + start_time = time.time() context_chunks, context_records = self.context_builder.build_context( conversation_history=conversation_history, **self.context_builder_params diff --git a/graphrag/query/structured_search/local_search/search.py b/graphrag/query/structured_search/local_search/search.py index fd94e46f8d..412e795c09 100644 --- a/graphrag/query/structured_search/local_search/search.py +++ b/graphrag/query/structured_search/local_search/search.py @@ -29,7 +29,7 @@ log = logging.getLogger(__name__) -class LocalSearch(BaseSearch): +class LocalSearch(BaseSearch[LocalContextBuilder]): """Search orchestration for local search mode.""" def __init__( @@ -72,9 +72,17 @@ async def asearch( ) log.info("GENERATE ANSWER: %s. QUERY: %s", start_time, query) try: - search_prompt = self.system_prompt.format( - context_data=context_text, response_type=self.response_type - ) + if "drift_query" in kwargs: + drift_query = kwargs["drift_query"] + search_prompt = self.system_prompt.format( + context_data=context_text, + response_type=self.response_type, + global_query=drift_query, + ) + else: + search_prompt = self.system_prompt.format( + context_data=context_text, response_type=self.response_type + ) search_messages = [ {"role": "system", "content": search_prompt}, {"role": "user", "content": query}, diff --git a/tests/fixtures/min-csv/config.json b/tests/fixtures/min-csv/config.json index 53986d9914..db07f7730d 100644 --- a/tests/fixtures/min-csv/config.json +++ b/tests/fixtures/min-csv/config.json @@ -5,23 +5,23 @@ "create_base_text_units": { "row_range": [ 1, - 2000 + 2500 ], "subworkflows": 1, - "max_runtime": 10 + "max_runtime": 100 }, "create_base_entity_graph": { "row_range": [ 1, - 2000 + 2500 ], "subworkflows": 1, - "max_runtime": 300 + "max_runtime": 100 }, "create_final_entities": { "row_range": [ 1, - 2000 + 2500 ], "nan_allowed_columns": [ "type", @@ -34,7 +34,7 @@ "create_final_relationships": { "row_range": [ 1, - 2000 + 2500 ], "subworkflows": 1, "max_runtime": 100 @@ -42,7 +42,7 @@ "create_final_nodes": { "row_range": [ 1, - 2000 + 2500 ], "nan_allowed_columns": [ "entity_type", @@ -52,20 +52,20 @@ "level" ], "subworkflows": 1, - "max_runtime": 10 + "max_runtime": 100 }, "create_final_communities": { "row_range": [ 1, - 2000 + 2500 ], "subworkflows": 1, - "max_runtime": 10 + "max_runtime": 100 }, "create_final_community_reports": { "row_range": [ 1, - 2000 + 2500 ], "nan_allowed_columns": [ "community_id", @@ -83,7 +83,7 @@ "create_final_text_units": { "row_range": [ 1, - 2000 + 2500 ], "nan_allowed_columns": [ "relationship_ids", @@ -95,7 +95,7 @@ "create_final_documents": { "row_range": [ 1, - 2000 + 2500 ], "subworkflows": 1, "max_runtime": 100 diff --git a/tests/fixtures/text/config.json b/tests/fixtures/text/config.json index af1eba3ff3..6e8a2a3ebe 100644 --- a/tests/fixtures/text/config.json +++ b/tests/fixtures/text/config.json @@ -5,15 +5,15 @@ "create_base_text_units": { "row_range": [ 1, - 2000 + 2500 ], "subworkflows": 1, - "max_runtime": 10 + "max_runtime": 100 }, "create_final_covariates": { "row_range": [ 1, - 2000 + 2500 ], "nan_allowed_columns": [ "type", @@ -30,7 +30,7 @@ "create_base_entity_graph": { "row_range": [ 1, - 2000 + 2500 ], "subworkflows": 1, "max_runtime": 300 @@ -38,7 +38,7 @@ "create_final_entities": { "row_range": [ 1, - 2000 + 2500 ], "nan_allowed_columns": [ "type", @@ -51,7 +51,7 @@ "create_final_relationships": { "row_range": [ 1, - 2000 + 2500 ], "subworkflows": 1, "max_runtime": 100 @@ -59,7 +59,7 @@ "create_final_nodes": { "row_range": [ 1, - 2000 + 2500 ], "nan_allowed_columns": [ "entity_type", @@ -69,20 +69,20 @@ "level" ], "subworkflows": 1, - "max_runtime": 10 + "max_runtime": 100 }, "create_final_communities": { "row_range": [ 1, - 2000 + 2500 ], "subworkflows": 1, - "max_runtime": 10 + "max_runtime": 100 }, "create_final_community_reports": { "row_range": [ 1, - 2000 + 2500 ], "nan_allowed_columns": [ "community_id", @@ -100,7 +100,7 @@ "create_final_text_units": { "row_range": [ 1, - 2000 + 2500 ], "nan_allowed_columns": [ "relationship_ids", @@ -112,7 +112,7 @@ "create_final_documents": { "row_range": [ 1, - 2000 + 2500 ], "subworkflows": 1, "max_runtime": 100