diff --git a/.semversioner/next-release/patch-20240925214723888952.json b/.semversioner/next-release/patch-20240925214723888952.json new file mode 100644 index 0000000000..6b50281476 --- /dev/null +++ b/.semversioner/next-release/patch-20240925214723888952.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Add embeddings to subflow." +} diff --git a/graphrag/index/workflows/v1/create_final_text_units.py b/graphrag/index/workflows/v1/create_final_text_units.py index 4fa91b4488..1117a36364 100644 --- a/graphrag/index/workflows/v1/create_final_text_units.py +++ b/graphrag/index/workflows/v1/create_final_text_units.py @@ -21,12 +21,8 @@ def build_steps( """ base_text_embed = config.get("text_embed", {}) text_unit_text_embed_config = config.get("text_unit_text_embed", base_text_embed) - covariates_enabled = config.get("covariates_enabled", False) skip_text_unit_embedding = config.get("skip_text_unit_embedding", False) - is_using_vector_store = ( - text_unit_text_embed_config.get("strategy", {}).get("vector_store", None) - is not None - ) + covariates_enabled = config.get("covariates_enabled", False) others = [ "workflow:create_final_entities", @@ -37,8 +33,10 @@ def build_steps( return [ { - "verb": "create_final_text_units_pre_embedding", + "verb": "create_final_text_units", "args": { + "skip_embedding": skip_text_unit_embedding, + "text_embed": text_unit_text_embed_config, "covariates_enabled": covariates_enabled, }, "input": { @@ -46,35 +44,4 @@ def build_steps( "others": others, }, }, - # Text-Embed after final aggregations - { - "id": "embedded_text_units", - "verb": "text_embed", - "enabled": not skip_text_unit_embedding, - "args": { - "column": config.get("column", "text"), - "to": config.get("to", "text_embedding"), - **text_unit_text_embed_config, - }, - }, - { - "verb": "select", - "args": { - # Final select to get output in the correct shape - "columns": [ - "id", - "text", - *( - [] - if (skip_text_unit_embedding or is_using_vector_store) - else ["text_embedding"] - ), - "n_tokens", - "document_ids", - "entity_ids", - "relationship_ids", - *([] if not covariates_enabled else ["covariate_ids"]), - ], - }, - }, ] diff --git a/graphrag/index/workflows/v1/subflows/__init__.py b/graphrag/index/workflows/v1/subflows/__init__.py index 232e646bcb..d8d201f389 100644 --- a/graphrag/index/workflows/v1/subflows/__init__.py +++ b/graphrag/index/workflows/v1/subflows/__init__.py @@ -12,7 +12,7 @@ from .create_final_relationships import ( create_final_relationships, ) -from .create_final_text_units_pre_embedding import create_final_text_units_pre_embedding +from .create_final_text_units import create_final_text_units __all__ = [ "create_base_documents", @@ -22,5 +22,5 @@ "create_final_documents", "create_final_nodes", "create_final_relationships", - "create_final_text_units_pre_embedding", + "create_final_text_units", ] diff --git a/graphrag/index/workflows/v1/subflows/create_final_text_units_pre_embedding.py b/graphrag/index/workflows/v1/subflows/create_final_text_units.py similarity index 63% rename from graphrag/index/workflows/v1/subflows/create_final_text_units_pre_embedding.py rename to graphrag/index/workflows/v1/subflows/create_final_text_units.py index 49ebd81986..f9b6d29f92 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_text_units_pre_embedding.py +++ b/graphrag/index/workflows/v1/subflows/create_final_text_units.py @@ -1,25 +1,35 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""All the steps to transform before we embed the text units.""" +"""All the steps to transform the text units.""" from typing import cast import pandas as pd -from datashaper.engine.verbs.verb_input import VerbInput -from datashaper.engine.verbs.verbs_mapping import verb -from datashaper.table_store.types import Table, VerbResult, create_verb_result +from datashaper import ( + Table, + VerbCallbacks, + VerbInput, + VerbResult, + create_verb_result, + verb, +) +from graphrag.index.cache import PipelineCache +from graphrag.index.verbs.text.embed.text_embed import text_embed_df -@verb( - name="create_final_text_units_pre_embedding", treats_input_tables_as_immutable=True -) -def create_final_text_units_pre_embedding( + +@verb(name="create_final_text_units", treats_input_tables_as_immutable=True) +async def create_final_text_units( input: VerbInput, + callbacks: VerbCallbacks, + cache: PipelineCache, + text_embed: dict, + skip_embedding: bool = False, covariates_enabled: bool = False, **_kwargs: dict, ) -> VerbResult: - """All the steps to transform before we embed the text units.""" + """All the steps to transform the text units.""" table = cast(pd.DataFrame, input.get_input()) others = input.get_others() @@ -43,7 +53,33 @@ def create_final_text_units_pre_embedding( aggregated = final_joined.groupby("id", sort=False).agg("first").reset_index() - return create_verb_result(cast(Table, aggregated)) + if not skip_embedding: + aggregated = await text_embed_df( + aggregated, + callbacks, + cache, + column="text", + strategy=text_embed["strategy"], + to="text_embedding", + ) + + is_using_vector_store = ( + text_embed.get("strategy", {}).get("vector_store", None) is not None + ) + + final = aggregated[ + [ + "id", + "text", + *([] if (skip_embedding or is_using_vector_store) else ["text_embedding"]), + "n_tokens", + "document_ids", + "entity_ids", + "relationship_ids", + *([] if not covariates_enabled else ["covariate_ids"]), + ] + ] + return create_verb_result(cast(Table, final)) def _entities(df: pd.DataFrame) -> pd.DataFrame: diff --git a/tests/fixtures/min-csv/config.json b/tests/fixtures/min-csv/config.json index 635bf9e5cd..684062773e 100644 --- a/tests/fixtures/min-csv/config.json +++ b/tests/fixtures/min-csv/config.json @@ -105,7 +105,7 @@ "relationship_ids", "entity_ids" ], - "subworkflows": 2, + "subworkflows": 1, "max_runtime": 100 }, "create_base_documents": { diff --git a/tests/fixtures/text/config.json b/tests/fixtures/text/config.json index 30c01be5d6..02f803dea2 100644 --- a/tests/fixtures/text/config.json +++ b/tests/fixtures/text/config.json @@ -122,7 +122,7 @@ "relationship_ids", "entity_ids" ], - "subworkflows": 2, + "subworkflows": 1, "max_runtime": 100 }, "create_base_documents": { diff --git a/tests/verbs/test_create_final_text_units.py b/tests/verbs/test_create_final_text_units.py index ea2a148cc9..a64c9ff5a5 100644 --- a/tests/verbs/test_create_final_text_units.py +++ b/tests/verbs/test_create_final_text_units.py @@ -71,3 +71,35 @@ async def test_create_final_text_units_no_covariates(): expected, ["id", "text", "n_tokens", "document_ids", "entity_ids", "relationship_ids"], ) + + +async def test_create_final_text_units_with_embeddings(): + input_tables = load_input_tables([ + "workflow:create_base_text_units", + "workflow:create_final_entities", + "workflow:create_final_relationships", + "workflow:create_final_covariates", + ]) + expected = load_expected(workflow_name) + + config = get_config_for_workflow(workflow_name) + + config["covariates_enabled"] = True + config["skip_text_unit_embedding"] = False + # default config has a detailed standard embed config + # just override the strategy to mock so the rest of the required parameters are in place + config["text_unit_text_embed"]["strategy"]["type"] = "mock" + + steps = remove_disabled_steps(build_steps(config)) + + actual = await get_workflow_output( + input_tables, + { + "steps": steps, + }, + ) + + assert "text_embedding" in actual.columns + assert len(actual.columns) == len(expected.columns) + 1 + # the mock impl returns an array of 3 floats for each embedding + assert len(actual["text_embedding"][0]) == 3