Skip to content

Commit

Permalink
Revisit create final text units (#1216)
Browse files Browse the repository at this point in the history
* Add embeddings to collapsed subflow

* Semver

* Fix smoke tests
  • Loading branch information
natoverse authored Sep 25, 2024
1 parent 73e709b commit 3217013
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 51 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20240925214723888952.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Add embeddings to subflow."
}
41 changes: 4 additions & 37 deletions graphrag/index/workflows/v1/create_final_text_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,8 @@ def build_steps(
"""
base_text_embed = config.get("text_embed", {})
text_unit_text_embed_config = config.get("text_unit_text_embed", base_text_embed)
covariates_enabled = config.get("covariates_enabled", False)
skip_text_unit_embedding = config.get("skip_text_unit_embedding", False)
is_using_vector_store = (
text_unit_text_embed_config.get("strategy", {}).get("vector_store", None)
is not None
)
covariates_enabled = config.get("covariates_enabled", False)

others = [
"workflow:create_final_entities",
Expand All @@ -37,44 +33,15 @@ def build_steps(

return [
{
"verb": "create_final_text_units_pre_embedding",
"verb": "create_final_text_units",
"args": {
"skip_embedding": skip_text_unit_embedding,
"text_embed": text_unit_text_embed_config,
"covariates_enabled": covariates_enabled,
},
"input": {
"source": "workflow:create_base_text_units",
"others": others,
},
},
# Text-Embed after final aggregations
{
"id": "embedded_text_units",
"verb": "text_embed",
"enabled": not skip_text_unit_embedding,
"args": {
"column": config.get("column", "text"),
"to": config.get("to", "text_embedding"),
**text_unit_text_embed_config,
},
},
{
"verb": "select",
"args": {
# Final select to get output in the correct shape
"columns": [
"id",
"text",
*(
[]
if (skip_text_unit_embedding or is_using_vector_store)
else ["text_embedding"]
),
"n_tokens",
"document_ids",
"entity_ids",
"relationship_ids",
*([] if not covariates_enabled else ["covariate_ids"]),
],
},
},
]
4 changes: 2 additions & 2 deletions graphrag/index/workflows/v1/subflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .create_final_relationships import (
create_final_relationships,
)
from .create_final_text_units_pre_embedding import create_final_text_units_pre_embedding
from .create_final_text_units import create_final_text_units

__all__ = [
"create_base_documents",
Expand All @@ -22,5 +22,5 @@
"create_final_documents",
"create_final_nodes",
"create_final_relationships",
"create_final_text_units_pre_embedding",
"create_final_text_units",
]
Original file line number Diff line number Diff line change
@@ -1,25 +1,35 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""All the steps to transform before we embed the text units."""
"""All the steps to transform the text units."""

from typing import cast

import pandas as pd
from datashaper.engine.verbs.verb_input import VerbInput
from datashaper.engine.verbs.verbs_mapping import verb
from datashaper.table_store.types import Table, VerbResult, create_verb_result
from datashaper import (
Table,
VerbCallbacks,
VerbInput,
VerbResult,
create_verb_result,
verb,
)

from graphrag.index.cache import PipelineCache
from graphrag.index.verbs.text.embed.text_embed import text_embed_df

@verb(
name="create_final_text_units_pre_embedding", treats_input_tables_as_immutable=True
)
def create_final_text_units_pre_embedding(

@verb(name="create_final_text_units", treats_input_tables_as_immutable=True)
async def create_final_text_units(
input: VerbInput,
callbacks: VerbCallbacks,
cache: PipelineCache,
text_embed: dict,
skip_embedding: bool = False,
covariates_enabled: bool = False,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform before we embed the text units."""
"""All the steps to transform the text units."""
table = cast(pd.DataFrame, input.get_input())
others = input.get_others()

Expand All @@ -43,7 +53,33 @@ def create_final_text_units_pre_embedding(

aggregated = final_joined.groupby("id", sort=False).agg("first").reset_index()

return create_verb_result(cast(Table, aggregated))
if not skip_embedding:
aggregated = await text_embed_df(
aggregated,
callbacks,
cache,
column="text",
strategy=text_embed["strategy"],
to="text_embedding",
)

is_using_vector_store = (
text_embed.get("strategy", {}).get("vector_store", None) is not None
)

final = aggregated[
[
"id",
"text",
*([] if (skip_embedding or is_using_vector_store) else ["text_embedding"]),
"n_tokens",
"document_ids",
"entity_ids",
"relationship_ids",
*([] if not covariates_enabled else ["covariate_ids"]),
]
]
return create_verb_result(cast(Table, final))


def _entities(df: pd.DataFrame) -> pd.DataFrame:
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/min-csv/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
"relationship_ids",
"entity_ids"
],
"subworkflows": 2,
"subworkflows": 1,
"max_runtime": 100
},
"create_base_documents": {
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/text/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@
"relationship_ids",
"entity_ids"
],
"subworkflows": 2,
"subworkflows": 1,
"max_runtime": 100
},
"create_base_documents": {
Expand Down
32 changes: 32 additions & 0 deletions tests/verbs/test_create_final_text_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,35 @@ async def test_create_final_text_units_no_covariates():
expected,
["id", "text", "n_tokens", "document_ids", "entity_ids", "relationship_ids"],
)


async def test_create_final_text_units_with_embeddings():
input_tables = load_input_tables([
"workflow:create_base_text_units",
"workflow:create_final_entities",
"workflow:create_final_relationships",
"workflow:create_final_covariates",
])
expected = load_expected(workflow_name)

config = get_config_for_workflow(workflow_name)

config["covariates_enabled"] = True
config["skip_text_unit_embedding"] = False
# default config has a detailed standard embed config
# just override the strategy to mock so the rest of the required parameters are in place
config["text_unit_text_embed"]["strategy"]["type"] = "mock"

steps = remove_disabled_steps(build_steps(config))

actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)

assert "text_embedding" in actual.columns
assert len(actual.columns) == len(expected.columns) + 1
# the mock impl returns an array of 3 floats for each embedding
assert len(actual["text_embedding"][0]) == 3

0 comments on commit 3217013

Please sign in to comment.