Skip to content

Commit

Permalink
Collapse create base extracted entities (#1235)
Browse files Browse the repository at this point in the history
* Set up base assertions

* Replace entity_extract

* Finish collapsing workflow

* Semver

* Update snoke tests
  • Loading branch information
natoverse authored Oct 1, 2024
1 parent 630679f commit 9070ea5
Show file tree
Hide file tree
Showing 10 changed files with 291 additions and 79 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20240930230641593846.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Collapse entity extraction."
}
41 changes: 36 additions & 5 deletions graphrag/index/verbs/entities/extraction/entity_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,38 @@ async def entity_extract(
entity_types=DEFAULT_ENTITY_TYPES,
**kwargs,
) -> TableContainer:
"""Extract entities from a piece of text."""
source = cast(pd.DataFrame, input.get_input())
output = await entity_extract_df(
source,
cache,
callbacks,
column,
id_column,
to,
strategy,
graph_to,
async_mode,
entity_types,
**kwargs,
)

return TableContainer(table=output)


async def entity_extract_df(
input: pd.DataFrame,
cache: PipelineCache,
callbacks: VerbCallbacks,
column: str,
id_column: str,
to: str,
strategy: dict[str, Any] | None,
graph_to: str | None = None,
async_mode: AsyncType = AsyncType.AsyncIO,
entity_types=DEFAULT_ENTITY_TYPES,
**kwargs,
) -> pd.DataFrame:
"""
Extract entities from a piece of text.
Expand Down Expand Up @@ -135,7 +167,6 @@ async def entity_extract(
log.debug("entity_extract strategy=%s", strategy)
if entity_types is None:
entity_types = DEFAULT_ENTITY_TYPES
output = cast(pd.DataFrame, input.get_input())
strategy = strategy or {}
strategy_exec = _load_strategy(
strategy.get("type", ExtractEntityStrategyType.graph_intelligence)
Expand All @@ -159,7 +190,7 @@ async def run_strategy(row):
return [result.entities, result.graphml_graph]

results = await derive_from_rows(
output,
input,
run_strategy,
callbacks,
scheduling_type=async_mode,
Expand All @@ -176,11 +207,11 @@ async def run_strategy(row):
to_result.append(None)
graph_to_result.append(None)

output[to] = to_result
input[to] = to_result
if graph_to is not None:
output[graph_to] = graph_to_result
input[graph_to] = graph_to_result

return TableContainer(table=output.reset_index(drop=True))
return input.reset_index(drop=True)


def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStrategy:
Expand Down
22 changes: 18 additions & 4 deletions graphrag/index/verbs/graph/merge/merge_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@ def merge_graphs(
edges: dict[str, Any] = DEFAULT_EDGE_OPERATIONS,
**_kwargs,
) -> TableContainer:
"""Merge multiple graphs together. The graphs are expected to be in graphml format. The verb outputs a new column containing the merged graph."""
input_df = cast(pd.DataFrame, input.get_input())
output = merge_graphs_df(input_df, callbacks, column, to, nodes, edges)

return TableContainer(table=output)


def merge_graphs_df(
input: pd.DataFrame,
callbacks: VerbCallbacks,
column: str,
to: str,
nodes: dict[str, Any] = DEFAULT_NODE_OPERATIONS,
edges: dict[str, Any] = DEFAULT_EDGE_OPERATIONS,
) -> pd.DataFrame:
"""
Merge multiple graphs together. The graphs are expected to be in graphml format. The verb outputs a new column containing the merged graph.
Expand Down Expand Up @@ -82,7 +97,6 @@ def merge_graphs(
- __average__: This operation takes the mean of the attribute with the last value seen.
- __multiply__: This operation multiplies the attribute with the last value seen.
"""
input_df = input.get_input()
output = pd.DataFrame()

node_ops = {
Expand All @@ -95,15 +109,15 @@ def merge_graphs(
}

mega_graph = nx.Graph()
num_total = len(input_df)
for graphml in progress_iterable(input_df[column], callbacks.progress, num_total):
num_total = len(input)
for graphml in progress_iterable(input[column], callbacks.progress, num_total):
graph = load_graph(cast(str | nx.Graph, graphml))
merge_nodes(mega_graph, graph, node_ops)
merge_edges(mega_graph, graph, edge_ops)

output[to] = ["\n".join(nx.generate_graphml(mega_graph))]

return TableContainer(table=output)
return output


def merge_nodes(
Expand Down
23 changes: 18 additions & 5 deletions graphrag/index/verbs/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

"""A module containing snapshot method definition."""

from typing import cast

import pandas as pd
from datashaper import TableContainer, VerbInput, verb

from graphrag.index.storage import PipelineStorage
Expand All @@ -17,14 +20,24 @@ async def snapshot(
**_kwargs: dict,
) -> TableContainer:
"""Take a entire snapshot of the tabular data."""
data = input.get_input()
data = cast(pd.DataFrame, input.get_input())

await snapshot_df(data, name, formats, storage)

return TableContainer(table=data)


async def snapshot_df(
input: pd.DataFrame,
name: str,
formats: list[str],
storage: PipelineStorage,
):
"""Take a entire snapshot of the tabular data."""
for fmt in formats:
if fmt == "parquet":
await storage.set(name + ".parquet", data.to_parquet())
await storage.set(name + ".parquet", input.to_parquet())
elif fmt == "json":
await storage.set(
name + ".json", data.to_json(orient="records", lines=True)
name + ".json", input.to_json(orient="records", lines=True)
)

return TableContainer(table=data)
115 changes: 52 additions & 63 deletions graphrag/index/workflows/v1/create_base_extracted_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,76 +20,65 @@ def build_steps(
* `workflow:create_base_text_units`
"""
entity_extraction_config = config.get("entity_extract", {})
graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False
raw_entity_snapshot_enabled = config.get("raw_entity_snapshot", False) or False

return [
column = entity_extraction_config.get("text_column", "chunk")
id_column = entity_extraction_config.get("id_column", "chunk_id")
async_mode = entity_extraction_config.get("async_mode", AsyncType.AsyncIO)
strategy = entity_extraction_config.get("strategy")
num_threads = entity_extraction_config.get("num_threads", 4)
entity_types = entity_extraction_config.get("entity_types")

graph_merge_operations_config = config.get(
"graph_merge_operations",
{
"verb": "entity_extract",
"args": {
**entity_extraction_config,
"column": entity_extraction_config.get("text_column", "chunk"),
"id_column": entity_extraction_config.get("id_column", "chunk_id"),
"async_mode": entity_extraction_config.get(
"async_mode", AsyncType.AsyncIO
),
"to": "entities",
"graph_to": "entity_graph",
"nodes": {
"source_id": {
"operation": "concat",
"delimiter": ", ",
"distinct": True,
},
"description": ({
"operation": "concat",
"separator": "\n",
"distinct": False,
}),
},
"input": {"source": "workflow:create_base_text_units"},
},
{
"verb": "snapshot",
"enabled": raw_entity_snapshot_enabled,
"args": {
"name": "raw_extracted_entities",
"formats": ["json"],
},
},
{
"verb": "merge_graphs",
"args": {
"column": "entity_graph",
"to": "entity_graph",
**config.get(
"graph_merge_operations",
{
"nodes": {
"source_id": {
"operation": "concat",
"delimiter": ", ",
"distinct": True,
},
"description": ({
"operation": "concat",
"separator": "\n",
"distinct": False,
}),
},
"edges": {
"source_id": {
"operation": "concat",
"delimiter": ", ",
"distinct": True,
},
"description": ({
"operation": "concat",
"separator": "\n",
"distinct": False,
}),
"weight": "sum",
},
},
),
"edges": {
"source_id": {
"operation": "concat",
"delimiter": ", ",
"distinct": True,
},
"description": ({
"operation": "concat",
"separator": "\n",
"distinct": False,
}),
"weight": "sum",
},
},
)
nodes = graph_merge_operations_config.get("nodes")
edges = graph_merge_operations_config.get("edges")

graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False
raw_entity_snapshot_enabled = config.get("raw_entity_snapshot", False) or False

return [
{
"verb": "snapshot_rows",
"enabled": graphml_snapshot_enabled,
"verb": "create_base_extracted_entities",
"args": {
"base_name": "merged_graph",
"column": "entity_graph",
"formats": [{"format": "text", "extension": "graphml"}],
"column": column,
"id_column": id_column,
"async_mode": async_mode,
"strategy": strategy,
"num_threads": num_threads,
"entity_types": entity_types,
"nodes": nodes,
"edges": edges,
"raw_entity_snapshot_enabled": raw_entity_snapshot_enabled,
"graphml_snapshot_enabled": graphml_snapshot_enabled,
},
"input": {"source": "workflow:create_base_text_units"},
},
]
2 changes: 2 additions & 0 deletions graphrag/index/workflows/v1/subflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .create_base_documents import create_base_documents
from .create_base_entity_graph import create_base_entity_graph
from .create_base_extracted_entities import create_base_extracted_entities
from .create_base_text_units import create_base_text_units
from .create_final_communities import create_final_communities
from .create_final_community_reports import create_final_community_reports
Expand All @@ -21,6 +22,7 @@
__all__ = [
"create_base_documents",
"create_base_entity_graph",
"create_base_extracted_entities",
"create_base_text_units",
"create_final_communities",
"create_final_community_reports",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""All the steps to extract and format covariates."""

from typing import Any, cast

import pandas as pd
from datashaper import (
AsyncType,
Table,
VerbCallbacks,
VerbInput,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result

from graphrag.index.cache import PipelineCache
from graphrag.index.storage import PipelineStorage
from graphrag.index.verbs.entities.extraction.entity_extract import entity_extract_df
from graphrag.index.verbs.graph.merge.merge_graphs import merge_graphs_df
from graphrag.index.verbs.snapshot import snapshot_df
from graphrag.index.verbs.snapshot_rows import snapshot_rows_df


@verb(name="create_base_extracted_entities", treats_input_tables_as_immutable=True)
async def create_base_extracted_entities(
input: VerbInput,
cache: PipelineCache,
callbacks: VerbCallbacks,
storage: PipelineStorage,
column: str,
id_column: str,
nodes: dict[str, Any],
edges: dict[str, Any],
strategy: dict[str, Any] | None,
async_mode: AsyncType = AsyncType.AsyncIO,
entity_types: list[str] | None = None,
graphml_snapshot_enabled: bool = False,
raw_entity_snapshot_enabled: bool = False,
**kwargs: dict,
) -> VerbResult:
"""All the steps to extract and format covariates."""
source = cast(pd.DataFrame, input.get_input())

entity_graph = await entity_extract_df(
source,
cache,
callbacks,
column=column,
id_column=id_column,
strategy=strategy,
async_mode=async_mode,
entity_types=entity_types,
to="entities",
graph_to="entity_graph",
**kwargs,
)

if raw_entity_snapshot_enabled:
await snapshot_df(
entity_graph,
name="raw_extracted_entities",
storage=storage,
formats=["json"],
)

merged_graph = merge_graphs_df(
entity_graph,
callbacks,
column="entity_graph",
to="entity_graph",
nodes=nodes,
edges=edges,
)

if graphml_snapshot_enabled:
await snapshot_rows_df(
merged_graph,
base_name="merged_graph",
column="entity_graph",
storage=storage,
formats=[{"format": "text", "extension": "graphml"}],
)

return create_verb_result(cast(Table, merged_graph))
Loading

0 comments on commit 9070ea5

Please sign in to comment.