Skip to content

Commit

Permalink
Add new basic pipeline runner
Browse files Browse the repository at this point in the history
  • Loading branch information
natoverse committed Dec 28, 2024
1 parent a2647da commit 2b119e1
Show file tree
Hide file tree
Showing 18 changed files with 680 additions and 62 deletions.
48 changes: 30 additions & 18 deletions graphrag/api/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.create_pipeline_config import create_pipeline_config
from graphrag.index.run import run_pipeline_with_config
from graphrag.index.run.run_workflows import run_workflows
from graphrag.index.typing import PipelineRunResult
from graphrag.logger.base import ProgressLogger

Expand All @@ -27,6 +28,7 @@ async def build_index(
memory_profile: bool = False,
callbacks: list[WorkflowCallbacks] | None = None,
progress_logger: ProgressLogger | None = None,
use_new_pipeline: bool = False,
) -> list[PipelineRunResult]:
"""Run the pipeline with the given configuration.
Expand Down Expand Up @@ -56,7 +58,6 @@ async def build_index(
msg = "Cannot resume and update a run at the same time."
raise ValueError(msg)

pipeline_config = create_pipeline_config(config)
pipeline_cache = (
NoopPipelineCache() if config.cache.type == CacheType.none is None else None
)
Expand All @@ -65,21 +66,32 @@ async def build_index(
callbacks = callbacks or []
callbacks.append(create_pipeline_reporter(config.reporting, None)) # type: ignore
outputs: list[PipelineRunResult] = []
async for output in run_pipeline_with_config(
pipeline_config,
run_id=run_id,
memory_profile=memory_profile,
cache=pipeline_cache,
callbacks=callbacks,
logger=progress_logger,
is_resume_run=is_resume_run,
is_update_run=is_update_run,
):
outputs.append(output)
if progress_logger:
if output.errors and len(output.errors) > 0:
progress_logger.error(output.workflow)
else:
progress_logger.success(output.workflow)
progress_logger.info(str(output.result))

if use_new_pipeline:
await run_workflows(
config,
cache=pipeline_cache,
logger=progress_logger,
run_id=run_id,
)
else:
pipeline_config = create_pipeline_config(config)
async for output in run_pipeline_with_config(
pipeline_config,
run_id=run_id,
memory_profile=memory_profile,
cache=pipeline_cache,
callbacks=callbacks,
logger=progress_logger,
is_resume_run=is_resume_run,
is_update_run=is_update_run,
):
outputs.append(output)
if progress_logger:
if output.errors and len(output.errors) > 0:
progress_logger.error(output.workflow)
else:
progress_logger.success(output.workflow)
progress_logger.info(str(output.result))

return outputs
1 change: 1 addition & 0 deletions graphrag/cli/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def _run_index(
is_resume_run=bool(resume),
memory_profile=memprofile,
progress_logger=progress_logger,
use_new_pipeline=True,
)
)
encountered_errors = any(
Expand Down
42 changes: 42 additions & 0 deletions graphrag/index/config/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

"""A module containing embeddings values."""

from graphrag.config.enums import TextEmbeddingTarget
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.models.text_embedding_config import TextEmbeddingConfig

entity_title_embedding = "entity.title"
entity_description_embedding = "entity.description"
relationship_description_embedding = "relationship.description"
Expand All @@ -27,3 +31,41 @@
community_full_content_embedding,
text_unit_text_embedding,
}


def get_embedded_fields(settings: GraphRagConfig) -> set[str]:
"""Get the fields to embed based on the enum or specifically skipped embeddings."""
match settings.embeddings.target:
case TextEmbeddingTarget.all:
return all_embeddings.difference(settings.embeddings.skip)
case TextEmbeddingTarget.required:
return required_embeddings
case TextEmbeddingTarget.none:
return set()
case _:
msg = f"Unknown embeddings target: {settings.embeddings.target}"
raise ValueError(msg)


def get_embedding_settings(
settings: TextEmbeddingConfig,
vector_store_params: dict | None = None,
) -> dict:
"""Transform GraphRAG config into settings for workflows."""
# TEMP
vector_store_settings = settings.vector_store
if vector_store_settings is None:
return {"strategy": settings.resolved_strategy()}
#
# If we get to this point, settings.vector_store is defined, and there's a specific setting for this embedding.
# settings.vector_store.base contains connection information, or may be undefined
# settings.vector_store.<vector_name> contains the specific settings for this embedding
#
strategy = settings.resolved_strategy() # get the default strategy
strategy.update({
"vector_store": {**(vector_store_params or {}), **vector_store_settings}
}) # update the default strategy with the vector store settings
# This ensures the vector store config is part of the strategy and not the global config
return {
"strategy": strategy,
}
46 changes: 3 additions & 43 deletions graphrag/index/create_pipeline_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
InputFileType,
ReportingType,
StorageType,
TextEmbeddingTarget,
)
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.models.storage_config import StorageConfig
from graphrag.config.models.text_embedding_config import TextEmbeddingConfig
from graphrag.index.config.cache import (
PipelineBlobCacheConfig,
PipelineCacheConfigTypes,
Expand All @@ -25,10 +23,7 @@
PipelineMemoryCacheConfig,
PipelineNoneCacheConfig,
)
from graphrag.index.config.embeddings import (
all_embeddings,
required_embeddings,
)
from graphrag.index.config.embeddings import get_embedded_fields, get_embedding_settings
from graphrag.index.config.input import (
PipelineCSVInputConfig,
PipelineInputConfigTypes,
Expand Down Expand Up @@ -92,7 +87,7 @@ def create_pipeline_config(settings: GraphRagConfig, verbose=False) -> PipelineC
_log_llm_settings(settings)

skip_workflows = settings.skip_workflows
embedded_fields = _get_embedded_fields(settings)
embedded_fields = get_embedded_fields(settings)
covariates_enabled = (
settings.claim_extraction.enabled
and create_final_covariates not in skip_workflows
Expand Down Expand Up @@ -123,19 +118,6 @@ def create_pipeline_config(settings: GraphRagConfig, verbose=False) -> PipelineC
return result


def _get_embedded_fields(settings: GraphRagConfig) -> set[str]:
match settings.embeddings.target:
case TextEmbeddingTarget.all:
return all_embeddings.difference(settings.embeddings.skip)
case TextEmbeddingTarget.required:
return required_embeddings
case TextEmbeddingTarget.none:
return set()
case _:
msg = f"Unknown embeddings target: {settings.embeddings.target}"
raise ValueError(msg)


def _log_llm_settings(settings: GraphRagConfig) -> None:
log.info(
"Using LLM Config %s",
Expand Down Expand Up @@ -189,28 +171,6 @@ def _text_unit_workflows(
]


def _get_embedding_settings(
settings: TextEmbeddingConfig,
vector_store_params: dict | None = None,
) -> dict:
vector_store_settings = settings.vector_store
if vector_store_settings is None:
return {"strategy": settings.resolved_strategy()}
#
# If we get to this point, settings.vector_store is defined, and there's a specific setting for this embedding.
# settings.vector_store.base contains connection information, or may be undefined
# settings.vector_store.<vector_name> contains the specific settings for this embedding
#
strategy = settings.resolved_strategy() # get the default strategy
strategy.update({
"vector_store": {**(vector_store_params or {}), **vector_store_settings}
}) # update the default strategy with the vector store settings
# This ensures the vector store config is part of the strategy and not the global config
return {
"strategy": strategy,
}


def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference]:
return [
PipelineWorkflowReference(
Expand Down Expand Up @@ -307,7 +267,7 @@ def _embeddings_workflows(
name=generate_text_embeddings,
config={
"snapshot_embeddings": settings.snapshots.embeddings,
"text_embed": _get_embedding_settings(settings.embeddings),
"text_embed": get_embedding_settings(settings.embeddings),
"embedded_fields": embedded_fields,
},
),
Expand Down
115 changes: 115 additions & 0 deletions graphrag/index/run/run_workflows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Different methods to run the pipeline."""

import logging
import time

from datashaper import VerbCallbacks
from datashaper.progress.types import Progress

from graphrag.cache.factory import CacheFactory
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.input.factory import create_input
from graphrag.index.run.profiling import _dump_stats
from graphrag.index.run.utils import create_run_context
from graphrag.index.workflows.default_workflows import basic_workflows
from graphrag.logger.base import ProgressLogger
from graphrag.logger.null_progress import NullProgressLogger
from graphrag.storage.factory import StorageFactory

log = logging.getLogger(__name__)


default_workflows = [
"create_base_text_units",
"create_final_documents",
"extract_graph",
"create_final_covariates",
"compute_communities",
"create_final_entities",
"create_final_relationships",
"create_final_nodes",
"create_final_communities",
"create_final_text_units",
"create_final_community_reports",
"generate_text_embeddings",
]


async def run_workflows(
config: GraphRagConfig,
cache: PipelineCache | None = None,
logger: ProgressLogger | None = None,
run_id: str | None = None,
):
"""Run all workflows using a simplified pipeline."""
print("RUNNING NEW PIPELINE")
print(config)

start_time = time.time()

run_id = run_id or time.strftime("%Y%m%d-%H%M%S")
root_dir = config.root_dir or ""
progress_logger = logger or NullProgressLogger()
storage_config = config.storage.model_dump() # type: ignore
storage = StorageFactory().create_storage(
storage_type=storage_config["type"], # type: ignore
kwargs=storage_config,
)
cache_config = config.cache.model_dump() # type: ignore
cache = cache or CacheFactory().create_cache(
cache_type=cache_config["type"], # type: ignore
root_dir=root_dir,
kwargs=cache_config,
)

context = create_run_context(storage=storage, cache=cache, stats=None)

dataset = await create_input(config.input, progress_logger, root_dir)
log.info("Final # of rows loaded: %s", len(dataset))
context.stats.num_documents = len(dataset)

await context.runtime_storage.set("input", dataset)

for workflow in default_workflows:
print("RUNNING WORKFLOW", workflow)
run_workflow = basic_workflows[workflow]
verb_callbacks = DelegatingCallbacks()
work_time = time.time()
await run_workflow(
config,
context,
verb_callbacks,
)
context.stats.workflows[workflow] = {"overall": time.time() - work_time}

context.stats.total_runtime = time.time() - start_time
await _dump_stats(context.stats, context.storage)


class DelegatingCallbacks(VerbCallbacks):
"""TEMP: this is all to wrap into DataShaper callbacks that the flows expect."""

def progress(self, progress: Progress) -> None:
"""Handle when progress occurs."""

def error(
self,
message: str,
cause: BaseException | None = None,
stack: str | None = None,
details: dict | None = None,
) -> None:
"""Handle when an error occurs."""

def warning(self, message: str, details: dict | None = None) -> None:
"""Handle when a warning occurs."""

def log(self, message: str, details: dict | None = None) -> None:
"""Handle when a log occurs."""

def measure(self, name: str, value: float, details: dict | None = None) -> None:
"""Handle when a measurement occurs."""
Loading

0 comments on commit 2b119e1

Please sign in to comment.