diff --git a/graphrag/api/index.py b/graphrag/api/index.py index 2ad41c2f40..d473487f0d 100644 --- a/graphrag/api/index.py +++ b/graphrag/api/index.py @@ -70,7 +70,9 @@ async def build_index( if memory_profile: log.warning("New pipeline does not yet support memory profiling.") + workflows = _get_workflow_list(config) async for output in run_workflows( + workflows, config, cache=pipeline_cache, callbacks=callbacks, @@ -87,3 +89,24 @@ async def build_index( progress_logger.info(str(output.result)) return outputs + + +# TODO: this is staging for being able to select a set of default workflows based on config and API params +# right now the primary test we do is whether to include claim extraction or not. +# this will eventually move into config as a list, populated via CLI params. +def _get_workflow_list(config: GraphRagConfig) -> list[str]: + """Get the list of workflows from the config.""" + return [ + "create_base_text_units", + "create_final_documents", + "extract_graph", + "compute_communities", + "create_final_entities", + "create_final_relationships", + "create_final_nodes", + "create_final_communities", + *(["create_final_covariates"] if config.claim_extraction.enabled else []), + "create_final_text_units", + "create_final_community_reports", + "generate_text_embeddings", + ] diff --git a/graphrag/index/run/run_workflows.py b/graphrag/index/run/run_workflows.py index c264462799..fe2deef3a8 100644 --- a/graphrag/index/run/run_workflows.py +++ b/graphrag/index/run/run_workflows.py @@ -38,22 +38,6 @@ log = logging.getLogger(__name__) - -default_workflows = [ - "create_base_text_units", - "create_final_documents", - "extract_graph", - "compute_communities", - "create_final_entities", - "create_final_relationships", - "create_final_nodes", - "create_final_communities", - "create_final_covariates", - "create_final_text_units", - "create_final_community_reports", - "generate_text_embeddings", -] - # these are transient outputs written to storage for downstream workflow use # they are not required after indexing, so we'll clean them up at the end for clarity # (unless snapshots.transient is set!) @@ -67,6 +51,7 @@ async def run_workflows( + workflows: list[str], config: GraphRagConfig, cache: PipelineCache | None = None, callbacks: list[WorkflowCallbacks] | None = None, @@ -115,6 +100,7 @@ async def run_workflows( # Run the pipeline on the new documents tables_dict = {} async for table in _run_workflows( + workflows=workflows, config=config, dataset=delta_dataset.new_inputs, cache=cache, @@ -140,6 +126,7 @@ async def run_workflows( progress_logger.info("Running standard indexing.") async for table in _run_workflows( + workflows=workflows, config=config, dataset=dataset, cache=cache, @@ -151,6 +138,7 @@ async def run_workflows( async def _run_workflows( + workflows: list[str], config: GraphRagConfig, dataset: pd.DataFrame, cache: PipelineCache, @@ -170,7 +158,7 @@ async def _run_workflows( await _dump_stats(context.stats, context.storage) await write_table_to_storage(dataset, "input", context.storage) - for workflow in default_workflows: + for workflow in workflows: last_workflow = workflow run_workflow = all_workflows[workflow] progress = logger.child(workflow, transient=False)