-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Collapse create_final_nodes * Update smoke tests * Typo --------- Co-authored-by: Alonso Guevara <[email protected]>
- Loading branch information
1 parent
fb65989
commit f8ab1b3
Showing
8 changed files
with
150 additions
and
82 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
{ | ||
"type": "patch", | ||
"description": "Collapse create-final-nodes." | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
78 changes: 78 additions & 0 deletions
78
graphrag/index/workflows/v1/subflows/create_final_nodes.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Copyright (c) 2024 Microsoft Corporation. | ||
# Licensed under the MIT License | ||
|
||
"""All the steps to transform final nodes.""" | ||
|
||
from typing import Any, cast | ||
|
||
import pandas as pd | ||
from datashaper import ( | ||
Table, | ||
VerbCallbacks, | ||
VerbInput, | ||
verb, | ||
) | ||
from datashaper.table_store.types import VerbResult, create_verb_result | ||
|
||
from graphrag.index.verbs.graph.layout.layout_graph import layout_graph_df | ||
from graphrag.index.verbs.graph.unpack import unpack_graph_df | ||
|
||
|
||
@verb(name="create_final_nodes", treats_input_tables_as_immutable=True) | ||
def create_final_nodes( | ||
input: VerbInput, | ||
callbacks: VerbCallbacks, | ||
strategy: dict[str, Any], | ||
level_for_node_positions: int, | ||
**_kwargs: dict, | ||
) -> VerbResult: | ||
"""All the steps to transform final nodes.""" | ||
table = cast(pd.DataFrame, input.get_input()) | ||
|
||
laid_out_entity_graph = cast( | ||
pd.DataFrame, | ||
layout_graph_df( | ||
table, | ||
callbacks, | ||
strategy, | ||
embeddings_column="embeddings", | ||
graph_column="clustered_graph", | ||
to="node_positions", | ||
graph_to="positioned_graph", | ||
), | ||
) | ||
|
||
nodes = cast( | ||
pd.DataFrame, | ||
unpack_graph_df( | ||
laid_out_entity_graph, callbacks, column="positioned_graph", type="nodes" | ||
), | ||
) | ||
|
||
nodes_without_positions = nodes.drop(columns=["x", "y"]) | ||
|
||
nodes = nodes[nodes["level"] == level_for_node_positions].reset_index(drop=True) | ||
nodes = cast(pd.DataFrame, nodes[["id", "x", "y"]]) | ||
|
||
# TODO: original workflow saved an optional snapshot of top level nodes | ||
# Combining the verbs loses the `storage` injection, so it would fail | ||
# verb arg: snapshot_top_level_nodes: bool, | ||
# (name: "top_level_nodes", formats: ["json"]) | ||
|
||
nodes.rename(columns={"id": "top_level_node_id"}, inplace=True) | ||
nodes["top_level_node_id"] = nodes["top_level_node_id"].astype(str) | ||
|
||
joined = nodes_without_positions.merge( | ||
nodes, | ||
left_on="id", | ||
right_on="top_level_node_id", | ||
how="inner", | ||
) | ||
joined.rename(columns={"label": "title", "cluster": "community"}, inplace=True) | ||
|
||
return create_verb_result( | ||
cast( | ||
Table, | ||
joined, | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright (c) 2024 Microsoft Corporation. | ||
# Licensed under the MIT License | ||
|
||
from graphrag.index.workflows.v1.create_final_nodes import ( | ||
build_steps, | ||
workflow_name, | ||
) | ||
|
||
from .util import ( | ||
compare_outputs, | ||
get_config_for_workflow, | ||
get_workflow_output, | ||
load_expected, | ||
load_input_tables, | ||
remove_disabled_steps, | ||
) | ||
|
||
|
||
async def test_create_final_nodes(): | ||
input_tables = load_input_tables([ | ||
"workflow:create_base_entity_graph", | ||
]) | ||
expected = load_expected(workflow_name) | ||
|
||
config = get_config_for_workflow(workflow_name) | ||
|
||
# default config turns UMAP off, which translates into false for layout | ||
# we don't have graph embeddings in the test data, so this will fail if True | ||
config["layout_graph_enabled"] = False | ||
|
||
steps = remove_disabled_steps(build_steps(config)) | ||
|
||
actual = await get_workflow_output( | ||
input_tables, | ||
{ | ||
"steps": steps, | ||
}, | ||
) | ||
|
||
compare_outputs(actual, expected) |