diff --git a/.semversioner/next-release/patch-20250103231659816022.json b/.semversioner/next-release/patch-20250103231659816022.json new file mode 100644 index 0000000000..3a6ade4515 --- /dev/null +++ b/.semversioner/next-release/patch-20250103231659816022.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Simplify callbacks model." +} diff --git a/docs/examples_notebooks/index_migration.ipynb b/docs/examples_notebooks/index_migration.ipynb index 5021fa2cbb..4dee757871 100644 --- a/docs/examples_notebooks/index_migration.ipynb +++ b/docs/examples_notebooks/index_migration.ipynb @@ -207,7 +207,7 @@ "outputs": [], "source": [ "from graphrag.cache.factory import create_cache\n", - "from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks\n", + "from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks\n", "from graphrag.index.flows.generate_text_embeddings import generate_text_embeddings\n", "\n", "# We only need to re-run the embeddings workflow, to ensure that embeddings for all required search fields are in place\n", @@ -219,7 +219,7 @@ "config = workflow.config\n", "text_embed = config.get(\"text_embed\", {})\n", "embedded_fields = config.get(\"embedded_fields\", {})\n", - "callbacks = NoopVerbCallbacks()\n", + "callbacks = NoopWorkflowCallbacks()\n", "cache = create_cache(pipeline_config.cache, PROJECT_DIRECTORY)\n", "\n", "await generate_text_embeddings(\n", diff --git a/graphrag/api/prompt_tune.py b/graphrag/api/prompt_tune.py index 98c1dac3ba..30b7c682cb 100644 --- a/graphrag/api/prompt_tune.py +++ b/graphrag/api/prompt_tune.py @@ -13,7 +13,7 @@ from pydantic import PositiveInt, validate_call -from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.llm.load_llm import load_llm from graphrag.logger.print_progress import PrintProgressLogger @@ -99,7 +99,7 @@ async def generate_indexing_prompts( "prompt_tuning", config.llm, cache=None, - callbacks=NoopVerbCallbacks(), + callbacks=NoopWorkflowCallbacks(), ) if not domain: diff --git a/graphrag/callbacks/blob_workflow_callbacks.py b/graphrag/callbacks/blob_workflow_callbacks.py index 36bd5f9e83..30391ecd8d 100644 --- a/graphrag/callbacks/blob_workflow_callbacks.py +++ b/graphrag/callbacks/blob_workflow_callbacks.py @@ -84,7 +84,7 @@ def _write_log(self, log: dict[str, Any]): # update the blob's block count self._num_blocks += 1 - def on_error( + def error( self, message: str, cause: BaseException | None = None, @@ -100,10 +100,10 @@ def on_error( "details": details, }) - def on_warning(self, message: str, details: dict | None = None): + def warning(self, message: str, details: dict | None = None): """Report a warning.""" self._write_log({"type": "warning", "data": message, "details": details}) - def on_log(self, message: str, details: dict | None = None): + def log(self, message: str, details: dict | None = None): """Report a generic log message.""" self._write_log({"type": "log", "data": message, "details": details}) diff --git a/graphrag/callbacks/console_workflow_callbacks.py b/graphrag/callbacks/console_workflow_callbacks.py index a2ab6ef08a..b1478085de 100644 --- a/graphrag/callbacks/console_workflow_callbacks.py +++ b/graphrag/callbacks/console_workflow_callbacks.py @@ -9,7 +9,7 @@ class ConsoleWorkflowCallbacks(NoopWorkflowCallbacks): """A logger that writes to a console.""" - def on_error( + def error( self, message: str, cause: BaseException | None = None, @@ -19,11 +19,11 @@ def on_error( """Handle when an error occurs.""" print(message, str(cause), stack, details) # noqa T201 - def on_warning(self, message: str, details: dict | None = None): + def warning(self, message: str, details: dict | None = None): """Handle when a warning occurs.""" _print_warning(message) - def on_log(self, message: str, details: dict | None = None): + def log(self, message: str, details: dict | None = None): """Handle when a log message is produced.""" print(message, details) # noqa T201 diff --git a/graphrag/callbacks/delegating_verb_callbacks.py b/graphrag/callbacks/delegating_verb_callbacks.py deleted file mode 100644 index 11687f3a24..0000000000 --- a/graphrag/callbacks/delegating_verb_callbacks.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Contains the DelegatingVerbCallback definition.""" - -from graphrag.callbacks.verb_callbacks import VerbCallbacks -from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks -from graphrag.logger.progress import Progress - - -class DelegatingVerbCallbacks(VerbCallbacks): - """A wrapper that implements VerbCallbacks that delegates to the underlying WorkflowCallbacks.""" - - _workflow_callbacks: WorkflowCallbacks - _name: str - - def __init__(self, name: str, workflow_callbacks: WorkflowCallbacks): - """Create a new instance of DelegatingVerbCallbacks.""" - self._workflow_callbacks = workflow_callbacks - self._name = name - - def progress(self, progress: Progress) -> None: - """Handle when progress occurs.""" - self._workflow_callbacks.on_step_progress(self._name, progress) - - def error( - self, - message: str, - cause: BaseException | None = None, - stack: str | None = None, - details: dict | None = None, - ) -> None: - """Handle when an error occurs.""" - self._workflow_callbacks.on_error(message, cause, stack, details) - - def warning(self, message: str, details: dict | None = None) -> None: - """Handle when a warning occurs.""" - self._workflow_callbacks.on_warning(message, details) - - def log(self, message: str, details: dict | None = None) -> None: - """Handle when a log occurs.""" - self._workflow_callbacks.on_log(message, details) - - def measure(self, name: str, value: float, details: dict | None = None) -> None: - """Handle when a measurement occurs.""" - self._workflow_callbacks.on_measure(name, value, details) diff --git a/graphrag/callbacks/file_workflow_callbacks.py b/graphrag/callbacks/file_workflow_callbacks.py index b3b5ca1963..1f476f8c52 100644 --- a/graphrag/callbacks/file_workflow_callbacks.py +++ b/graphrag/callbacks/file_workflow_callbacks.py @@ -25,7 +25,7 @@ def __init__(self, directory: str): Path(directory) / "logs.json", "a", encoding="utf-8", errors="strict" ) - def on_error( + def error( self, message: str, cause: BaseException | None = None, @@ -50,7 +50,7 @@ def on_error( message = f"{message} details={details}" log.info(message) - def on_warning(self, message: str, details: dict | None = None): + def warning(self, message: str, details: dict | None = None): """Handle when a warning occurs.""" self._out_stream.write( json.dumps( @@ -61,7 +61,7 @@ def on_warning(self, message: str, details: dict | None = None): ) _print_warning(message) - def on_log(self, message: str, details: dict | None = None): + def log(self, message: str, details: dict | None = None): """Handle when a log message is produced.""" self._out_stream.write( json.dumps( diff --git a/graphrag/callbacks/noop_verb_callbacks.py b/graphrag/callbacks/noop_verb_callbacks.py deleted file mode 100644 index 5a2000af67..0000000000 --- a/graphrag/callbacks/noop_verb_callbacks.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Defines the interface for verb callbacks.""" - -from graphrag.callbacks.verb_callbacks import VerbCallbacks -from graphrag.logger.progress import Progress - - -class NoopVerbCallbacks(VerbCallbacks): - """A noop implementation of the verb callbacks.""" - - def __init__(self) -> None: - pass - - def progress(self, progress: Progress) -> None: - """Report a progress update from the verb execution".""" - - def error( - self, - message: str, - cause: BaseException | None = None, - stack: str | None = None, - details: dict | None = None, - ) -> None: - """Report a error from the verb execution.""" - - def warning(self, message: str, details: dict | None = None) -> None: - """Report a warning from verb execution.""" - - def log(self, message: str, details: dict | None = None) -> None: - """Report an informational message from the verb execution.""" - - def measure(self, name: str, value: float) -> None: - """Report a telemetry measurement from the verb execution.""" diff --git a/graphrag/callbacks/noop_workflow_callbacks.py b/graphrag/callbacks/noop_workflow_callbacks.py index 2e8d6b883d..4678338f86 100644 --- a/graphrag/callbacks/noop_workflow_callbacks.py +++ b/graphrag/callbacks/noop_workflow_callbacks.py @@ -3,8 +3,6 @@ """A no-op implementation of WorkflowCallbacks.""" -from typing import Any - from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.logger.progress import Progress @@ -12,22 +10,16 @@ class NoopWorkflowCallbacks(WorkflowCallbacks): """A no-op implementation of WorkflowCallbacks.""" - def on_workflow_start(self, name: str, instance: object) -> None: + def workflow_start(self, name: str, instance: object) -> None: """Execute this callback when a workflow starts.""" - def on_workflow_end(self, name: str, instance: object) -> None: + def workflow_end(self, name: str, instance: object) -> None: """Execute this callback when a workflow ends.""" - def on_step_start(self, step_name: str) -> None: - """Execute this callback every time a step starts.""" - - def on_step_end(self, step_name: str, result: Any) -> None: - """Execute this callback every time a step ends.""" - - def on_step_progress(self, step_name: str, progress: Progress) -> None: + def progress(self, progress: Progress) -> None: """Handle when progress occurs.""" - def on_error( + def error( self, message: str, cause: BaseException | None = None, @@ -36,11 +28,8 @@ def on_error( ) -> None: """Handle when an error occurs.""" - def on_warning(self, message: str, details: dict | None = None) -> None: + def warning(self, message: str, details: dict | None = None) -> None: """Handle when a warning occurs.""" - def on_log(self, message: str, details: dict | None = None) -> None: + def log(self, message: str, details: dict | None = None) -> None: """Handle when a log message occurs.""" - - def on_measure(self, name: str, value: float, details: dict | None = None) -> None: - """Handle when a measurement occurs.""" diff --git a/graphrag/callbacks/progress_workflow_callbacks.py b/graphrag/callbacks/progress_workflow_callbacks.py index 1dc4ada022..cff483187f 100644 --- a/graphrag/callbacks/progress_workflow_callbacks.py +++ b/graphrag/callbacks/progress_workflow_callbacks.py @@ -3,8 +3,6 @@ """A workflow callback manager that emits updates.""" -from typing import Any - from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.logger.base import ProgressLogger from graphrag.logger.progress import Progress @@ -31,23 +29,14 @@ def _push(self, name: str) -> None: def _latest(self) -> ProgressLogger: return self._progress_stack[-1] - def on_workflow_start(self, name: str, instance: object) -> None: + def workflow_start(self, name: str, instance: object) -> None: """Execute this callback when a workflow starts.""" self._push(name) - def on_workflow_end(self, name: str, instance: object) -> None: + def workflow_end(self, name: str, instance: object) -> None: """Execute this callback when a workflow ends.""" self._pop() - def on_step_start(self, step_name: str) -> None: - """Execute this callback every time a step starts.""" - self._push(f"Step {step_name}") - self._latest(Progress(percent=0)) - - def on_step_end(self, step_name: str, result: Any) -> None: - """Execute this callback every time a step ends.""" - self._pop() - - def on_step_progress(self, step_name: str, progress: Progress) -> None: + def progress(self, progress: Progress) -> None: """Handle when progress occurs.""" self._latest(progress) diff --git a/graphrag/callbacks/verb_callbacks.py b/graphrag/callbacks/verb_callbacks.py deleted file mode 100644 index 9489b4cab3..0000000000 --- a/graphrag/callbacks/verb_callbacks.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Defines the interface for verb callbacks.""" - -from typing import Protocol - -from graphrag.logger.progress import Progress - - -class VerbCallbacks(Protocol): - """Provides a way to report status updates from the pipeline.""" - - def progress(self, progress: Progress) -> None: - """Report a progress update from the verb execution".""" - ... - - def error( - self, - message: str, - cause: BaseException | None = None, - stack: str | None = None, - details: dict | None = None, - ) -> None: - """Report a error from the verb execution.""" - ... - - def warning(self, message: str, details: dict | None = None) -> None: - """Report a warning from verb execution.""" - ... - - def log(self, message: str, details: dict | None = None) -> None: - """Report an informational message from the verb execution.""" - ... - - def measure(self, name: str, value: float) -> None: - """Report a telemetry measurement from the verb execution.""" - ... diff --git a/graphrag/callbacks/workflow_callbacks.py b/graphrag/callbacks/workflow_callbacks.py index f1adec6cb6..c01f389ceb 100644 --- a/graphrag/callbacks/workflow_callbacks.py +++ b/graphrag/callbacks/workflow_callbacks.py @@ -3,7 +3,7 @@ """Collection of callbacks that can be used to monitor the workflow execution.""" -from typing import Any, Protocol +from typing import Protocol from graphrag.logger.progress import Progress @@ -15,27 +15,19 @@ class WorkflowCallbacks(Protocol): This base class is a "noop" implementation so that clients may implement just the callbacks they need. """ - def on_workflow_start(self, name: str, instance: object) -> None: + def workflow_start(self, name: str, instance: object) -> None: """Execute this callback when a workflow starts.""" ... - def on_workflow_end(self, name: str, instance: object) -> None: + def workflow_end(self, name: str, instance: object) -> None: """Execute this callback when a workflow ends.""" ... - def on_step_start(self, step_name: str) -> None: - """Execute this callback every time a step starts.""" - ... - - def on_step_end(self, step_name: str, result: Any) -> None: - """Execute this callback every time a step ends.""" - ... - - def on_step_progress(self, step_name: str, progress: Progress) -> None: + def progress(self, progress: Progress) -> None: """Handle when progress occurs.""" ... - def on_error( + def error( self, message: str, cause: BaseException | None = None, @@ -45,14 +37,10 @@ def on_error( """Handle when an error occurs.""" ... - def on_warning(self, message: str, details: dict | None = None) -> None: + def warning(self, message: str, details: dict | None = None) -> None: """Handle when a warning occurs.""" ... - def on_log(self, message: str, details: dict | None = None) -> None: + def log(self, message: str, details: dict | None = None) -> None: """Handle when a log message occurs.""" ... - - def on_measure(self, name: str, value: float, details: dict | None = None) -> None: - """Handle when a measurement occurs.""" - ... diff --git a/graphrag/callbacks/workflow_callbacks_manager.py b/graphrag/callbacks/workflow_callbacks_manager.py index d677462cb7..4c02a0f2fb 100644 --- a/graphrag/callbacks/workflow_callbacks_manager.py +++ b/graphrag/callbacks/workflow_callbacks_manager.py @@ -3,8 +3,6 @@ """A module containing the WorkflowCallbacks registry.""" -from typing import Any - from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.logger.progress import Progress @@ -22,37 +20,25 @@ def register(self, callbacks: WorkflowCallbacks) -> None: """Register a new WorkflowCallbacks type.""" self._callbacks.append(callbacks) - def on_workflow_start(self, name: str, instance: object) -> None: + def workflow_start(self, name: str, instance: object) -> None: """Execute this callback when a workflow starts.""" for callback in self._callbacks: - if hasattr(callback, "on_workflow_start"): - callback.on_workflow_start(name, instance) + if hasattr(callback, "workflow_start"): + callback.workflow_start(name, instance) - def on_workflow_end(self, name: str, instance: object) -> None: + def workflow_end(self, name: str, instance: object) -> None: """Execute this callback when a workflow ends.""" for callback in self._callbacks: - if hasattr(callback, "on_workflow_end"): - callback.on_workflow_end(name, instance) - - def on_step_start(self, step_name: str) -> None: - """Execute this callback every time a step starts.""" - for callback in self._callbacks: - if hasattr(callback, "on_step_start"): - callback.on_step_start(step_name) + if hasattr(callback, "workflow_end"): + callback.workflow_end(name, instance) - def on_step_end(self, step_name: str, result: Any) -> None: - """Execute this callback every time a step ends.""" - for callback in self._callbacks: - if hasattr(callback, "on_step_end"): - callback.on_step_end(step_name, result) - - def on_step_progress(self, step_name: str, progress: Progress) -> None: + def progress(self, progress: Progress) -> None: """Handle when progress occurs.""" for callback in self._callbacks: - if hasattr(callback, "on_step_progress"): - callback.on_step_progress(step_name, progress) + if hasattr(callback, "progress"): + callback.progress(progress) - def on_error( + def error( self, message: str, cause: BaseException | None = None, @@ -61,23 +47,17 @@ def on_error( ) -> None: """Handle when an error occurs.""" for callback in self._callbacks: - if hasattr(callback, "on_error"): - callback.on_error(message, cause, stack, details) + if hasattr(callback, "error"): + callback.error(message, cause, stack, details) - def on_warning(self, message: str, details: dict | None = None) -> None: + def warning(self, message: str, details: dict | None = None) -> None: """Handle when a warning occurs.""" for callback in self._callbacks: - if hasattr(callback, "on_warning"): - callback.on_warning(message, details) + if hasattr(callback, "warning"): + callback.warning(message, details) - def on_log(self, message: str, details: dict | None = None) -> None: + def log(self, message: str, details: dict | None = None) -> None: """Handle when a log message occurs.""" for callback in self._callbacks: - if hasattr(callback, "on_log"): - callback.on_log(message, details) - - def on_measure(self, name: str, value: float, details: dict | None = None) -> None: - """Handle when a measurement occurs.""" - for callback in self._callbacks: - if hasattr(callback, "on_measure"): - callback.on_measure(name, value, details) + if hasattr(callback, "log"): + callback.log(message, details) diff --git a/graphrag/index/flows/create_base_text_units.py b/graphrag/index/flows/create_base_text_units.py index 33dad0aebd..f0ca4d60d5 100644 --- a/graphrag/index/flows/create_base_text_units.py +++ b/graphrag/index/flows/create_base_text_units.py @@ -7,7 +7,7 @@ import pandas as pd -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.chunking_config import ChunkStrategyType from graphrag.index.operations.chunk_text.chunk_text import chunk_text from graphrag.index.utils.hashing import gen_sha512_hash @@ -16,7 +16,7 @@ def create_base_text_units( documents: pd.DataFrame, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, group_by_columns: list[str], size: int, overlap: int, diff --git a/graphrag/index/flows/create_final_community_reports.py b/graphrag/index/flows/create_final_community_reports.py index f94103db04..48cad7a93d 100644 --- a/graphrag/index/flows/create_final_community_reports.py +++ b/graphrag/index/flows/create_final_community_reports.py @@ -8,7 +8,7 @@ import pandas as pd from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.enums import AsyncType from graphrag.index.operations.summarize_communities import ( prepare_community_reports, @@ -43,7 +43,7 @@ async def create_final_community_reports( entities: pd.DataFrame, communities: pd.DataFrame, claims_input: pd.DataFrame | None, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, cache: PipelineCache, summarization_strategy: dict, async_mode: AsyncType = AsyncType.AsyncIO, diff --git a/graphrag/index/flows/create_final_covariates.py b/graphrag/index/flows/create_final_covariates.py index ce6cccaa9c..eff2e65b55 100644 --- a/graphrag/index/flows/create_final_covariates.py +++ b/graphrag/index/flows/create_final_covariates.py @@ -9,7 +9,7 @@ import pandas as pd from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.enums import AsyncType from graphrag.index.operations.extract_covariates.extract_covariates import ( extract_covariates, @@ -18,7 +18,7 @@ async def create_final_covariates( text_units: pd.DataFrame, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, cache: PipelineCache, covariate_type: str, extraction_strategy: dict[str, Any] | None, diff --git a/graphrag/index/flows/create_final_nodes.py b/graphrag/index/flows/create_final_nodes.py index f75ef2733a..a0bba2dad7 100644 --- a/graphrag/index/flows/create_final_nodes.py +++ b/graphrag/index/flows/create_final_nodes.py @@ -5,7 +5,7 @@ import pandas as pd -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.embed_graph_config import EmbedGraphConfig from graphrag.index.operations.compute_degree import compute_degree from graphrag.index.operations.create_graph import create_graph @@ -17,7 +17,7 @@ def create_final_nodes( base_entity_nodes: pd.DataFrame, base_relationship_edges: pd.DataFrame, base_communities: pd.DataFrame, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, embed_config: EmbedGraphConfig, layout_enabled: bool, ) -> pd.DataFrame: diff --git a/graphrag/index/flows/extract_graph.py b/graphrag/index/flows/extract_graph.py index 8eaa4d2951..4fc513f918 100644 --- a/graphrag/index/flows/extract_graph.py +++ b/graphrag/index/flows/extract_graph.py @@ -9,7 +9,7 @@ import pandas as pd from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.enums import AsyncType from graphrag.index.operations.extract_entities import extract_entities from graphrag.index.operations.summarize_descriptions import ( @@ -19,7 +19,7 @@ async def extract_graph( text_units: pd.DataFrame, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, cache: PipelineCache, extraction_strategy: dict[str, Any] | None = None, extraction_num_threads: int = 4, diff --git a/graphrag/index/flows/generate_text_embeddings.py b/graphrag/index/flows/generate_text_embeddings.py index d8c547663d..25831f8e8a 100644 --- a/graphrag/index/flows/generate_text_embeddings.py +++ b/graphrag/index/flows/generate_text_embeddings.py @@ -8,7 +8,7 @@ import pandas as pd from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.index.config.embeddings import ( community_full_content_embedding, community_summary_embedding, @@ -32,7 +32,7 @@ async def generate_text_embeddings( final_text_units: pd.DataFrame | None, final_entities: pd.DataFrame | None, final_community_reports: pd.DataFrame | None, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, cache: PipelineCache, storage: PipelineStorage, text_embed_config: dict, @@ -110,7 +110,7 @@ async def _run_and_snapshot_embeddings( name: str, data: pd.DataFrame, embed_column: str, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, cache: PipelineCache, storage: PipelineStorage, text_embed_config: dict, diff --git a/graphrag/index/llm/load_llm.py b/graphrag/index/llm/load_llm.py index eae2cf34bd..a4f9419605 100644 --- a/graphrag/index/llm/load_llm.py +++ b/graphrag/index/llm/load_llm.py @@ -30,7 +30,7 @@ if TYPE_CHECKING: from graphrag.cache.pipeline_cache import PipelineCache - from graphrag.callbacks.verb_callbacks import VerbCallbacks + from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.index.typing import ErrorHandlerFn log = logging.getLogger(__name__) @@ -105,7 +105,7 @@ def load_llm( name: str, config: LLMParameters, *, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, cache: PipelineCache | None, chat_only=False, ) -> ChatLLM: @@ -135,7 +135,7 @@ def load_llm_embeddings( name: str, llm_config: LLMParameters, *, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, cache: PipelineCache | None, chat_only=False, ) -> EmbeddingsLLM: @@ -160,7 +160,7 @@ def load_llm_embeddings( raise ValueError(msg) -def _create_error_handler(callbacks: VerbCallbacks) -> ErrorHandlerFn: +def _create_error_handler(callbacks: WorkflowCallbacks) -> ErrorHandlerFn: def on_error( error: BaseException | None = None, stack: str | None = None, diff --git a/graphrag/index/operations/chunk_text/chunk_text.py b/graphrag/index/operations/chunk_text/chunk_text.py index 02c12e6f1a..a673c2e832 100644 --- a/graphrag/index/operations/chunk_text/chunk_text.py +++ b/graphrag/index/operations/chunk_text/chunk_text.py @@ -7,7 +7,7 @@ import pandas as pd -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.chunking_config import ChunkingConfig, ChunkStrategyType from graphrag.index.operations.chunk_text.typing import ( ChunkInput, @@ -23,7 +23,7 @@ def chunk_text( overlap: int, encoding_model: str, strategy: ChunkStrategyType, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, ) -> pd.Series: """ Chunk a piece of text into smaller pieces. diff --git a/graphrag/index/operations/embed_text/embed_text.py b/graphrag/index/operations/embed_text/embed_text.py index f4a7e5f367..793625d906 100644 --- a/graphrag/index/operations/embed_text/embed_text.py +++ b/graphrag/index/operations/embed_text/embed_text.py @@ -11,7 +11,7 @@ import pandas as pd from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingStrategy from graphrag.utils.embeddings import create_collection_name from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument @@ -37,7 +37,7 @@ def __repr__(self): async def embed_text( input: pd.DataFrame, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, cache: PipelineCache, embed_column: str, strategy: dict, @@ -109,7 +109,7 @@ async def embed_text( async def _text_embed_in_memory( input: pd.DataFrame, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, cache: PipelineCache, embed_column: str, strategy: dict, @@ -126,7 +126,7 @@ async def _text_embed_in_memory( async def _text_embed_with_vector_store( input: pd.DataFrame, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, cache: PipelineCache, embed_column: str, strategy: dict[str, Any], diff --git a/graphrag/index/operations/embed_text/strategies/mock.py b/graphrag/index/operations/embed_text/strategies/mock.py index 9facd66643..6aa60ff3ef 100644 --- a/graphrag/index/operations/embed_text/strategies/mock.py +++ b/graphrag/index/operations/embed_text/strategies/mock.py @@ -8,14 +8,14 @@ from typing import Any from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingResult from graphrag.logger.progress import ProgressTicker, progress_ticker async def run( # noqa RUF029 async is required for interface input: list[str], - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, cache: PipelineCache, _args: dict[str, Any], ) -> TextEmbeddingResult: diff --git a/graphrag/index/operations/embed_text/strategies/openai.py b/graphrag/index/operations/embed_text/strategies/openai.py index 5bef604dab..4bfbcf0ffa 100644 --- a/graphrag/index/operations/embed_text/strategies/openai.py +++ b/graphrag/index/operations/embed_text/strategies/openai.py @@ -13,7 +13,7 @@ import graphrag.config.defaults as defs from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.llm_parameters import LLMParameters from graphrag.index.llm.load_llm import load_llm_embeddings from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingResult @@ -26,7 +26,7 @@ async def run( input: list[str], - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, cache: PipelineCache, args: dict[str, Any], ) -> TextEmbeddingResult: @@ -75,7 +75,7 @@ def _get_splitter(config: LLMParameters, batch_max_tokens: int) -> TokenTextSpli def _get_llm( config: LLMParameters, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, cache: PipelineCache, ) -> EmbeddingsLLM: return load_llm_embeddings( diff --git a/graphrag/index/operations/embed_text/strategies/typing.py b/graphrag/index/operations/embed_text/strategies/typing.py index 5962045a67..f45a7eb36e 100644 --- a/graphrag/index/operations/embed_text/strategies/typing.py +++ b/graphrag/index/operations/embed_text/strategies/typing.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks @dataclass @@ -20,7 +20,7 @@ class TextEmbeddingResult: TextEmbeddingStrategy = Callable[ [ list[str], - VerbCallbacks, + WorkflowCallbacks, PipelineCache, dict, ], diff --git a/graphrag/index/operations/extract_covariates/extract_covariates.py b/graphrag/index/operations/extract_covariates/extract_covariates.py index 323d95627d..70f0585b98 100644 --- a/graphrag/index/operations/extract_covariates/extract_covariates.py +++ b/graphrag/index/operations/extract_covariates/extract_covariates.py @@ -12,7 +12,7 @@ import graphrag.config.defaults as defs from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.enums import AsyncType from graphrag.index.llm.load_llm import load_llm, read_llm_params from graphrag.index.operations.extract_covariates.claim_extractor import ClaimExtractor @@ -30,7 +30,7 @@ async def extract_covariates( input: pd.DataFrame, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, cache: PipelineCache, column: str, covariate_type: str, @@ -78,7 +78,7 @@ async def run_claim_extraction( input: str | Iterable[str], entity_types: list[str], resolved_entities_map: dict[str, str], - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, cache: PipelineCache, strategy_config: dict[str, Any], ) -> CovariateExtractionResult: diff --git a/graphrag/index/operations/extract_covariates/typing.py b/graphrag/index/operations/extract_covariates/typing.py index 8f95b9b5fb..a524b2bc17 100644 --- a/graphrag/index/operations/extract_covariates/typing.py +++ b/graphrag/index/operations/extract_covariates/typing.py @@ -8,7 +8,7 @@ from typing import Any from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks @dataclass @@ -41,7 +41,7 @@ class CovariateExtractionResult: Iterable[str], list[str], dict[str, str], - VerbCallbacks, + WorkflowCallbacks, PipelineCache, dict[str, Any], ], diff --git a/graphrag/index/operations/extract_entities/extract_entities.py b/graphrag/index/operations/extract_entities/extract_entities.py index d50e1219b3..243ac200f6 100644 --- a/graphrag/index/operations/extract_entities/extract_entities.py +++ b/graphrag/index/operations/extract_entities/extract_entities.py @@ -9,7 +9,7 @@ import pandas as pd from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.enums import AsyncType from graphrag.index.bootstrap import bootstrap from graphrag.index.operations.extract_entities.typing import ( @@ -27,7 +27,7 @@ async def extract_entities( text_units: pd.DataFrame, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, cache: PipelineCache, text_column: str, id_column: str, diff --git a/graphrag/index/operations/extract_entities/graph_intelligence_strategy.py b/graphrag/index/operations/extract_entities/graph_intelligence_strategy.py index 2a403112a1..919858779a 100644 --- a/graphrag/index/operations/extract_entities/graph_intelligence_strategy.py +++ b/graphrag/index/operations/extract_entities/graph_intelligence_strategy.py @@ -8,7 +8,7 @@ import graphrag.config.defaults as defs from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.index.llm.load_llm import load_llm, read_llm_params from graphrag.index.operations.extract_entities.graph_extractor import GraphExtractor from graphrag.index.operations.extract_entities.typing import ( @@ -22,7 +22,7 @@ async def run_graph_intelligence( docs: list[Document], entity_types: EntityTypes, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, cache: PipelineCache, args: StrategyConfig, ) -> EntityExtractionResult: @@ -36,7 +36,7 @@ async def run_extract_entities( llm: ChatLLM, docs: list[Document], entity_types: EntityTypes, - callbacks: VerbCallbacks | None, + callbacks: WorkflowCallbacks | None, args: StrategyConfig, ) -> EntityExtractionResult: """Run the entity extraction chain.""" diff --git a/graphrag/index/operations/extract_entities/nltk_strategy.py b/graphrag/index/operations/extract_entities/nltk_strategy.py index e133aeeab4..d9810b8a38 100644 --- a/graphrag/index/operations/extract_entities/nltk_strategy.py +++ b/graphrag/index/operations/extract_entities/nltk_strategy.py @@ -8,7 +8,7 @@ from nltk.corpus import words from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.index.operations.extract_entities.typing import ( Document, EntityExtractionResult, @@ -23,7 +23,7 @@ async def run( # noqa RUF029 async is required for interface docs: list[Document], entity_types: EntityTypes, - callbacks: VerbCallbacks, # noqa ARG001 + callbacks: WorkflowCallbacks, # noqa ARG001 cache: PipelineCache, # noqa ARG001 args: StrategyConfig, # noqa ARG001 ) -> EntityExtractionResult: diff --git a/graphrag/index/operations/extract_entities/typing.py b/graphrag/index/operations/extract_entities/typing.py index 247c781003..0361317f76 100644 --- a/graphrag/index/operations/extract_entities/typing.py +++ b/graphrag/index/operations/extract_entities/typing.py @@ -11,7 +11,7 @@ import networkx as nx from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks ExtractedEntity = dict[str, Any] ExtractedRelationship = dict[str, Any] @@ -40,7 +40,7 @@ class EntityExtractionResult: [ list[Document], EntityTypes, - VerbCallbacks, + WorkflowCallbacks, PipelineCache, StrategyConfig, ], diff --git a/graphrag/index/operations/layout_graph/layout_graph.py b/graphrag/index/operations/layout_graph/layout_graph.py index b96ef91e34..f004f54fe0 100644 --- a/graphrag/index/operations/layout_graph/layout_graph.py +++ b/graphrag/index/operations/layout_graph/layout_graph.py @@ -6,14 +6,14 @@ import networkx as nx import pandas as pd -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.index.operations.embed_graph.typing import NodeEmbeddings from graphrag.index.operations.layout_graph.typing import GraphLayout def layout_graph( graph: nx.Graph, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, enabled: bool, embeddings: NodeEmbeddings | None, ): @@ -58,7 +58,7 @@ def _run_layout( graph: nx.Graph, enabled: bool, embeddings: NodeEmbeddings, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, ) -> GraphLayout: if enabled: from graphrag.index.operations.layout_graph.umap import ( diff --git a/graphrag/index/operations/summarize_communities/prepare_community_reports.py b/graphrag/index/operations/summarize_communities/prepare_community_reports.py index 66fcaa2bb5..d1f88f6d94 100644 --- a/graphrag/index/operations/summarize_communities/prepare_community_reports.py +++ b/graphrag/index/operations/summarize_communities/prepare_community_reports.py @@ -8,7 +8,7 @@ import pandas as pd import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.index.operations.summarize_communities.community_reports_extractor.sort_context import ( parallel_sort_context_batch, ) @@ -24,7 +24,7 @@ def prepare_community_reports( nodes, edges, claims, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, max_tokens: int = 16_000, ): """Prep communities for report generation.""" diff --git a/graphrag/index/operations/summarize_communities/strategies.py b/graphrag/index/operations/summarize_communities/strategies.py index e630baba73..2d0c6af521 100644 --- a/graphrag/index/operations/summarize_communities/strategies.py +++ b/graphrag/index/operations/summarize_communities/strategies.py @@ -9,7 +9,7 @@ from fnllm import ChatLLM from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.index.llm.load_llm import load_llm, read_llm_params from graphrag.index.operations.summarize_communities.community_reports_extractor.community_reports_extractor import ( CommunityReportsExtractor, @@ -28,7 +28,7 @@ async def run_graph_intelligence( community: str | int, input: str, level: int, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, cache: PipelineCache, args: StrategyConfig, ) -> CommunityReport | None: @@ -44,7 +44,7 @@ async def _run_extractor( input: str, level: int, args: StrategyConfig, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, ) -> CommunityReport | None: # RateLimiter rate_limiter = RateLimiter(rate=1, per=60) diff --git a/graphrag/index/operations/summarize_communities/summarize_communities.py b/graphrag/index/operations/summarize_communities/summarize_communities.py index df6dd631e1..4f29b062ab 100644 --- a/graphrag/index/operations/summarize_communities/summarize_communities.py +++ b/graphrag/index/operations/summarize_communities/summarize_communities.py @@ -10,8 +10,8 @@ import graphrag.config.defaults as defaults import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.enums import AsyncType from graphrag.index.operations.summarize_communities.community_reports_extractor import ( prep_community_report_context, @@ -34,7 +34,7 @@ async def summarize_communities( local_contexts, nodes, community_hierarchy, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, cache: PipelineCache, strategy: dict, async_mode: AsyncType = AsyncType.AsyncIO, @@ -73,7 +73,7 @@ async def run_generate(record): local_reports = await derive_from_rows( level_contexts, run_generate, - callbacks=NoopVerbCallbacks(), + callbacks=NoopWorkflowCallbacks(), num_threads=num_threads, async_type=async_mode, ) @@ -84,7 +84,7 @@ async def run_generate(record): async def _generate_report( runner: CommunityReportsStrategy, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, cache: PipelineCache, strategy: dict, community_id: int, diff --git a/graphrag/index/operations/summarize_communities/typing.py b/graphrag/index/operations/summarize_communities/typing.py index 2a1ed3aca5..b48a05a6e8 100644 --- a/graphrag/index/operations/summarize_communities/typing.py +++ b/graphrag/index/operations/summarize_communities/typing.py @@ -10,7 +10,7 @@ from typing_extensions import TypedDict from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks ExtractedEntity = dict[str, Any] StrategyConfig = dict[str, Any] @@ -45,7 +45,7 @@ class CommunityReport(TypedDict): str | int, str, int, - VerbCallbacks, + WorkflowCallbacks, PipelineCache, StrategyConfig, ], diff --git a/graphrag/index/operations/summarize_descriptions/graph_intelligence_strategy.py b/graphrag/index/operations/summarize_descriptions/graph_intelligence_strategy.py index e5de39f57f..1c5cccb022 100644 --- a/graphrag/index/operations/summarize_descriptions/graph_intelligence_strategy.py +++ b/graphrag/index/operations/summarize_descriptions/graph_intelligence_strategy.py @@ -6,7 +6,7 @@ from fnllm import ChatLLM from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.index.llm.load_llm import load_llm, read_llm_params from graphrag.index.operations.summarize_descriptions.description_summary_extractor import ( SummarizeExtractor, @@ -20,7 +20,7 @@ async def run_graph_intelligence( id: str | tuple[str, str], descriptions: list[str], - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, cache: PipelineCache, args: StrategyConfig, ) -> SummarizedDescriptionResult: @@ -36,7 +36,7 @@ async def run_summarize_descriptions( llm: ChatLLM, id: str | tuple[str, str], descriptions: list[str], - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, args: StrategyConfig, ) -> SummarizedDescriptionResult: """Run the entity extraction chain.""" diff --git a/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py b/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py index d1ad4af487..85674873d6 100644 --- a/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py +++ b/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py @@ -10,7 +10,7 @@ import pandas as pd from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.index.operations.summarize_descriptions.typing import ( SummarizationStrategy, SummarizeStrategyType, @@ -23,7 +23,7 @@ async def summarize_descriptions( entities_df: pd.DataFrame, relationships_df: pd.DataFrame, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, cache: PipelineCache, strategy: dict[str, Any] | None = None, num_threads: int = 4, diff --git a/graphrag/index/operations/summarize_descriptions/typing.py b/graphrag/index/operations/summarize_descriptions/typing.py index 919ff9fd1c..565fc8eab2 100644 --- a/graphrag/index/operations/summarize_descriptions/typing.py +++ b/graphrag/index/operations/summarize_descriptions/typing.py @@ -9,7 +9,7 @@ from typing import Any, NamedTuple from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks StrategyConfig = dict[str, Any] @@ -26,7 +26,7 @@ class SummarizedDescriptionResult: [ str | tuple[str, str], list[str], - VerbCallbacks, + WorkflowCallbacks, PipelineCache, StrategyConfig, ], diff --git a/graphrag/index/run/derive_from_rows.py b/graphrag/index/run/derive_from_rows.py index 283621bb93..1c04e964da 100644 --- a/graphrag/index/run/derive_from_rows.py +++ b/graphrag/index/run/derive_from_rows.py @@ -12,7 +12,7 @@ import pandas as pd -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.enums import AsyncType from graphrag.logger.progress import progress_ticker @@ -32,7 +32,7 @@ def __init__(self, num_errors: int): async def derive_from_rows( input: pd.DataFrame, transform: Callable[[pd.Series], Awaitable[ItemType]], - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, num_threads: int = 4, async_type: AsyncType = AsyncType.AsyncIO, ) -> list[ItemType | None]: @@ -57,7 +57,7 @@ async def derive_from_rows( async def derive_from_rows_asyncio_threads( input: pd.DataFrame, transform: Callable[[pd.Series], Awaitable[ItemType]], - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, num_threads: int | None = 4, ) -> list[ItemType | None]: """ @@ -87,7 +87,7 @@ async def execute_task(task: Coroutine) -> ItemType | None: async def derive_from_rows_asyncio( input: pd.DataFrame, transform: Callable[[pd.Series], Awaitable[ItemType]], - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, num_threads: int = 4, ) -> list[ItemType | None]: """ @@ -121,7 +121,7 @@ async def execute_row_protected( async def _derive_from_rows_base( input: pd.DataFrame, transform: Callable[[pd.Series], Awaitable[ItemType]], - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, gather: GatherFn[ItemType], ) -> list[ItemType | None]: """ diff --git a/graphrag/index/run/run_workflows.py b/graphrag/index/run/run_workflows.py index 096fe9fb1a..0a7afb456b 100644 --- a/graphrag/index/run/run_workflows.py +++ b/graphrag/index/run/run_workflows.py @@ -9,15 +9,13 @@ import traceback from collections.abc import AsyncIterable from dataclasses import asdict -from typing import cast import pandas as pd from graphrag.cache.factory import CacheFactory from graphrag.cache.pipeline_cache import PipelineCache from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks -from graphrag.callbacks.delegating_verb_callbacks import DelegatingVerbCallbacks -from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.context import PipelineRunStats @@ -119,7 +117,7 @@ async def run_workflows( update_storage=update_index_storage, config=config, cache=cache, - callbacks=NoopVerbCallbacks(), + callbacks=NoopWorkflowCallbacks(), progress_logger=progress_logger, ) @@ -163,16 +161,15 @@ async def _run_workflows( last_workflow = workflow run_workflow = all_workflows[workflow] progress = logger.child(workflow, transient=False) - callbacks.on_workflow_start(workflow, None) - verb_callbacks = DelegatingVerbCallbacks(workflow, callbacks) + callbacks.workflow_start(workflow, None) work_time = time.time() result = await run_workflow( config, context, - verb_callbacks, + callbacks, ) progress(Progress(percent=1)) - callbacks.on_workflow_end(workflow, result) + callbacks.workflow_end(workflow, result) yield PipelineRunResult(workflow, result, None) context.stats.workflows[workflow] = {"overall": time.time() - work_time} @@ -186,9 +183,7 @@ async def _run_workflows( except Exception as e: log.exception("error running workflow %s", last_workflow) - cast("WorkflowCallbacks", callbacks).on_error( - "Error running pipeline!", e, traceback.format_exc() - ) + callbacks.error("Error running pipeline!", e, traceback.format_exc()) yield PipelineRunResult(last_workflow, None, [e]) diff --git a/graphrag/index/update/entities.py b/graphrag/index/update/entities.py index 849fa4a749..e94e0f95aa 100644 --- a/graphrag/index/update/entities.py +++ b/graphrag/index/update/entities.py @@ -10,7 +10,7 @@ import pandas as pd from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.operations.summarize_descriptions.graph_intelligence_strategy import ( run_graph_intelligence as run_entity_summarization, @@ -92,7 +92,7 @@ async def _run_entity_summarization( entities_df: pd.DataFrame, config: GraphRagConfig, cache: PipelineCache, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, ) -> pd.DataFrame: """Run entity summarization. diff --git a/graphrag/index/update/incremental_index.py b/graphrag/index/update/incremental_index.py index 4ba486af6b..0c6a44373c 100644 --- a/graphrag/index/update/incremental_index.py +++ b/graphrag/index/update/incremental_index.py @@ -9,7 +9,7 @@ import pandas as pd from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.config.embeddings import get_embedded_fields, get_embedding_settings from graphrag.index.flows.generate_text_embeddings import generate_text_embeddings @@ -86,7 +86,7 @@ async def update_dataframe_outputs( update_storage: PipelineStorage, config: GraphRagConfig, cache: PipelineCache, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, progress_logger: ProgressLogger, ) -> None: """Update the mergeable outputs. diff --git a/graphrag/index/validate_config.py b/graphrag/index/validate_config.py index a98e4cb707..73e023afb6 100644 --- a/graphrag/index/validate_config.py +++ b/graphrag/index/validate_config.py @@ -6,7 +6,7 @@ import asyncio import sys -from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.llm.load_llm import load_llm, load_llm_embeddings from graphrag.logger.print_progress import ProgressLogger @@ -18,7 +18,7 @@ def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) -> llm = load_llm( "test-llm", parameters.llm, - callbacks=NoopVerbCallbacks(), + callbacks=NoopWorkflowCallbacks(), cache=None, ) try: @@ -32,7 +32,7 @@ def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) -> embed_llm = load_llm_embeddings( "test-embed-llm", parameters.embeddings.llm, - callbacks=NoopVerbCallbacks(), + callbacks=NoopWorkflowCallbacks(), cache=None, ) try: diff --git a/graphrag/index/workflows/__init__.py b/graphrag/index/workflows/__init__.py index a904dc7bb8..a82a305634 100644 --- a/graphrag/index/workflows/__init__.py +++ b/graphrag/index/workflows/__init__.py @@ -8,7 +8,7 @@ import pandas as pd -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.context import PipelineRunContext @@ -88,7 +88,7 @@ all_workflows: dict[ str, Callable[ - [GraphRagConfig, PipelineRunContext, VerbCallbacks], + [GraphRagConfig, PipelineRunContext, WorkflowCallbacks], Awaitable[pd.DataFrame | None], ], ] = { diff --git a/graphrag/index/workflows/compute_communities.py b/graphrag/index/workflows/compute_communities.py index 51cf511d50..a5bc9e0eba 100644 --- a/graphrag/index/workflows/compute_communities.py +++ b/graphrag/index/workflows/compute_communities.py @@ -5,7 +5,7 @@ import pandas as pd -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.context import PipelineRunContext from graphrag.index.flows.compute_communities import compute_communities @@ -17,7 +17,7 @@ async def run_workflow( config: GraphRagConfig, context: PipelineRunContext, - _callbacks: VerbCallbacks, + _callbacks: WorkflowCallbacks, ) -> pd.DataFrame | None: """All the steps to create the base communities.""" base_relationship_edges = await load_table_from_storage( diff --git a/graphrag/index/workflows/create_base_text_units.py b/graphrag/index/workflows/create_base_text_units.py index 91d5822884..e882eca149 100644 --- a/graphrag/index/workflows/create_base_text_units.py +++ b/graphrag/index/workflows/create_base_text_units.py @@ -5,7 +5,7 @@ import pandas as pd -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.context import PipelineRunContext from graphrag.index.flows.create_base_text_units import ( @@ -19,7 +19,7 @@ async def run_workflow( config: GraphRagConfig, context: PipelineRunContext, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, ) -> pd.DataFrame | None: """All the steps to transform base text_units.""" documents = await load_table_from_storage("input", context.storage) diff --git a/graphrag/index/workflows/create_final_communities.py b/graphrag/index/workflows/create_final_communities.py index e1cf950e97..fa08224d5a 100644 --- a/graphrag/index/workflows/create_final_communities.py +++ b/graphrag/index/workflows/create_final_communities.py @@ -5,7 +5,7 @@ import pandas as pd -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.context import PipelineRunContext from graphrag.index.flows.create_final_communities import ( @@ -19,7 +19,7 @@ async def run_workflow( _config: GraphRagConfig, context: PipelineRunContext, - _callbacks: VerbCallbacks, + _callbacks: WorkflowCallbacks, ) -> pd.DataFrame | None: """All the steps to transform final communities.""" base_entity_nodes = await load_table_from_storage( diff --git a/graphrag/index/workflows/create_final_community_reports.py b/graphrag/index/workflows/create_final_community_reports.py index 7aacc79fbf..8d1b613036 100644 --- a/graphrag/index/workflows/create_final_community_reports.py +++ b/graphrag/index/workflows/create_final_community_reports.py @@ -5,7 +5,7 @@ import pandas as pd -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.context import PipelineRunContext from graphrag.index.flows.create_final_community_reports import ( @@ -19,7 +19,7 @@ async def run_workflow( config: GraphRagConfig, context: PipelineRunContext, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, ) -> pd.DataFrame | None: """All the steps to transform community reports.""" nodes = await load_table_from_storage("create_final_nodes", context.storage) diff --git a/graphrag/index/workflows/create_final_covariates.py b/graphrag/index/workflows/create_final_covariates.py index 9ab91fdf16..7830889c96 100644 --- a/graphrag/index/workflows/create_final_covariates.py +++ b/graphrag/index/workflows/create_final_covariates.py @@ -5,7 +5,7 @@ import pandas as pd -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.context import PipelineRunContext from graphrag.index.flows.create_final_covariates import ( @@ -19,7 +19,7 @@ async def run_workflow( config: GraphRagConfig, context: PipelineRunContext, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, ) -> pd.DataFrame | None: """All the steps to extract and format covariates.""" text_units = await load_table_from_storage( diff --git a/graphrag/index/workflows/create_final_documents.py b/graphrag/index/workflows/create_final_documents.py index bbc1490b8f..d4b643a852 100644 --- a/graphrag/index/workflows/create_final_documents.py +++ b/graphrag/index/workflows/create_final_documents.py @@ -5,7 +5,7 @@ import pandas as pd -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.context import PipelineRunContext from graphrag.index.flows.create_final_documents import ( @@ -19,7 +19,7 @@ async def run_workflow( config: GraphRagConfig, context: PipelineRunContext, - _callbacks: VerbCallbacks, + _callbacks: WorkflowCallbacks, ) -> pd.DataFrame | None: """All the steps to transform final documents.""" documents = await load_table_from_storage("input", context.storage) diff --git a/graphrag/index/workflows/create_final_entities.py b/graphrag/index/workflows/create_final_entities.py index 565da6cf6b..33213062cc 100644 --- a/graphrag/index/workflows/create_final_entities.py +++ b/graphrag/index/workflows/create_final_entities.py @@ -5,7 +5,7 @@ import pandas as pd -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.context import PipelineRunContext from graphrag.index.flows.create_final_entities import ( @@ -19,7 +19,7 @@ async def run_workflow( _config: GraphRagConfig, context: PipelineRunContext, - _callbacks: VerbCallbacks, + _callbacks: WorkflowCallbacks, ) -> pd.DataFrame | None: """All the steps to transform final entities.""" base_entity_nodes = await load_table_from_storage( diff --git a/graphrag/index/workflows/create_final_nodes.py b/graphrag/index/workflows/create_final_nodes.py index aa1ec3c177..6c0109a897 100644 --- a/graphrag/index/workflows/create_final_nodes.py +++ b/graphrag/index/workflows/create_final_nodes.py @@ -5,7 +5,7 @@ import pandas as pd -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.context import PipelineRunContext from graphrag.index.flows.create_final_nodes import ( @@ -19,7 +19,7 @@ async def run_workflow( config: GraphRagConfig, context: PipelineRunContext, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, ) -> pd.DataFrame | None: """All the steps to transform final nodes.""" base_entity_nodes = await load_table_from_storage( diff --git a/graphrag/index/workflows/create_final_relationships.py b/graphrag/index/workflows/create_final_relationships.py index f6896420b0..fb672d79eb 100644 --- a/graphrag/index/workflows/create_final_relationships.py +++ b/graphrag/index/workflows/create_final_relationships.py @@ -5,7 +5,7 @@ import pandas as pd -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.context import PipelineRunContext from graphrag.index.flows.create_final_relationships import ( @@ -19,7 +19,7 @@ async def run_workflow( _config: GraphRagConfig, context: PipelineRunContext, - _callbacks: VerbCallbacks, + _callbacks: WorkflowCallbacks, ) -> pd.DataFrame | None: """All the steps to transform final relationships.""" base_relationship_edges = await load_table_from_storage( diff --git a/graphrag/index/workflows/create_final_text_units.py b/graphrag/index/workflows/create_final_text_units.py index d9d49fec4f..53e1db5593 100644 --- a/graphrag/index/workflows/create_final_text_units.py +++ b/graphrag/index/workflows/create_final_text_units.py @@ -5,7 +5,7 @@ import pandas as pd -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.context import PipelineRunContext from graphrag.index.flows.create_final_text_units import ( @@ -19,7 +19,7 @@ async def run_workflow( config: GraphRagConfig, context: PipelineRunContext, - _callbacks: VerbCallbacks, + _callbacks: WorkflowCallbacks, ) -> pd.DataFrame | None: """All the steps to transform the text units.""" text_units = await load_table_from_storage( diff --git a/graphrag/index/workflows/extract_graph.py b/graphrag/index/workflows/extract_graph.py index 454bf7806a..94227f1fad 100644 --- a/graphrag/index/workflows/extract_graph.py +++ b/graphrag/index/workflows/extract_graph.py @@ -5,7 +5,7 @@ import pandas as pd -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.context import PipelineRunContext from graphrag.index.flows.extract_graph import ( @@ -21,7 +21,7 @@ async def run_workflow( config: GraphRagConfig, context: PipelineRunContext, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, ) -> pd.DataFrame | None: """All the steps to create the base entity graph.""" text_units = await load_table_from_storage( diff --git a/graphrag/index/workflows/generate_text_embeddings.py b/graphrag/index/workflows/generate_text_embeddings.py index 29a8bf0988..df1ee0ada8 100644 --- a/graphrag/index/workflows/generate_text_embeddings.py +++ b/graphrag/index/workflows/generate_text_embeddings.py @@ -5,7 +5,7 @@ import pandas as pd -from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.config.embeddings import get_embedded_fields, get_embedding_settings from graphrag.index.context import PipelineRunContext @@ -20,7 +20,7 @@ async def run_workflow( config: GraphRagConfig, context: PipelineRunContext, - callbacks: VerbCallbacks, + callbacks: WorkflowCallbacks, ) -> pd.DataFrame | None: """All the steps to transform community reports.""" final_documents = await load_table_from_storage( diff --git a/graphrag/prompt_tune/loader/input.py b/graphrag/prompt_tune/loader/input.py index db8a95804a..fa069f50b0 100644 --- a/graphrag/prompt_tune/loader/input.py +++ b/graphrag/prompt_tune/loader/input.py @@ -9,7 +9,7 @@ from pydantic import TypeAdapter import graphrag.config.defaults as defs -from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.config.models.llm_parameters import LLMParameters from graphrag.index.input.factory import create_input @@ -77,7 +77,7 @@ async def load_docs_in_chunks( overlap=MIN_CHUNK_OVERLAP, encoding_model=defs.ENCODING_MODEL, strategy=chunk_config.strategy, - callbacks=NoopVerbCallbacks(), + callbacks=NoopWorkflowCallbacks(), ) # Select chunks into a new df and explode it @@ -98,7 +98,7 @@ async def load_docs_in_chunks( embedding_llm = load_llm_embeddings( "prompt_tuning_embeddings", llm_config, - callbacks=NoopVerbCallbacks(), + callbacks=NoopWorkflowCallbacks(), cache=None, ) diff --git a/tests/verbs/test_compute_communities.py b/tests/verbs/test_compute_communities.py index a460793e0b..8f6b5b24fc 100644 --- a/tests/verbs/test_compute_communities.py +++ b/tests/verbs/test_compute_communities.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.config.create_graphrag_config import create_graphrag_config from graphrag.index.workflows.compute_communities import run_workflow from graphrag.utils.storage import load_table_from_storage @@ -25,7 +25,7 @@ async def test_compute_communities(): await run_workflow( config, context, - NoopVerbCallbacks(), + NoopWorkflowCallbacks(), ) actual = await load_table_from_storage("base_communities", context.storage) diff --git a/tests/verbs/test_create_base_text_units.py b/tests/verbs/test_create_base_text_units.py index 587db6549d..d8627beafb 100644 --- a/tests/verbs/test_create_base_text_units.py +++ b/tests/verbs/test_create_base_text_units.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.config.create_graphrag_config import create_graphrag_config from graphrag.index.workflows.create_base_text_units import run_workflow, workflow_name from graphrag.utils.storage import load_table_from_storage @@ -25,7 +25,7 @@ async def test_create_base_text_units(): await run_workflow( config, context, - NoopVerbCallbacks(), + NoopWorkflowCallbacks(), ) actual = await load_table_from_storage(workflow_name, context.storage) diff --git a/tests/verbs/test_create_final_communities.py b/tests/verbs/test_create_final_communities.py index 07c9e9baa5..23f36b37e6 100644 --- a/tests/verbs/test_create_final_communities.py +++ b/tests/verbs/test_create_final_communities.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.config.create_graphrag_config import create_graphrag_config from graphrag.index.workflows.create_final_communities import ( run_workflow, @@ -32,7 +32,7 @@ async def test_create_final_communities(): await run_workflow( config, context, - NoopVerbCallbacks(), + NoopWorkflowCallbacks(), ) actual = await load_table_from_storage(workflow_name, context.storage) diff --git a/tests/verbs/test_create_final_community_reports.py b/tests/verbs/test_create_final_community_reports.py index 896fe6e3cb..dddabf0910 100644 --- a/tests/verbs/test_create_final_community_reports.py +++ b/tests/verbs/test_create_final_community_reports.py @@ -4,7 +4,7 @@ import pytest -from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.config.create_graphrag_config import create_graphrag_config from graphrag.config.enums import LLMType from graphrag.index.operations.summarize_communities.community_reports_extractor.community_reports_extractor import ( @@ -70,7 +70,7 @@ async def test_create_final_community_reports(): await run_workflow( config, context, - NoopVerbCallbacks(), + NoopWorkflowCallbacks(), ) actual = await load_table_from_storage(workflow_name, context.storage) @@ -105,5 +105,5 @@ async def test_create_final_community_reports_missing_llm_throws(): await run_workflow( config, context, - NoopVerbCallbacks(), + NoopWorkflowCallbacks(), ) diff --git a/tests/verbs/test_create_final_covariates.py b/tests/verbs/test_create_final_covariates.py index 8236abd7bc..4e170edf71 100644 --- a/tests/verbs/test_create_final_covariates.py +++ b/tests/verbs/test_create_final_covariates.py @@ -4,7 +4,7 @@ import pytest from pandas.testing import assert_series_equal -from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.config.create_graphrag_config import create_graphrag_config from graphrag.config.enums import LLMType from graphrag.index.run.derive_from_rows import ParallelizationError @@ -46,7 +46,7 @@ async def test_create_final_covariates(): await run_workflow( config, context, - NoopVerbCallbacks(), + NoopWorkflowCallbacks(), ) actual = await load_table_from_storage(workflow_name, context.storage) @@ -95,5 +95,5 @@ async def test_create_final_covariates_missing_llm_throws(): await run_workflow( config, context, - NoopVerbCallbacks(), + NoopWorkflowCallbacks(), ) diff --git a/tests/verbs/test_create_final_documents.py b/tests/verbs/test_create_final_documents.py index a6916530a0..62db673132 100644 --- a/tests/verbs/test_create_final_documents.py +++ b/tests/verbs/test_create_final_documents.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.config.create_graphrag_config import create_graphrag_config from graphrag.index.workflows.create_final_documents import ( run_workflow, @@ -28,7 +28,7 @@ async def test_create_final_documents(): await run_workflow( config, context, - NoopVerbCallbacks(), + NoopWorkflowCallbacks(), ) actual = await load_table_from_storage(workflow_name, context.storage) @@ -49,7 +49,7 @@ async def test_create_final_documents_with_attribute_columns(): await run_workflow( config, context, - NoopVerbCallbacks(), + NoopWorkflowCallbacks(), ) actual = await load_table_from_storage(workflow_name, context.storage) diff --git a/tests/verbs/test_create_final_entities.py b/tests/verbs/test_create_final_entities.py index 6d4430d398..8f1bfee5ef 100644 --- a/tests/verbs/test_create_final_entities.py +++ b/tests/verbs/test_create_final_entities.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.config.create_graphrag_config import create_graphrag_config from graphrag.index.workflows.create_final_entities import ( run_workflow, @@ -28,7 +28,7 @@ async def test_create_final_entities(): await run_workflow( config, context, - NoopVerbCallbacks(), + NoopWorkflowCallbacks(), ) actual = await load_table_from_storage(workflow_name, context.storage) diff --git a/tests/verbs/test_create_final_nodes.py b/tests/verbs/test_create_final_nodes.py index f37cb20cec..7c0e5a8b4f 100644 --- a/tests/verbs/test_create_final_nodes.py +++ b/tests/verbs/test_create_final_nodes.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.config.create_graphrag_config import create_graphrag_config from graphrag.index.workflows.create_final_nodes import ( run_workflow, @@ -32,7 +32,7 @@ async def test_create_final_nodes(): await run_workflow( config, context, - NoopVerbCallbacks(), + NoopWorkflowCallbacks(), ) actual = await load_table_from_storage(workflow_name, context.storage) diff --git a/tests/verbs/test_create_final_relationships.py b/tests/verbs/test_create_final_relationships.py index 223ca20ea4..4e46813d06 100644 --- a/tests/verbs/test_create_final_relationships.py +++ b/tests/verbs/test_create_final_relationships.py @@ -2,7 +2,7 @@ # Licensed under the MIT License -from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.config.create_graphrag_config import create_graphrag_config from graphrag.index.workflows.create_final_relationships import ( run_workflow, @@ -29,7 +29,7 @@ async def test_create_final_relationships(): await run_workflow( config, context, - NoopVerbCallbacks(), + NoopWorkflowCallbacks(), ) actual = await load_table_from_storage(workflow_name, context.storage) diff --git a/tests/verbs/test_create_final_text_units.py b/tests/verbs/test_create_final_text_units.py index 19fb11c6f0..d6323f463c 100644 --- a/tests/verbs/test_create_final_text_units.py +++ b/tests/verbs/test_create_final_text_units.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.config.create_graphrag_config import create_graphrag_config from graphrag.index.workflows.create_final_text_units import ( run_workflow, @@ -34,7 +34,7 @@ async def test_create_final_text_units(): await run_workflow( config, context, - NoopVerbCallbacks(), + NoopWorkflowCallbacks(), ) actual = await load_table_from_storage(workflow_name, context.storage) @@ -60,7 +60,7 @@ async def test_create_final_text_units_no_covariates(): await run_workflow( config, context, - NoopVerbCallbacks(), + NoopWorkflowCallbacks(), ) actual = await load_table_from_storage(workflow_name, context.storage) diff --git a/tests/verbs/test_extract_graph.py b/tests/verbs/test_extract_graph.py index 68c9bb231b..b22a13f601 100644 --- a/tests/verbs/test_extract_graph.py +++ b/tests/verbs/test_extract_graph.py @@ -3,7 +3,7 @@ import pytest -from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.config.create_graphrag_config import create_graphrag_config from graphrag.config.enums import LLMType from graphrag.index.workflows.extract_graph import ( @@ -68,7 +68,7 @@ async def test_extract_graph(): await run_workflow( config, context, - NoopVerbCallbacks(), + NoopWorkflowCallbacks(), ) # graph construction creates transient tables for nodes, edges, and communities @@ -110,5 +110,5 @@ async def test_extract_graph_missing_llm_throws(): await run_workflow( config, context, - NoopVerbCallbacks(), + NoopWorkflowCallbacks(), ) diff --git a/tests/verbs/test_generate_text_embeddings.py b/tests/verbs/test_generate_text_embeddings.py index 640284c7ca..aabb1752fa 100644 --- a/tests/verbs/test_generate_text_embeddings.py +++ b/tests/verbs/test_generate_text_embeddings.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.config.create_graphrag_config import create_graphrag_config from graphrag.config.enums import TextEmbeddingTarget from graphrag.index.config.embeddings import ( @@ -38,7 +38,7 @@ async def test_generate_text_embeddings(): await run_workflow( config, context, - NoopVerbCallbacks(), + NoopWorkflowCallbacks(), ) parquet_files = context.storage.keys()