Skip to content

Commit

Permalink
Collapse create_final_nodes (#1171)
Browse files Browse the repository at this point in the history
* Collapse create_final_nodes

* Update smoke tests

* Typo

---------

Co-authored-by: Alonso Guevara <[email protected]>
  • Loading branch information
natoverse and AlonsoGuevara authored Sep 20, 2024
1 parent fb65989 commit f8ab1b3
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 82 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20240920000120463201.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Collapse create-final-nodes."
}
21 changes: 19 additions & 2 deletions graphrag/index/verbs/graph/layout/layout_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,25 @@ def layout_graph(
min_dist: 0.75 # Optional, The min distance to use for the umap algorithm, default: 0.75
```
"""
output_df = cast(pd.DataFrame, input.get_input())
input_df = cast(pd.DataFrame, input.get_input())
output_df = layout_graph_df(
input_df, callbacks, strategy, embeddings_column, graph_column, to, graph_to
)

return TableContainer(table=output_df)


def layout_graph_df(
input_df: pd.DataFrame,
callbacks: VerbCallbacks,
strategy: dict[str, Any],
embeddings_column: str,
graph_column: str,
to: str,
graph_to: str | None = None,
):
"""Apply a layout algorithm to a graph."""
output_df = input_df
num_items = len(output_df)
strategy_type = strategy.get("type", LayoutGraphStrategyType.umap)
strategy_args = {**strategy}
Expand Down Expand Up @@ -93,7 +110,7 @@ def layout_graph(
),
axis=1,
)
return TableContainer(table=output_df)
return output_df


def _run_layout(
Expand Down
83 changes: 5 additions & 78 deletions graphrag/index/workflows/v1/create_final_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,55 +19,6 @@ def build_steps(
"""
snapshot_top_level_nodes = config.get("snapshot_top_level_nodes", False)
layout_graph_enabled = config.get("layout_graph_enabled", True)
_compute_top_level_node_positions = [
{
"verb": "unpack_graph",
"args": {"column": "positioned_graph", "type": "nodes"},
"input": {"source": "laid_out_entity_graph"},
},
{
"verb": "filter",
"args": {
"column": "level",
"criteria": [
{
"type": "value",
"operator": "equals",
"value": config.get("level_for_node_positions", 0),
}
],
},
},
{
"verb": "select",
"args": {"columns": ["id", "x", "y"]},
},
{
"verb": "snapshot",
"enabled": snapshot_top_level_nodes,
"args": {
"name": "top_level_nodes",
"formats": ["json"],
},
},
{
"id": "_compute_top_level_node_positions",
"verb": "rename",
"args": {
"columns": {
"id": "top_level_node_id",
}
},
},
{
"verb": "convert",
"args": {
"column": "top_level_node_id",
"to": "top_level_node_id",
"type": "string",
},
},
]
layout_graph_config = config.get(
"layout_graph",
{
Expand All @@ -76,41 +27,17 @@ def build_steps(
},
},
)
level_for_node_positions = config.get("level_for_node_positions", 0)

return [
{
"id": "laid_out_entity_graph",
"verb": "layout_graph",
"verb": "create_final_nodes",
"args": {
"embeddings_column": "embeddings",
"graph_column": "clustered_graph",
"to": "node_positions",
"graph_to": "positioned_graph",
**layout_graph_config,
"level_for_node_positions": level_for_node_positions,
"snapshot_top_level_nodes": snapshot_top_level_nodes,
},
"input": {"source": "workflow:create_base_entity_graph"},
},
{
"verb": "unpack_graph",
"args": {"column": "positioned_graph", "type": "nodes"},
},
{
"id": "nodes_without_positions",
"verb": "drop",
"args": {"columns": ["x", "y"]},
},
*_compute_top_level_node_positions,
{
"verb": "join",
"args": {
"on": ["id", "top_level_node_id"],
},
"input": {
"source": "nodes_without_positions",
"others": ["_compute_top_level_node_positions"],
},
},
{
"verb": "rename",
"args": {"columns": {"label": "title", "cluster": "community"}},
},
]
2 changes: 2 additions & 0 deletions graphrag/index/workflows/v1/subflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""The Indexing Engine workflows -> subflows package root."""

from .create_final_communities import create_final_communities
from .create_final_nodes import create_final_nodes
from .create_final_relationships_post_embedding import (
create_final_relationships_post_embedding,
)
Expand All @@ -14,6 +15,7 @@

__all__ = [
"create_final_communities",
"create_final_nodes",
"create_final_relationships_post_embedding",
"create_final_relationships_pre_embedding",
"create_final_text_units_pre_embedding",
Expand Down
78 changes: 78 additions & 0 deletions graphrag/index/workflows/v1/subflows/create_final_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""All the steps to transform final nodes."""

from typing import Any, cast

import pandas as pd
from datashaper import (
Table,
VerbCallbacks,
VerbInput,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result

from graphrag.index.verbs.graph.layout.layout_graph import layout_graph_df
from graphrag.index.verbs.graph.unpack import unpack_graph_df


@verb(name="create_final_nodes", treats_input_tables_as_immutable=True)
def create_final_nodes(
input: VerbInput,
callbacks: VerbCallbacks,
strategy: dict[str, Any],
level_for_node_positions: int,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform final nodes."""
table = cast(pd.DataFrame, input.get_input())

laid_out_entity_graph = cast(
pd.DataFrame,
layout_graph_df(
table,
callbacks,
strategy,
embeddings_column="embeddings",
graph_column="clustered_graph",
to="node_positions",
graph_to="positioned_graph",
),
)

nodes = cast(
pd.DataFrame,
unpack_graph_df(
laid_out_entity_graph, callbacks, column="positioned_graph", type="nodes"
),
)

nodes_without_positions = nodes.drop(columns=["x", "y"])

nodes = nodes[nodes["level"] == level_for_node_positions].reset_index(drop=True)
nodes = cast(pd.DataFrame, nodes[["id", "x", "y"]])

# TODO: original workflow saved an optional snapshot of top level nodes
# Combining the verbs loses the `storage` injection, so it would fail
# verb arg: snapshot_top_level_nodes: bool,
# (name: "top_level_nodes", formats: ["json"])

nodes.rename(columns={"id": "top_level_node_id"}, inplace=True)
nodes["top_level_node_id"] = nodes["top_level_node_id"].astype(str)

joined = nodes_without_positions.merge(
nodes,
left_on="id",
right_on="top_level_node_id",
how="inner",
)
joined.rename(columns={"label": "title", "cluster": "community"}, inplace=True)

return create_verb_result(
cast(
Table,
joined,
)
)
2 changes: 1 addition & 1 deletion tests/fixtures/min-csv/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
"community",
"level"
],
"subworkflows": 10,
"subworkflows": 1,
"max_runtime": 10
},
"create_final_communities": {
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/text/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
"community",
"level"
],
"subworkflows": 10,
"subworkflows": 1,
"max_runtime": 10
},
"create_final_communities": {
Expand Down
40 changes: 40 additions & 0 deletions tests/verbs/test_create_final_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

from graphrag.index.workflows.v1.create_final_nodes import (
build_steps,
workflow_name,
)

from .util import (
compare_outputs,
get_config_for_workflow,
get_workflow_output,
load_expected,
load_input_tables,
remove_disabled_steps,
)


async def test_create_final_nodes():
input_tables = load_input_tables([
"workflow:create_base_entity_graph",
])
expected = load_expected(workflow_name)

config = get_config_for_workflow(workflow_name)

# default config turns UMAP off, which translates into false for layout
# we don't have graph embeddings in the test data, so this will fail if True
config["layout_graph_enabled"] = False

steps = remove_disabled_steps(build_steps(config))

actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)

compare_outputs(actual, expected)

0 comments on commit f8ab1b3

Please sign in to comment.