Skip to content

Commit

Permalink
Basic search implementation (#1563)
Browse files Browse the repository at this point in the history
* basic search implementation

* basic streaming functionality

* format check

* check fix

* release change

* Chore/gleanings any encoding (#1569)

* Make claims and entities independent of encoding

* Semver

* Change semver release type

---------

Co-authored-by: Alonso Guevara <[email protected]>
  • Loading branch information
gaudyb and AlonsoGuevara authored Jan 2, 2025
1 parent 5f9ad0d commit 185f513
Show file tree
Hide file tree
Showing 22 changed files with 915 additions and 198 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/minor-20241227205339264730.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "minor",
"description": "new search implemented as a new option for the api"
}
4 changes: 4 additions & 0 deletions graphrag/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from graphrag.api.index import build_index
from graphrag.api.prompt_tune import generate_indexing_prompts
from graphrag.api.query import (
basic_search,
basic_search_streaming,
drift_search,
global_search,
global_search_streaming,
Expand All @@ -27,6 +29,8 @@
"local_search",
"local_search_streaming",
"drift_search",
"basic_search",
"basic_search_streaming",
# prompt tuning API
"DocSelectionType",
"generate_indexing_prompts",
Expand Down
105 changes: 105 additions & 0 deletions graphrag/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
from graphrag.index.config.embeddings import (
community_full_content_embedding,
entity_description_embedding,
text_unit_text_embedding,
)
from graphrag.logger.print_progress import PrintProgressLogger
from graphrag.query.factory import (
get_basic_search_engine,
get_drift_search_engine,
get_global_search_engine,
get_local_search_engine,
Expand Down Expand Up @@ -423,6 +425,109 @@ async def drift_search(
return response, context_data


@validate_call(config={"arbitrary_types_allowed": True})
async def basic_search(
config: GraphRagConfig,
text_units: pd.DataFrame,
query: str,
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
]:
"""Perform a basic search and return the context data and response.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- text_units (pd.DataFrame): A DataFrame containing the final text units (from create_final_text_units.parquet)
- response_type (str): The response type to return.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
vector_store_args = config.embeddings.vector_store
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa

description_embedding_store = _get_embedding_store(
config_args=vector_store_args, # type: ignore
embedding_name=text_unit_text_embedding,
)

prompt = _load_search_prompt(config.root_dir, config.basic_search.prompt)

search_engine = get_basic_search_engine(
config=config,
text_units=read_indexer_text_units(text_units),
text_unit_embeddings=description_embedding_store,
system_prompt=prompt,
)

result: SearchResult = await search_engine.asearch(query=query)
response = result.response
context_data = _reformat_context_data(result.context_data) # type: ignore
return response, context_data


@validate_call(config={"arbitrary_types_allowed": True})
async def basic_search_streaming(
config: GraphRagConfig,
text_units: pd.DataFrame,
query: str,
) -> AsyncGenerator:
"""Perform a local search and return the context data and response via a generator.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- text_units (pd.DataFrame): A DataFrame containing the final text units (from create_final_text_units.parquet)
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
vector_store_args = config.embeddings.vector_store
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa

description_embedding_store = _get_embedding_store(
config_args=vector_store_args, # type: ignore
embedding_name=text_unit_text_embedding,
)

prompt = _load_search_prompt(config.root_dir, config.basic_search.prompt)

search_engine = get_basic_search_engine(
config=config,
text_units=read_indexer_text_units(text_units),
text_unit_embeddings=description_embedding_store,
system_prompt=prompt,
)

search_result = search_engine.astream_search(query=query)

# when streaming results, a context data object is returned as the first result
# and the query response in subsequent tokens
context_data = None
get_context_data = True
async for stream_chunk in search_result:
if get_context_data:
context_data = _reformat_context_data(stream_chunk) # type: ignore
yield context_data
get_context_data = False
else:
yield stream_chunk


def _get_embedding_store(
config_args: dict,
embedding_name: str,
Expand Down
2 changes: 2 additions & 0 deletions graphrag/cli/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from graphrag.prompts.index.entity_extraction import GRAPH_EXTRACTION_PROMPT
from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT
from graphrag.prompts.query.basic_search_system_prompt import BASIC_SEARCH_SYSTEM_PROMPT
from graphrag.prompts.query.drift_search_system_prompt import DRIFT_LOCAL_SYSTEM_PROMPT
from graphrag.prompts.query.global_search_knowledge_system_prompt import (
GENERAL_KNOWLEDGE_INSTRUCTION,
Expand Down Expand Up @@ -60,6 +61,7 @@ def initialize_project_at(path: Path) -> None:
"global_search_reduce_system_prompt": REDUCE_SYSTEM_PROMPT,
"global_search_knowledge_system_prompt": GENERAL_KNOWLEDGE_INSTRUCTION,
"local_search_system_prompt": LOCAL_SEARCH_SYSTEM_PROMPT,
"basic_search_system_prompt": BASIC_SEARCH_SYSTEM_PROMPT,
"question_gen_system_prompt": QUESTION_SYSTEM_PROMPT,
}

Expand Down
16 changes: 15 additions & 1 deletion graphrag/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class SearchType(Enum):
LOCAL = "local"
GLOBAL = "global"
DRIFT = "drift"
BASIC = "basic"

def __str__(self):
"""Return the string representation of the enum value."""
Expand Down Expand Up @@ -424,7 +425,12 @@ def _query_cli(
] = False,
):
"""Query a knowledge graph index."""
from graphrag.cli.query import run_drift_search, run_global_search, run_local_search
from graphrag.cli.query import (
run_basic_search,
run_drift_search,
run_global_search,
run_local_search,
)

match method:
case SearchType.LOCAL:
Expand Down Expand Up @@ -457,5 +463,13 @@ def _query_cli(
streaming=False, # Drift search does not support streaming (yet)
query=query,
)
case SearchType.BASIC:
run_basic_search(
config_filepath=config,
data_dir=data,
root_dir=root,
streaming=streaming,
query=query,
)
case _:
raise ValueError(INVALID_METHOD_ERROR)
63 changes: 63 additions & 0 deletions graphrag/cli/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,69 @@ def run_drift_search(
return response, context_data


def run_basic_search(
config_filepath: Path | None,
data_dir: Path | None,
root_dir: Path,
streaming: bool,
query: str,
):
"""Perform a basics search with a given query.
Loads index files required for basic search and calls the Query API.
"""
root = root_dir.resolve()
config = load_config(root, config_filepath)
config.storage.base_dir = str(data_dir) if data_dir else config.storage.base_dir
resolve_paths(config)

dataframe_dict = _resolve_output_files(
config=config,
output_list=[
"create_final_text_units.parquet",
],
)
final_text_units: pd.DataFrame = dataframe_dict["create_final_text_units"]

print(streaming) # noqa: T201

# # call the Query API
if streaming:

async def run_streaming_search():
full_response = ""
context_data = None
get_context_data = True
async for stream_chunk in api.basic_search_streaming(
config=config,
text_units=final_text_units,
query=query,
):
if get_context_data:
context_data = stream_chunk
get_context_data = False
else:
full_response += stream_chunk
print(stream_chunk, end="") # noqa: T201
sys.stdout.flush() # flush output buffer to display text immediately
print() # noqa: T201
return full_response, context_data

return asyncio.run(run_streaming_search())
# not streaming
response, context_data = asyncio.run(
api.basic_search(
config=config,
text_units=final_text_units,
query=query,
)
)
logger.success(f"Basic Search Response:\n{response}")
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
# External users should use the API directly to get the response and context data.
return response, context_data


def _resolve_output_files(
config: GraphRagConfig,
output_list: list[str],
Expand Down
25 changes: 25 additions & 0 deletions graphrag/config/create_graphrag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from graphrag.config.input_models.graphrag_config_input import GraphRagConfigInput
from graphrag.config.input_models.llm_config_input import LLMConfigInput
from graphrag.config.models.basic_search_config import BasicSearchConfig
from graphrag.config.models.cache_config import CacheConfig
from graphrag.config.models.chunking_config import ChunkingConfig, ChunkStrategyType
from graphrag.config.models.claim_extraction_config import ClaimExtractionConfig
Expand Down Expand Up @@ -636,6 +637,28 @@ def hydrate_parallelization_params(
or defs.DRIFT_LOCAL_SEARCH_LLM_MAX_TOKENS,
)

with (
reader.use(values.get("basic_search")),
reader.envvar_prefix(Section.basic_search),
):
basic_search_model = BasicSearchConfig(
prompt=reader.str("prompt") or None,
text_unit_prop=reader.float("text_unit_prop")
or defs.BASIC_SEARCH_TEXT_UNIT_PROP,
conversation_history_max_turns=reader.int(
"conversation_history_max_turns"
)
or defs.BASIC_SEARCH_CONVERSATION_HISTORY_MAX_TURNS,
temperature=reader.float("llm_temperature")
or defs.BASIC_SEARCH_LLM_TEMPERATURE,
top_p=reader.float("llm_top_p") or defs.BASIC_SEARCH_LLM_TOP_P,
n=reader.int("llm_n") or defs.BASIC_SEARCH_LLM_N,
max_tokens=reader.int(Fragment.max_tokens)
or defs.BASIC_SEARCH_MAX_TOKENS,
llm_max_tokens=reader.int("llm_max_tokens")
or defs.BASIC_SEARCH_LLM_MAX_TOKENS,
)

skip_workflows = reader.list("skip_workflows") or []

return GraphRagConfig(
Expand Down Expand Up @@ -663,6 +686,7 @@ def hydrate_parallelization_params(
local_search=local_search_model,
global_search=global_search_model,
drift_search=drift_search_model,
basic_search=basic_search_model,
)


Expand Down Expand Up @@ -731,6 +755,7 @@ class Section(str, Enum):
local_search = "LOCAL_SEARCH"
global_search = "GLOBAL_SEARCH"
drift_search = "DRIFT_SEARCH"
basic_search = "BASIC_SEARCH"


def _is_azure(llm_type: LLMType | None) -> bool:
Expand Down
9 changes: 9 additions & 0 deletions graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,12 @@
DRIFT_LOCAL_SEARCH_LLM_MAX_TOKENS = 2000

DRIFT_N_DEPTH = 3

# Basic Search
BASIC_SEARCH_TEXT_UNIT_PROP = 0.5
BASIC_SEARCH_CONVERSATION_HISTORY_MAX_TURNS = 5
BASIC_SEARCH_MAX_TOKENS = 12_000
BASIC_SEARCH_LLM_TEMPERATURE = 0
BASIC_SEARCH_LLM_TOP_P = 1
BASIC_SEARCH_LLM_N = 1
BASIC_SEARCH_LLM_MAX_TOKENS = 2000
3 changes: 3 additions & 0 deletions graphrag/config/init_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@
drift_search:
prompt: "prompts/drift_search_system_prompt.txt"
basic_search:
prompt: "prompts/basic_search_system_prompt.txt"
"""

INIT_DOTENV = """\
Expand Down
15 changes: 15 additions & 0 deletions graphrag/config/input_models/basic_search_config_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Parameterization settings for the default configuration."""

from typing_extensions import NotRequired, TypedDict


class BasicSearchConfigInput(TypedDict):
"""The default configuration section for Cache."""

text_unit_prop: NotRequired[float | str | None]
conversation_history_max_turns: NotRequired[int | str | None]
max_tokens: NotRequired[int | str | None]
llm_max_tokens: NotRequired[int | str | None]
4 changes: 4 additions & 0 deletions graphrag/config/input_models/graphrag_config_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

from typing_extensions import NotRequired

from graphrag.config.input_models.basic_search_config_input import (
BasicSearchConfigInput,
)
from graphrag.config.input_models.cache_config_input import CacheConfigInput
from graphrag.config.input_models.chunking_config_input import ChunkingConfigInput
from graphrag.config.input_models.claim_extraction_config_input import (
Expand Down Expand Up @@ -61,3 +64,4 @@ class GraphRagConfigInput(LLMConfigInput):
skip_workflows: NotRequired[list[str] | str | None]
local_search: NotRequired[LocalSearchConfigInput | None]
global_search: NotRequired[GlobalSearchConfigInput | None]
basic_search: NotRequired[BasicSearchConfigInput | None]
Loading

0 comments on commit 185f513

Please sign in to comment.