From fbc483e4e57c76e198a828b773af4896486df7ee Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Mon, 23 Sep 2024 13:24:06 -0700 Subject: [PATCH] Collapse create base documents (#1176) * Collapse non-attribute verbs * Include document_column_attributes in collapse * Remove merge_override verb * Semver * Clean up some df/tests --- .../patch-20240920192804408249.json | 4 + graphrag/index/verbs/__init__.py | 3 +- graphrag/index/verbs/overrides/__init__.py | 3 +- graphrag/index/verbs/overrides/merge.py | 78 ----------------- .../workflows/v1/create_base_documents.py | 79 +---------------- .../index/workflows/v1/subflows/__init__.py | 2 + .../v1/subflows/create_base_documents.py | 84 +++++++++++++++++++ tests/fixtures/min-csv/config.json | 2 +- tests/fixtures/text/config.json | 2 +- tests/verbs/test_create_base_documents.py | 58 +++++++++++++ tests/verbs/util.py | 32 ++++--- 11 files changed, 178 insertions(+), 169 deletions(-) create mode 100644 .semversioner/next-release/patch-20240920192804408249.json delete mode 100644 graphrag/index/verbs/overrides/merge.py create mode 100644 graphrag/index/workflows/v1/subflows/create_base_documents.py create mode 100644 tests/verbs/test_create_base_documents.py diff --git a/.semversioner/next-release/patch-20240920192804408249.json b/.semversioner/next-release/patch-20240920192804408249.json new file mode 100644 index 0000000000..25b8f20598 --- /dev/null +++ b/.semversioner/next-release/patch-20240920192804408249.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Collapse create_base_documents." +} diff --git a/graphrag/index/verbs/__init__.py b/graphrag/index/verbs/__init__.py index 379c2a3749..3b37e3cb18 100644 --- a/graphrag/index/verbs/__init__.py +++ b/graphrag/index/verbs/__init__.py @@ -15,7 +15,7 @@ merge_graphs, unpack_graph, ) -from .overrides import aggregate, concat, merge +from .overrides import aggregate, concat from .snapshot import snapshot from .snapshot_rows import snapshot_rows from .spread_json import spread_json @@ -35,7 +35,6 @@ "extract_covariates", "genid", "layout_graph", - "merge", "merge_graphs", "snapshot", "snapshot_rows", diff --git a/graphrag/index/verbs/overrides/__init__.py b/graphrag/index/verbs/overrides/__init__.py index 24b82c1f3e..1c42b34e1a 100644 --- a/graphrag/index/verbs/overrides/__init__.py +++ b/graphrag/index/verbs/overrides/__init__.py @@ -5,6 +5,5 @@ from .aggregate import aggregate from .concat import concat -from .merge import merge -__all__ = ["aggregate", "concat", "merge"] +__all__ = ["aggregate", "concat"] diff --git a/graphrag/index/verbs/overrides/merge.py b/graphrag/index/verbs/overrides/merge.py deleted file mode 100644 index 64684c9828..0000000000 --- a/graphrag/index/verbs/overrides/merge.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing merge and _merge_json methods definition.""" - -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License -import logging -from enum import Enum -from typing import Any, cast - -import pandas as pd -from datashaper import TableContainer, VerbInput, VerbResult, verb -from datashaper.engine.verbs.merge import merge as ds_merge - -log = logging.getLogger(__name__) - - -class MergeStrategyType(str, Enum): - """MergeStrategy class definition.""" - - json = "json" - datashaper = "datashaper" - - def __repr__(self): - """Get a string representation.""" - return f'"{self.value}"' - - -# TODO: This thing is kinda gross -# Also, it diverges from the original aggregate verb, since it doesn't support the same syntax -@verb(name="merge_override") -def merge( - input: VerbInput, - to: str, - columns: list[str], - strategy: MergeStrategyType = MergeStrategyType.datashaper, - delimiter: str = "", - preserveSource: bool = False, # noqa N806 - unhot: bool = False, - prefix: str = "", - **_kwargs: dict, -) -> TableContainer | VerbResult: - """Merge method definition.""" - output: pd.DataFrame - match strategy: - case MergeStrategyType.json: - output = _merge_json(input, to, columns) - filtered_list: list[str] = [] - - for col in output.columns: - try: - columns.index(col) - except ValueError: - log.exception("Column %s not found in input columns", col) - filtered_list.append(col) - - if not preserveSource: - output = cast(Any, output[filtered_list]) - return TableContainer(table=output.reset_index()) - case _: - return ds_merge( - input, to, columns, strategy, delimiter, preserveSource, unhot, prefix - ) - - -def _merge_json( - input: VerbInput, - to: str, - columns: list[str], -) -> pd.DataFrame: - input_table = cast(pd.DataFrame, input.get_input()) - output = input_table - output[to] = output[columns].apply( - lambda row: ({**row}), - axis=1, - ) - return output diff --git a/graphrag/index/workflows/v1/create_base_documents.py b/graphrag/index/workflows/v1/create_base_documents.py index bd7094c64a..9fd08a7ed6 100644 --- a/graphrag/index/workflows/v1/create_base_documents.py +++ b/graphrag/index/workflows/v1/create_base_documents.py @@ -22,84 +22,13 @@ def build_steps( document_attribute_columns = config.get("document_attribute_columns", []) return [ { - "verb": "unroll", - "args": {"column": "document_ids"}, - "input": {"source": "workflow:create_final_text_units"}, - }, - { - "verb": "select", - "args": { - # We only need the chunk id and the document id - "columns": ["id", "document_ids", "text"] - }, - }, - { - "id": "rename_chunk_doc_id", - "verb": "rename", - "args": { - "columns": { - "document_ids": "chunk_doc_id", - "id": "chunk_id", - "text": "chunk_text", - } - }, - }, - { - "verb": "join", - "args": { - # Join the doc id from the chunk onto the original document - "on": ["chunk_doc_id", "id"] - }, - "input": {"source": "rename_chunk_doc_id", "others": [DEFAULT_INPUT_NAME]}, - }, - { - "id": "docs_with_text_units", - "verb": "aggregate_override", - "args": { - "groupby": ["id"], - "aggregations": [ - { - "column": "chunk_id", - "operation": "array_agg", - "to": "text_units", - } - ], - }, - }, - { - "verb": "join", + "verb": "create_base_documents", "args": { - "on": ["id", "id"], - "strategy": "right outer", + "document_attribute_columns": document_attribute_columns, }, "input": { - "source": "docs_with_text_units", - "others": [DEFAULT_INPUT_NAME], - }, - }, - { - "verb": "rename", - "args": {"columns": {"text": "raw_content"}}, - }, - *[ - { - "verb": "convert", - "args": { - "column": column, - "to": column, - "type": "string", - }, - } - for column in document_attribute_columns - ], - { - "verb": "merge_override", - "enabled": len(document_attribute_columns) > 0, - "args": { - "columns": document_attribute_columns, - "strategy": "json", - "to": "attributes", + "source": DEFAULT_INPUT_NAME, + "others": ["workflow:create_final_text_units"], }, }, - {"verb": "convert", "args": {"column": "id", "to": "id", "type": "string"}}, ] diff --git a/graphrag/index/workflows/v1/subflows/__init__.py b/graphrag/index/workflows/v1/subflows/__init__.py index 7ea0346f86..38cd0791d4 100644 --- a/graphrag/index/workflows/v1/subflows/__init__.py +++ b/graphrag/index/workflows/v1/subflows/__init__.py @@ -3,6 +3,7 @@ """The Indexing Engine workflows -> subflows package root.""" +from .create_base_documents import create_base_documents from .create_final_communities import create_final_communities from .create_final_nodes import create_final_nodes from .create_final_relationships_post_embedding import ( @@ -14,6 +15,7 @@ from .create_final_text_units_pre_embedding import create_final_text_units_pre_embedding __all__ = [ + "create_base_documents", "create_final_communities", "create_final_nodes", "create_final_relationships_post_embedding", diff --git a/graphrag/index/workflows/v1/subflows/create_base_documents.py b/graphrag/index/workflows/v1/subflows/create_base_documents.py new file mode 100644 index 0000000000..7329b5ea78 --- /dev/null +++ b/graphrag/index/workflows/v1/subflows/create_base_documents.py @@ -0,0 +1,84 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""All the steps to transform base documents.""" + +from typing import cast + +import pandas as pd +from datashaper import ( + Table, + VerbInput, + verb, +) +from datashaper.table_store.types import VerbResult, create_verb_result + +from graphrag.index.verbs.overrides.aggregate import aggregate_df + + +@verb(name="create_base_documents", treats_input_tables_as_immutable=True) +def create_base_documents( + input: VerbInput, + document_attribute_columns: list[str] | None = None, + **_kwargs: dict, +) -> VerbResult: + """All the steps to transform base documents.""" + source = cast(pd.DataFrame, input.get_input()) + text_units = cast(pd.DataFrame, input.get_others()[0]) + + text_units = cast( + pd.DataFrame, text_units.explode("document_ids")[["id", "document_ids", "text"]] + ) + text_units.rename( + columns={ + "document_ids": "chunk_doc_id", + "id": "chunk_id", + "text": "chunk_text", + }, + inplace=True, + ) + + joined = text_units.merge( + source, + left_on="chunk_doc_id", + right_on="id", + how="inner", + ) + + docs_with_text_units = aggregate_df( + joined, + groupby=["id"], + aggregations=[ + { + "column": "chunk_id", + "operation": "array_agg", + "to": "text_units", + } + ], + ) + + rejoined = docs_with_text_units.merge( + source, + on="id", + how="right", + ) + rejoined.rename(columns={"text": "raw_content"}, inplace=True) + rejoined["id"] = rejoined["id"].astype(str) + + # attribute columns are converted to strings and then collapsed into a single json object + if document_attribute_columns: + for column in document_attribute_columns: + rejoined[column] = rejoined[column].astype(str) + rejoined["attributes"] = rejoined[document_attribute_columns].apply( + lambda row: {**row}, + axis=1, + ) + rejoined.drop(columns=document_attribute_columns, inplace=True) + rejoined.reset_index() + + return create_verb_result( + cast( + Table, + rejoined, + ) + ) diff --git a/tests/fixtures/min-csv/config.json b/tests/fixtures/min-csv/config.json index 0f7b0d4e17..4340232e26 100644 --- a/tests/fixtures/min-csv/config.json +++ b/tests/fixtures/min-csv/config.json @@ -113,7 +113,7 @@ 1, 2000 ], - "subworkflows": 8, + "subworkflows": 1, "max_runtime": 10 }, "create_final_documents": { diff --git a/tests/fixtures/text/config.json b/tests/fixtures/text/config.json index 24b3d1347f..37f5132334 100644 --- a/tests/fixtures/text/config.json +++ b/tests/fixtures/text/config.json @@ -132,7 +132,7 @@ 1, 2000 ], - "subworkflows": 8, + "subworkflows": 1, "max_runtime": 10 }, "create_final_documents": { diff --git a/tests/verbs/test_create_base_documents.py b/tests/verbs/test_create_base_documents.py new file mode 100644 index 0000000000..1e182e26b6 --- /dev/null +++ b/tests/verbs/test_create_base_documents.py @@ -0,0 +1,58 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +from graphrag.index.workflows.v1.create_base_documents import ( + build_steps, + workflow_name, +) + +from .util import ( + compare_outputs, + get_config_for_workflow, + get_workflow_output, + load_expected, + load_input_tables, +) + + +async def test_create_base_documents(): + input_tables = load_input_tables(["workflow:create_final_text_units"]) + expected = load_expected(workflow_name) + + config = get_config_for_workflow(workflow_name) + + steps = build_steps(config) + + actual = await get_workflow_output( + input_tables, + { + "steps": steps, + }, + ) + + compare_outputs(actual, expected) + + +async def test_create_base_documents_with_attribute_columns(): + input_tables = load_input_tables(["workflow:create_final_text_units"]) + expected = load_expected(workflow_name) + + config = get_config_for_workflow(workflow_name) + + config["document_attribute_columns"] = ["title"] + + steps = build_steps(config) + + actual = await get_workflow_output( + input_tables, + { + "steps": steps, + }, + ) + + # we should have dropped "title" and added "attributes" + # our test dataframe does not have attributes, so we'll assert without it + # and separately confirm it is in the output + compare_outputs(actual, expected, columns=["id", "text_units", "raw_content"]) + assert len(actual.columns) == 4 + assert "attributes" in actual.columns diff --git a/tests/verbs/util.py b/tests/verbs/util.py index 21cd9249d2..dcc9c4ea3c 100644 --- a/tests/verbs/util.py +++ b/tests/verbs/util.py @@ -19,6 +19,14 @@ def load_input_tables(inputs: list[str]) -> dict[str, pd.DataFrame]: """Harvest all the referenced input IDs from the workflow being tested and pass them here.""" # stick all the inputs in a map - Workflow looks them up by name input_tables: dict[str, pd.DataFrame] = {} + + # all workflows implicitly receive the `input` source, which is formatted as a dataframe after loading from storage + # we'll simulate that by just loading one of our output parquets and converting back to equivalent dataframe + # so we aren't dealing with storage vagaries (which would become an integration test) + source = pd.read_parquet("tests/verbs/data/create_base_documents.parquet") + source.rename(columns={"raw_content": "text"}, inplace=True) + input_tables["source"] = cast(pd.DataFrame, source[["id", "text", "title"]]) + for input in inputs: # remove the workflow: prefix if it exists, because that is not part of the actual table filename name = input.replace("workflow:", "") @@ -63,18 +71,22 @@ def compare_outputs( """Compare the actual and expected dataframes, optionally specifying columns to compare. This uses assert_series_equal since we are sometimes intentionally omitting columns from the actual output.""" cols = expected.columns if columns is None else columns - try: - assert len(actual) == len(expected) - assert len(actual.columns) == len(cols) - for column in cols: + + assert len(actual) == len( + expected + ), f"Expected: {len(expected)}, Actual: {len(actual)}" + + for column in cols: + assert column in actual.columns + try: # dtypes can differ since the test data is read from parquet and our workflow runs in memory assert_series_equal(actual[column], expected[column], check_dtype=False) - except AssertionError: - print("Expected:") - print(expected.head()) - print("Actual:") - print(actual.head()) - raise + except AssertionError: + print("Expected:") + print(expected[column]) + print("Actual:") + print(actual[columns]) + raise def remove_disabled_steps(