From be7d3eb189e56f96490a00b34a9101d19379f037 Mon Sep 17 00:00:00 2001
From: Alonso Guevara <alonsog@microsoft.com>
Date: Mon, 23 Sep 2024 16:54:15 -0600
Subject: [PATCH] Remove aggregate_df from final coomunities and final text
 units (#1179)

* Remove aggregate_df from final coomunities and final text units

* Semver

* Ruff and format

* Format

* Format

* Fix tests, ruff and checks

* Remove some leftover prints

* Removed _final_join method
---
 .../patch-20240920221112632172.json           |   4 +
 .../v1/subflows/create_final_communities.py   |  63 +++-----
 .../create_final_text_units_pre_embedding.py  | 144 +++++-------------
 3 files changed, 63 insertions(+), 148 deletions(-)
 create mode 100644 .semversioner/next-release/patch-20240920221112632172.json

diff --git a/.semversioner/next-release/patch-20240920221112632172.json b/.semversioner/next-release/patch-20240920221112632172.json
new file mode 100644
index 0000000000..47a2a6d76f
--- /dev/null
+++ b/.semversioner/next-release/patch-20240920221112632172.json
@@ -0,0 +1,4 @@
+{
+  "type": "patch",
+  "description": "Remove aggregate_df from final coomunities and final text units"
+}
diff --git a/graphrag/index/workflows/v1/subflows/create_final_communities.py b/graphrag/index/workflows/v1/subflows/create_final_communities.py
index 5db80fc6af..2cbe8f6cae 100644
--- a/graphrag/index/workflows/v1/subflows/create_final_communities.py
+++ b/graphrag/index/workflows/v1/subflows/create_final_communities.py
@@ -15,7 +15,6 @@
 from datashaper.table_store.types import VerbResult, create_verb_result
 
 from graphrag.index.verbs.graph.unpack import unpack_graph_df
-from graphrag.index.verbs.overrides.aggregate import aggregate_df
 
 
 @verb(name="create_final_communities", treats_input_tables_as_immutable=True)
@@ -30,54 +29,35 @@ def create_final_communities(
     graph_nodes = unpack_graph_df(table, callbacks, "clustered_graph", "nodes")
     graph_edges = unpack_graph_df(table, callbacks, "clustered_graph", "edges")
 
+    # Merge graph_nodes with graph_edges for both source and target matches
     source_clusters = graph_nodes.merge(
-        graph_edges,
-        left_on="label",
-        right_on="source",
-        how="inner",
+        graph_edges, left_on="label", right_on="source", how="inner"
     )
+
     target_clusters = graph_nodes.merge(
-        graph_edges,
-        left_on="label",
-        right_on="target",
-        how="inner",
+        graph_edges, left_on="label", right_on="target", how="inner"
     )
 
-    concatenated_clusters = pd.concat(
-        [source_clusters, target_clusters], ignore_index=True
-    )
+    # Concatenate the source and target clusters
+    clusters = pd.concat([source_clusters, target_clusters], ignore_index=True)
 
-    # level_x is the left side of the join
-    # level_y is the right side of the join
-    # we only want to keep the clusters that are the same on both sides
-    combined_clusters = concatenated_clusters[
-        concatenated_clusters["level_x"] == concatenated_clusters["level_y"]
+    # Keep only rows where level_x == level_y
+    combined_clusters = clusters[
+        clusters["level_x"] == clusters["level_y"]
     ].reset_index(drop=True)
 
-    cluster_relationships = aggregate_df(
-        cast(Table, combined_clusters),
-        aggregations=[
-            {
-                "column": "id_y",  # this is the id of the edge from the join steps above
-                "to": "relationship_ids",
-                "operation": "array_agg_distinct",
-            },
-            {
-                "column": "source_id_x",
-                "to": "text_unit_ids",
-                "operation": "array_agg_distinct",
-            },
-        ],
-        groupby=[
-            "cluster",
-            "level_x",  # level_x is the left side of the join
-        ],
+    cluster_relationships = (
+        combined_clusters.groupby(["cluster", "level_x"], sort=False)
+        .agg(
+            relationship_ids=("id_y", "unique"), text_unit_ids=("source_id_x", "unique")
+        )
+        .reset_index()
     )
 
-    all_clusters = aggregate_df(
-        graph_nodes,
-        aggregations=[{"column": "cluster", "to": "id", "operation": "any"}],
-        groupby=["cluster", "level"],
+    all_clusters = (
+        graph_nodes.groupby(["cluster", "level"], sort=False)
+        .agg(id=("cluster", "first"))
+        .reset_index()
     )
 
     joined = all_clusters.merge(
@@ -94,14 +74,15 @@ def create_final_communities(
     return create_verb_result(
         cast(
             Table,
-            filtered[
+            filtered.loc[
+                :,
                 [
                     "id",
                     "title",
                     "level",
                     "relationship_ids",
                     "text_unit_ids",
-                ]
+                ],
             ],
         )
     )
diff --git a/graphrag/index/workflows/v1/subflows/create_final_text_units_pre_embedding.py b/graphrag/index/workflows/v1/subflows/create_final_text_units_pre_embedding.py
index 5cb48f0513..49ebd81986 100644
--- a/graphrag/index/workflows/v1/subflows/create_final_text_units_pre_embedding.py
+++ b/graphrag/index/workflows/v1/subflows/create_final_text_units_pre_embedding.py
@@ -5,12 +5,11 @@
 
 from typing import cast
 
+import pandas as pd
 from datashaper.engine.verbs.verb_input import VerbInput
 from datashaper.engine.verbs.verbs_mapping import verb
 from datashaper.table_store.types import Table, VerbResult, create_verb_result
 
-from graphrag.index.verbs.overrides.aggregate import aggregate_df
-
 
 @verb(
     name="create_final_text_units_pre_embedding", treats_input_tables_as_immutable=True
@@ -21,15 +20,15 @@ def create_final_text_units_pre_embedding(
     **_kwargs: dict,
 ) -> VerbResult:
     """All the steps to transform before we embed the text units."""
-    table = input.get_input()
+    table = cast(pd.DataFrame, input.get_input())
     others = input.get_others()
 
-    selected = cast(Table, table[["id", "chunk", "document_ids", "n_tokens"]]).rename(
+    selected = table.loc[:, ["id", "chunk", "document_ids", "n_tokens"]].rename(
         columns={"chunk": "text"}
     )
 
-    final_entities = others[0]
-    final_relationships = others[1]
+    final_entities = cast(pd.DataFrame, others[0])
+    final_relationships = cast(pd.DataFrame, others[1])
     entity_join = _entities(final_entities)
     relationship_join = _relationships(final_relationships)
 
@@ -38,116 +37,47 @@ def create_final_text_units_pre_embedding(
     final_joined = relationship_joined
 
     if covariates_enabled:
-        final_covariates = others[2]
+        final_covariates = cast(pd.DataFrame, others[2])
         covariate_join = _covariates(final_covariates)
         final_joined = _join(relationship_joined, covariate_join)
 
-    aggregated = _final_aggregation(final_joined, covariates_enabled)
-
-    return create_verb_result(aggregated)
-
-
-def _final_aggregation(table, covariates_enabled):
-    aggregations = [
-        {
-            "column": "text",
-            "operation": "any",
-            "to": "text",
-        },
-        {
-            "column": "n_tokens",
-            "operation": "any",
-            "to": "n_tokens",
-        },
-        {
-            "column": "document_ids",
-            "operation": "any",
-            "to": "document_ids",
-        },
-        {
-            "column": "entity_ids",
-            "operation": "any",
-            "to": "entity_ids",
-        },
-        {
-            "column": "relationship_ids",
-            "operation": "any",
-            "to": "relationship_ids",
-        },
-    ]
-    if covariates_enabled:
-        aggregations.append({
-            "column": "covariate_ids",
-            "operation": "any",
-            "to": "covariate_ids",
-        })
-    return aggregate_df(
-        table,
-        aggregations,
-        ["id"],
-    )
+    aggregated = final_joined.groupby("id", sort=False).agg("first").reset_index()
+
+    return create_verb_result(cast(Table, aggregated))
+
 
+def _entities(df: pd.DataFrame) -> pd.DataFrame:
+    selected = df.loc[:, ["id", "text_unit_ids"]]
+    unrolled = selected.explode(["text_unit_ids"]).reset_index(drop=True)
 
-def _entities(table):
-    selected = cast(Table, table[["id", "text_unit_ids"]])
-    unrolled = selected.explode("text_unit_ids").reset_index(drop=True)
-    return aggregate_df(
-        unrolled,
-        [
-            {
-                "column": "id",
-                "operation": "array_agg_distinct",
-                "to": "entity_ids",
-            },
-            {
-                "column": "text_unit_ids",
-                "operation": "any",
-                "to": "id",
-            },
-        ],
-        ["text_unit_ids"],
+    return (
+        unrolled.groupby("text_unit_ids", sort=False)
+        .agg(entity_ids=("id", "unique"))
+        .reset_index()
+        .rename(columns={"text_unit_ids": "id"})
     )
 
 
-def _relationships(table):
-    selected = cast(Table, table[["id", "text_unit_ids"]])
-    unrolled = selected.explode("text_unit_ids").reset_index(drop=True)
-    aggregated = aggregate_df(
-        unrolled,
-        [
-            {
-                "column": "id",
-                "operation": "array_agg_distinct",
-                "to": "relationship_ids",
-            },
-            {
-                "column": "text_unit_ids",
-                "operation": "any",
-                "to": "id",
-            },
-        ],
-        ["text_unit_ids"],
+def _relationships(df: pd.DataFrame) -> pd.DataFrame:
+    selected = df.loc[:, ["id", "text_unit_ids"]]
+    unrolled = selected.explode(["text_unit_ids"]).reset_index(drop=True)
+
+    return (
+        unrolled.groupby("text_unit_ids", sort=False)
+        .agg(relationship_ids=("id", "unique"))
+        .reset_index()
+        .rename(columns={"text_unit_ids": "id"})
     )
-    return aggregated[["id", "relationship_ids"]]
-
-
-def _covariates(table):
-    selected = cast(Table, table[["id", "text_unit_id"]])
-    return aggregate_df(
-        selected,
-        [
-            {
-                "column": "id",
-                "operation": "array_agg_distinct",
-                "to": "covariate_ids",
-            },
-            {
-                "column": "text_unit_id",
-                "operation": "any",
-                "to": "id",
-            },
-        ],
-        ["text_unit_id"],
+
+
+def _covariates(df: pd.DataFrame) -> pd.DataFrame:
+    selected = df.loc[:, ["id", "text_unit_id"]]
+
+    return (
+        selected.groupby("text_unit_id", sort=False)
+        .agg(covariate_ids=("id", "unique"))
+        .reset_index()
+        .rename(columns={"text_unit_id": "id"})
     )