Skip to content

Commit

Permalink
DRIFT Search (#1285)
Browse files Browse the repository at this point in the history
* drift search

* args for drift global query in local search

* accept drift context in search base

* optionally parse embeddings from df when creating CommunityReport

* abstract class for drift context

* pathing for drift config

* drift config

* add defs for drift config

* formatting

* capture generated tokens in token count

* semversion

* Formatting and ruff

* Some algorithmic refactors

* Ruff

* Format

* Use asdict()

* Address comments

* Update smoke tests

* Update smoke tests

* Update smoke tests part 2

---------

Co-authored-by: Julian Whiting <[email protected]>
  • Loading branch information
AlonsoGuevara and j2whiting authored Oct 21, 2024
1 parent e0840a2 commit 8a6d4e6
Show file tree
Hide file tree
Showing 19 changed files with 1,448 additions and 39 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/minor-20241011205050985571.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Added DRIFT graph reasoning query module"
}
25 changes: 25 additions & 0 deletions graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,28 @@
GLOBAL_SEARCH_MAP_MAX_TOKENS = 1000
GLOBAL_SEARCH_REDUCE_MAX_TOKENS = 2_000
GLOBAL_SEARCH_CONCURRENCY = 32

# DRIFT Search

DRIFT_SEARCH_LLM_TEMPERATURE = 0
DRIFT_SEARCH_LLM_TOP_P = 1
DRIFT_SEARCH_LLM_N = 3
DRIFT_SEARCH_MAX_TOKENS = 12_000
DRIFT_SEARCH_DATA_MAX_TOKENS = 12_000
DRIFT_SEARCH_CONCURRENCY = 32

DRIFT_SEARCH_K_FOLLOW_UPS = 20
DRIFT_SEARCH_PRIMER_FOLDS = 5
DRIFT_SEARCH_PRIMER_MAX_TOKENS = 12_000

DRIFT_LOCAL_SEARCH_TEXT_UNIT_PROP = 0.9
DRIFT_LOCAL_SEARCH_COMMUNITY_PROP = 0.1
DRIFT_LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES = 10
DRIFT_LOCAL_SEARCH_TOP_K_RELATIONSHIPS = 10
DRIFT_LOCAL_SEARCH_MAX_TOKENS = 12_000
DRIFT_LOCAL_SEARCH_LLM_TEMPERATURE = 0
DRIFT_LOCAL_SEARCH_LLM_TOP_P = 1
DRIFT_LOCAL_SEARCH_LLM_N = 1
DRIFT_LOCAL_SEARCH_LLM_MAX_TOKENS = 2000

DRIFT_N_DEPTH = 3
2 changes: 2 additions & 0 deletions graphrag/config/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .claim_extraction_config import ClaimExtractionConfig
from .cluster_graph_config import ClusterGraphConfig
from .community_reports_config import CommunityReportsConfig
from .drift_config import DRIFTSearchConfig
from .embed_graph_config import EmbedGraphConfig
from .entity_extraction_config import EntityExtractionConfig
from .global_search_config import GlobalSearchConfig
Expand All @@ -30,6 +31,7 @@
"ClaimExtractionConfig",
"ClusterGraphConfig",
"CommunityReportsConfig",
"DRIFTSearchConfig",
"EmbedGraphConfig",
"EntityExtractionConfig",
"GlobalSearchConfig",
Expand Down
103 changes: 103 additions & 0 deletions graphrag/config/models/drift_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Parameterization settings for the default configuration."""

from pydantic import BaseModel, Field

import graphrag.config.defaults as defs


class DRIFTSearchConfig(BaseModel):
"""The default configuration section for Cache."""

temperature: float = Field(
description="The temperature to use for token generation.",
default=defs.DRIFT_SEARCH_LLM_TEMPERATURE,
)
top_p: float = Field(
description="The top-p value to use for token generation.",
default=defs.DRIFT_SEARCH_LLM_TOP_P,
)
n: int = Field(
description="The number of completions to generate.",
default=defs.DRIFT_SEARCH_LLM_N,
)
max_tokens: int = Field(
description="The maximum context size in tokens.",
default=defs.DRIFT_SEARCH_MAX_TOKENS,
)
data_max_tokens: int = Field(
description="The data llm maximum tokens.",
default=defs.DRIFT_SEARCH_DATA_MAX_TOKENS,
)

concurrency: int = Field(
description="The number of concurrent requests.",
default=defs.DRIFT_SEARCH_CONCURRENCY,
)

drift_k_followups: int = Field(
description="The number of top global results to retrieve.",
default=defs.DRIFT_SEARCH_K_FOLLOW_UPS,
)

primer_folds: int = Field(
description="The number of folds for search priming.",
default=defs.DRIFT_SEARCH_PRIMER_FOLDS,
)

primer_llm_max_tokens: int = Field(
description="The maximum number of tokens for the LLM in primer.",
default=defs.DRIFT_SEARCH_PRIMER_MAX_TOKENS,
)

n_depth: int = Field(
description="The number of drift search steps to take.",
default=defs.DRIFT_N_DEPTH,
)

local_search_text_unit_prop: float = Field(
description="The proportion of search dedicated to text units.",
default=defs.DRIFT_LOCAL_SEARCH_TEXT_UNIT_PROP,
)

local_search_community_prop: float = Field(
description="The proportion of search dedicated to community properties.",
default=defs.DRIFT_LOCAL_SEARCH_COMMUNITY_PROP,
)

local_search_top_k_mapped_entities: int = Field(
description="The number of top K entities to map during local search.",
default=defs.DRIFT_LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES,
)

local_search_top_k_relationships: int = Field(
description="The number of top K relationships to map during local search.",
default=defs.DRIFT_LOCAL_SEARCH_TOP_K_RELATIONSHIPS,
)

local_search_max_data_tokens: int = Field(
description="The maximum context size in tokens for local search.",
default=defs.DRIFT_LOCAL_SEARCH_MAX_TOKENS,
)

local_search_temperature: float = Field(
description="The temperature to use for token generation in local search.",
default=defs.DRIFT_LOCAL_SEARCH_LLM_TEMPERATURE,
)

local_search_top_p: float = Field(
description="The top-p value to use for token generation in local search.",
default=defs.DRIFT_LOCAL_SEARCH_LLM_TOP_P,
)

local_search_n: int = Field(
description="The number of completions to generate in local search.",
default=defs.DRIFT_LOCAL_SEARCH_LLM_N,
)

local_search_llm_max_gen_tokens: int = Field(
description="The maximum number of generated tokens for the LLM in local search.",
default=defs.DRIFT_LOCAL_SEARCH_LLM_MAX_TOKENS,
)
12 changes: 12 additions & 0 deletions graphrag/query/context_builder/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,15 @@ def build_context(
**kwargs,
) -> tuple[str | list[str], dict[str, pd.DataFrame]]:
"""Build the context for the local search mode."""


class DRIFTContextBuilder(ABC):
"""Base class for DRIFT-search context builders."""

@abstractmethod
def build_context(
self,
query: str,
**kwargs,
) -> pd.DataFrame:
"""Build the context for the primer search actions."""
5 changes: 2 additions & 3 deletions graphrag/query/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
from graphrag.query.llm.oai.typing import OpenaiApiType
from graphrag.query.structured_search.base import BaseSearch
from graphrag.query.structured_search.global_search.community_context import (
GlobalCommunityContext,
)
Expand Down Expand Up @@ -107,7 +106,7 @@ def get_local_search_engine(
covariates: dict[str, list[Covariate]],
response_type: str,
description_embedding_store: BaseVectorStore,
) -> BaseSearch:
) -> LocalSearch:
"""Create a local search engine based on data + configuration."""
llm = get_llm(config)
text_embedder = get_text_embedder(config)
Expand Down Expand Up @@ -158,7 +157,7 @@ def get_global_search_engine(
reports: list[CommunityReport],
entities: list[Entity],
response_type: str,
) -> BaseSearch:
) -> GlobalSearch:
"""Create a global search engine based on data + configuration."""
token_encoder = tiktoken.get_encoding(config.encoding_model)
gs_config = config.global_search
Expand Down
3 changes: 2 additions & 1 deletion graphrag/query/indexer_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def read_indexer_reports(
final_community_reports: pd.DataFrame,
final_nodes: pd.DataFrame,
community_level: int,
content_embedding_col: str | None = None,
) -> list[CommunityReport]:
"""Read in the Community Reports from the raw indexing outputs."""
report_df = final_community_reports
Expand All @@ -83,7 +84,7 @@ def read_indexer_reports(
id_col="community",
short_id_col="community",
summary_embedding_col=None,
content_embedding_col=None,
content_embedding_col=content_embedding_col,
)


Expand Down
12 changes: 8 additions & 4 deletions graphrag/query/structured_search/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import Any
from typing import Any, Generic, TypeVar

import pandas as pd
import tiktoken

from graphrag.query.context_builder.builders import (
DRIFTContextBuilder,
GlobalContextBuilder,
LocalContextBuilder,
)
Expand All @@ -34,13 +35,16 @@ class SearchResult:
prompt_tokens: int


class BaseSearch(ABC):
T = TypeVar("T", GlobalContextBuilder, LocalContextBuilder, DRIFTContextBuilder)


class BaseSearch(ABC, Generic[T]):
"""The Base Search implementation."""

def __init__(
self,
llm: BaseLLM,
context_builder: GlobalContextBuilder | LocalContextBuilder,
context_builder: T,
token_encoder: tiktoken.Encoding | None = None,
llm_params: dict[str, Any] | None = None,
context_builder_params: dict[str, Any] | None = None,
Expand Down Expand Up @@ -74,5 +78,5 @@ def astream_search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
) -> AsyncGenerator[str, None]:
) -> AsyncGenerator[str, None] | None:
"""Stream search for the given query."""
4 changes: 4 additions & 0 deletions graphrag/query/structured_search/drift_search/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""DriftSearch module."""
Loading

0 comments on commit 8a6d4e6

Please sign in to comment.