Skip to content

Commit

Permalink
Move workflow list to conditional construction
Browse files Browse the repository at this point in the history
  • Loading branch information
natoverse committed Jan 3, 2025
1 parent ffa929c commit 575309d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 17 deletions.
23 changes: 23 additions & 0 deletions graphrag/api/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ async def build_index(
if memory_profile:
log.warning("New pipeline does not yet support memory profiling.")

workflows = _get_workflow_list(config)
async for output in run_workflows(
workflows,
config,
cache=pipeline_cache,
callbacks=callbacks,
Expand All @@ -87,3 +89,24 @@ async def build_index(
progress_logger.info(str(output.result))

return outputs


# TODO: this is staging for being able to select a set of default workflows based on config and API params
# right now the primary test we do is whether to include claim extraction or not.
# this will eventually move into config as a list, populated via CLI params.
def _get_workflow_list(config: GraphRagConfig) -> list[str]:
"""Get the list of workflows from the config."""
return [
"create_base_text_units",
"create_final_documents",
"extract_graph",
"compute_communities",
"create_final_entities",
"create_final_relationships",
"create_final_nodes",
"create_final_communities",
*(["create_final_covariates"] if config.claim_extraction.enabled else []),
"create_final_text_units",
"create_final_community_reports",
"generate_text_embeddings",
]
22 changes: 5 additions & 17 deletions graphrag/index/run/run_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,6 @@

log = logging.getLogger(__name__)


default_workflows = [
"create_base_text_units",
"create_final_documents",
"extract_graph",
"compute_communities",
"create_final_entities",
"create_final_relationships",
"create_final_nodes",
"create_final_communities",
"create_final_covariates",
"create_final_text_units",
"create_final_community_reports",
"generate_text_embeddings",
]

# these are transient outputs written to storage for downstream workflow use
# they are not required after indexing, so we'll clean them up at the end for clarity
# (unless snapshots.transient is set!)
Expand All @@ -67,6 +51,7 @@


async def run_workflows(
workflows: list[str],
config: GraphRagConfig,
cache: PipelineCache | None = None,
callbacks: list[WorkflowCallbacks] | None = None,
Expand Down Expand Up @@ -115,6 +100,7 @@ async def run_workflows(
# Run the pipeline on the new documents
tables_dict = {}
async for table in _run_workflows(
workflows=workflows,
config=config,
dataset=delta_dataset.new_inputs,
cache=cache,
Expand All @@ -140,6 +126,7 @@ async def run_workflows(
progress_logger.info("Running standard indexing.")

async for table in _run_workflows(
workflows=workflows,
config=config,
dataset=dataset,
cache=cache,
Expand All @@ -151,6 +138,7 @@ async def run_workflows(


async def _run_workflows(
workflows: list[str],
config: GraphRagConfig,
dataset: pd.DataFrame,
cache: PipelineCache,
Expand All @@ -170,7 +158,7 @@ async def _run_workflows(
await _dump_stats(context.stats, context.storage)
await write_table_to_storage(dataset, "input", context.storage)

for workflow in default_workflows:
for workflow in workflows:
last_workflow = workflow
run_workflow = all_workflows[workflow]
progress = logger.child(workflow, transient=False)
Expand Down

0 comments on commit 575309d

Please sign in to comment.