diff --git a/.semversioner/next-release/patch-20240926032712236048.json b/.semversioner/next-release/patch-20240926032712236048.json new file mode 100644 index 0000000000..677cedd7b2 --- /dev/null +++ b/.semversioner/next-release/patch-20240926032712236048.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Reorganized api,reporter,callback code into separate components. Defined debug profiles." +} diff --git a/.vscode/launch.json b/.vscode/launch.json index 909771b809..5d8bec3194 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -1,12 +1,39 @@ { + "_comment": "Use this file to configure the graphrag project for debugging. You may create other configuration profiles based on these or select one below to use.", "version": "0.2.0", "configurations": [ { - "name": "Attach to Node Functions", - "type": "node", - "request": "attach", - "port": 9229, - "preLaunchTask": "func: host start" + "name": "Indexer", + "type": "debugpy", + "request": "launch", + "module": "poetry", + "args": [ + "poe", "index", + "--root", "" + ], + }, + { + "name": "Query", + "type": "debugpy", + "request": "launch", + "module": "poetry", + "args": [ + "poe", "query", + "--root", "", + "--method", "global", + "What are the top themes in this story", + ] + }, + { + "name": "Prompt Tuning", + "type": "debugpy", + "request": "launch", + "module": "poetry", + "args": [ + "poe", "prompt_tune", + "--config", + "/settings.yaml", + ] } ] -} +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index ee8ded1e70..bf20d4c881 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -38,7 +38,6 @@ ], "python.defaultInterpreterPath": "python/services/.venv/bin/python", "python.languageServer": "Pylance", - "python.analysis.typeCheckingMode": "basic", "cSpell.customDictionaries": { "project-words": { "name": "project-words", diff --git a/graphrag/api/__init__.py b/graphrag/api/__init__.py new file mode 100644 index 0000000000..120e3e41c3 --- /dev/null +++ b/graphrag/api/__init__.py @@ -0,0 +1,30 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""API for GraphRAG. + +WARNING: This API is under development and may undergo changes in future releases. +Backwards compatibility is not guaranteed at this time. +""" + +from .index_api import build_index +from .prompt_tune_api import DocSelectionType, generate_indexing_prompts +from .query_api import ( + global_search, + global_search_streaming, + local_search, + local_search_streaming, +) + +__all__ = [ # noqa: RUF022 + # index API + "build_index", + # query API + "global_search", + "global_search_streaming", + "local_search", + "local_search_streaming", + # prompt tuning API + "DocSelectionType", + "generate_indexing_prompts", +] diff --git a/graphrag/index/api.py b/graphrag/api/index_api.py similarity index 87% rename from graphrag/index/api.py rename to graphrag/api/index_api.py index 75206964eb..4b9a36d136 100644 --- a/graphrag/index/api.py +++ b/graphrag/api/index_api.py @@ -9,15 +9,12 @@ """ from graphrag.config import CacheType, GraphRagConfig - -from .cache.noop_pipeline_cache import NoopPipelineCache -from .create_pipeline_config import create_pipeline_config -from .emit.types import TableEmitterType -from .progress import ( - ProgressReporter, -) -from .run import run_pipeline_with_config -from .typing import PipelineRunResult +from graphrag.index.cache.noop_pipeline_cache import NoopPipelineCache +from graphrag.index.create_pipeline_config import create_pipeline_config +from graphrag.index.emit.types import TableEmitterType +from graphrag.index.run import run_pipeline_with_config +from graphrag.index.typing import PipelineRunResult +from graphrag.logging import ProgressReporter async def build_index( diff --git a/graphrag/prompt_tune/api.py b/graphrag/api/prompt_tune_api.py similarity index 96% rename from graphrag/prompt_tune/api.py rename to graphrag/api/prompt_tune_api.py index aef10f9423..dfaaa91376 100644 --- a/graphrag/prompt_tune/api.py +++ b/graphrag/api/prompt_tune_api.py @@ -16,9 +16,8 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.llm import load_llm -from graphrag.index.progress import PrintProgressReporter - -from .generator import ( +from graphrag.logging import PrintProgressReporter +from graphrag.prompt_tune.generator import ( MAX_TOKEN_COUNT, create_community_summarization_prompt, create_entity_extraction_prompt, @@ -31,11 +30,11 @@ generate_entity_types, generate_persona, ) -from .loader import ( +from graphrag.prompt_tune.loader import ( MIN_CHUNK_SIZE, load_docs_in_chunks, ) -from .types import DocSelectionType +from graphrag.prompt_tune.types import DocSelectionType @validate_call diff --git a/graphrag/query/api.py b/graphrag/api/query_api.py similarity index 98% rename from graphrag/query/api.py rename to graphrag/api/query_api.py index f496f9823b..9e18ca8893 100644 --- a/graphrag/query/api.py +++ b/graphrag/api/query_api.py @@ -25,21 +25,20 @@ from pydantic import validate_call from graphrag.config import GraphRagConfig -from graphrag.index.progress.types import PrintProgressReporter +from graphrag.logging import PrintProgressReporter from graphrag.model.entity import Entity -from graphrag.query.structured_search.base import SearchResult # noqa: TCH001 -from graphrag.vector_stores.lancedb import LanceDBVectorStore -from graphrag.vector_stores.typing import VectorStoreFactory, VectorStoreType - -from .factories import get_global_search_engine, get_local_search_engine -from .indexer_adapters import ( +from graphrag.query.factories import get_global_search_engine, get_local_search_engine +from graphrag.query.indexer_adapters import ( read_indexer_covariates, read_indexer_entities, read_indexer_relationships, read_indexer_reports, read_indexer_text_units, ) -from .input.loaders.dfs import store_entity_semantic_embeddings +from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings +from graphrag.query.structured_search.base import SearchResult # noqa: TCH001 +from graphrag.vector_stores.lancedb import LanceDBVectorStore +from graphrag.vector_stores.typing import VectorStoreFactory, VectorStoreType reporter = PrintProgressReporter("") diff --git a/graphrag/callbacks/__init__.py b/graphrag/callbacks/__init__.py new file mode 100644 index 0000000000..c6b2def2f6 --- /dev/null +++ b/graphrag/callbacks/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing callback implementations.""" diff --git a/graphrag/index/reporting/blob_workflow_callbacks.py b/graphrag/callbacks/blob_workflow_callbacks.py similarity index 96% rename from graphrag/index/reporting/blob_workflow_callbacks.py rename to graphrag/callbacks/blob_workflow_callbacks.py index 59dc6e8bfe..56ed317a9f 100644 --- a/graphrag/index/reporting/blob_workflow_callbacks.py +++ b/graphrag/callbacks/blob_workflow_callbacks.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""A reporter that writes to a blob storage.""" +"""A logger that emits updates from the indexing engine to a blob in Azure Storage.""" import json from datetime import datetime, timezone @@ -14,7 +14,7 @@ class BlobWorkflowCallbacks(NoopWorkflowCallbacks): - """A reporter that writes to a blob storage.""" + """A logger that writes to a blob storage account.""" _blob_service_client: BlobServiceClient _container_name: str diff --git a/graphrag/index/reporting/console_workflow_callbacks.py b/graphrag/callbacks/console_workflow_callbacks.py similarity index 87% rename from graphrag/index/reporting/console_workflow_callbacks.py rename to graphrag/callbacks/console_workflow_callbacks.py index b1ab1278f7..4e70ba7109 100644 --- a/graphrag/index/reporting/console_workflow_callbacks.py +++ b/graphrag/callbacks/console_workflow_callbacks.py @@ -1,13 +1,13 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""Console-based reporter for the workflow engine.""" +"""A logger that emits updates from the indexing engine to the console.""" from datashaper import NoopWorkflowCallbacks class ConsoleWorkflowCallbacks(NoopWorkflowCallbacks): - """A reporter that writes to a console.""" + """A logger that writes to a console.""" def on_error( self, diff --git a/graphrag/index/reporting/load_pipeline_reporter.py b/graphrag/callbacks/factories.py similarity index 95% rename from graphrag/index/reporting/load_pipeline_reporter.py rename to graphrag/callbacks/factories.py index 3f3082dc29..3f3b64788f 100644 --- a/graphrag/index/reporting/load_pipeline_reporter.py +++ b/graphrag/callbacks/factories.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""Load pipeline reporter method.""" +"""Create a pipeline reporter.""" from pathlib import Path from typing import cast @@ -20,7 +20,7 @@ from .file_workflow_callbacks import FileWorkflowCallbacks -def load_pipeline_reporter( +def create_pipeline_reporter( config: PipelineReportingConfig | None, root_dir: str | None ) -> WorkflowCallbacks: """Create a reporter for the given pipeline config.""" diff --git a/graphrag/index/reporting/file_workflow_callbacks.py b/graphrag/callbacks/file_workflow_callbacks.py similarity index 91% rename from graphrag/index/reporting/file_workflow_callbacks.py rename to graphrag/callbacks/file_workflow_callbacks.py index 0115de01ec..95ccfea272 100644 --- a/graphrag/index/reporting/file_workflow_callbacks.py +++ b/graphrag/callbacks/file_workflow_callbacks.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""A reporter that writes to a file.""" +"""A logger that emits updates from the indexing engine to a local file.""" import json import logging @@ -14,12 +14,12 @@ class FileWorkflowCallbacks(NoopWorkflowCallbacks): - """A reporter that writes to a file.""" + """A logger that writes to a local file.""" _out_stream: TextIOWrapper def __init__(self, directory: str): - """Create a new file-based workflow reporter.""" + """Create a new file-based workflow logger.""" Path(directory).mkdir(parents=True, exist_ok=True) self._out_stream = open( # noqa: PTH123, SIM115 Path(directory) / "logs.json", "a", encoding="utf-8", errors="strict" diff --git a/graphrag/query/structured_search/global_search/callbacks.py b/graphrag/callbacks/global_search_callbacks.py similarity index 93% rename from graphrag/query/structured_search/global_search/callbacks.py rename to graphrag/callbacks/global_search_callbacks.py index f48bb79b82..32c6fc8668 100644 --- a/graphrag/query/structured_search/global_search/callbacks.py +++ b/graphrag/callbacks/global_search_callbacks.py @@ -3,9 +3,10 @@ """GlobalSearch LLM Callbacks.""" -from graphrag.query.llm.base import BaseLLMCallback from graphrag.query.structured_search.base import SearchResult +from .llm_callbacks import BaseLLMCallback + class GlobalSearchLLMCallback(BaseLLMCallback): """GlobalSearch LLM Callbacks.""" diff --git a/graphrag/callbacks/llm_callbacks.py b/graphrag/callbacks/llm_callbacks.py new file mode 100644 index 0000000000..5438e75a54 --- /dev/null +++ b/graphrag/callbacks/llm_callbacks.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LLM Callbacks.""" + + +class BaseLLMCallback: + """Base class for LLM callbacks.""" + + def __init__(self): + self.response = [] + + def on_llm_new_token(self, token: str): + """Handle when a new token is generated.""" + self.response.append(token) diff --git a/graphrag/index/reporting/progress_workflow_callbacks.py b/graphrag/callbacks/progress_workflow_callbacks.py similarity index 93% rename from graphrag/index/reporting/progress_workflow_callbacks.py rename to graphrag/callbacks/progress_workflow_callbacks.py index 68f10d7530..31c29543a5 100644 --- a/graphrag/index/reporting/progress_workflow_callbacks.py +++ b/graphrag/callbacks/progress_workflow_callbacks.py @@ -1,13 +1,13 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""A workflow callback manager that emits updates to a ProgressReporter.""" +"""A workflow callback manager that emits updates.""" from typing import Any from datashaper import ExecutionNode, NoopWorkflowCallbacks, Progress, TableContainer -from graphrag.index.progress import ProgressReporter +from graphrag.logging import ProgressReporter class ProgressWorkflowCallbacks(NoopWorkflowCallbacks): diff --git a/graphrag/index/__main__.py b/graphrag/index/__main__.py index 203d955872..bdf8a63a06 100644 --- a/graphrag/index/__main__.py +++ b/graphrag/index/__main__.py @@ -5,11 +5,11 @@ import argparse +from graphrag.logging import ReporterType from graphrag.utils.cli import dir_exist, file_exist from .cli import index_cli from .emit.types import TableEmitterType -from .progress.types import ReporterType if __name__ == "__main__": parser = argparse.ArgumentParser( diff --git a/graphrag/index/cli.py b/graphrag/index/cli.py index 7dfae5b2cf..72835f6603 100644 --- a/graphrag/index/cli.py +++ b/graphrag/index/cli.py @@ -11,22 +11,21 @@ import warnings from pathlib import Path +import graphrag.api as api from graphrag.config import ( CacheType, enable_logging_with_config, load_config, resolve_paths, ) +from graphrag.logging import ProgressReporter, ReporterType, create_progress_reporter -from .api import build_index from .emit.types import TableEmitterType from .graph.extractors.claims.prompts import CLAIM_EXTRACTION_PROMPT from .graph.extractors.community_reports.prompts import COMMUNITY_REPORT_PROMPT from .graph.extractors.graph.prompts import GRAPH_EXTRACTION_PROMPT from .graph.extractors.summarize.prompts import SUMMARIZE_PROMPT from .init_content import INIT_DOTENV, INIT_YAML -from .progress import ProgressReporter, ReporterType -from .progress.load_progress_reporter import load_progress_reporter from .validate_config import validate_config_names # Ignore warnings from numba @@ -118,7 +117,7 @@ def index_cli( output_dir: str | None, ): """Run the pipeline with the given config.""" - progress_reporter = load_progress_reporter(reporter) + progress_reporter = create_progress_reporter(reporter) info, error, success = _logger(progress_reporter) run_id = resume or update_index_id or time.strftime("%Y%m%d-%H%M%S") @@ -161,7 +160,7 @@ def index_cli( _register_signal_handlers(progress_reporter) outputs = asyncio.run( - build_index( + api.build_index( config=config, run_id=run_id, is_resume_run=bool(resume), diff --git a/graphrag/index/input/csv.py b/graphrag/index/input/csv.py index 2e4864a98c..9c93fca8f4 100644 --- a/graphrag/index/input/csv.py +++ b/graphrag/index/input/csv.py @@ -11,9 +11,9 @@ import pandas as pd from graphrag.index.config import PipelineCSVInputConfig, PipelineInputConfig -from graphrag.index.progress import ProgressReporter from graphrag.index.storage import PipelineStorage from graphrag.index.utils import gen_md5_hash +from graphrag.logging import ProgressReporter log = logging.getLogger(__name__) diff --git a/graphrag/index/input/load_input.py b/graphrag/index/input/load_input.py index 6d62334210..100caf982a 100644 --- a/graphrag/index/input/load_input.py +++ b/graphrag/index/input/load_input.py @@ -12,11 +12,11 @@ from graphrag.config import InputConfig, InputType from graphrag.index.config import PipelineInputConfig -from graphrag.index.progress import NullProgressReporter, ProgressReporter from graphrag.index.storage import ( BlobPipelineStorage, FilePipelineStorage, ) +from graphrag.logging import NullProgressReporter, ProgressReporter from .csv import input_type as csv from .csv import load as load_csv diff --git a/graphrag/index/input/text.py b/graphrag/index/input/text.py index 2a676c0902..7e76bfe1e7 100644 --- a/graphrag/index/input/text.py +++ b/graphrag/index/input/text.py @@ -11,9 +11,9 @@ import pandas as pd from graphrag.index.config import PipelineInputConfig -from graphrag.index.progress import ProgressReporter from graphrag.index.storage import PipelineStorage from graphrag.index.utils import gen_md5_hash +from graphrag.logging import ProgressReporter DEFAULT_FILE_PATTERN = re.compile( r".*[\\/](?P[^\\/]+)[\\/](?P\d{4})-(?P\d{2})-(?P\d{2})_(?P[^_]+)_\d+\.txt" diff --git a/graphrag/index/progress/__init__.py b/graphrag/index/progress/__init__.py deleted file mode 100644 index 820440dad7..0000000000 --- a/graphrag/index/progress/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Progress-reporting components.""" - -from .types import ( - NullProgressReporter, - PrintProgressReporter, - ProgressReporter, - ReporterType, -) - -__all__ = [ - "NullProgressReporter", - "PrintProgressReporter", - "ProgressReporter", - "ReporterType", -] diff --git a/graphrag/index/progress/types.py b/graphrag/index/progress/types.py deleted file mode 100644 index 1251e2bd67..0000000000 --- a/graphrag/index/progress/types.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Types for status reporting.""" - -from abc import ABC, abstractmethod -from enum import Enum - -from datashaper import Progress - - -class ReporterType(Enum): - """The type of reporter to use.""" - - RICH = "rich" - PRINT = "print" - NONE = "none" - - def __str__(self): - """Return the string representation of the enum value.""" - return self.value - - -class ProgressReporter(ABC): - """ - Abstract base class for progress reporters. - - This is used to report workflow processing progress via mechanisms like progress-bars. - """ - - @abstractmethod - def __call__(self, update: Progress): - """Update progress.""" - - @abstractmethod - def dispose(self): - """Dispose of the progress reporter.""" - - @abstractmethod - def child(self, prefix: str, transient=True) -> "ProgressReporter": - """Create a child progress bar.""" - - @abstractmethod - def force_refresh(self) -> None: - """Force a refresh.""" - - @abstractmethod - def stop(self) -> None: - """Stop the progress reporter.""" - - @abstractmethod - def error(self, message: str) -> None: - """Report an error.""" - - @abstractmethod - def warning(self, message: str) -> None: - """Report a warning.""" - - @abstractmethod - def info(self, message: str) -> None: - """Report information.""" - - @abstractmethod - def success(self, message: str) -> None: - """Report success.""" - - -class NullProgressReporter(ProgressReporter): - """A progress reporter that does nothing.""" - - def __call__(self, update: Progress) -> None: - """Update progress.""" - - def dispose(self) -> None: - """Dispose of the progress reporter.""" - - def child(self, prefix: str, transient: bool = True) -> ProgressReporter: - """Create a child progress bar.""" - return self - - def force_refresh(self) -> None: - """Force a refresh.""" - - def stop(self) -> None: - """Stop the progress reporter.""" - - def error(self, message: str) -> None: - """Report an error.""" - - def warning(self, message: str) -> None: - """Report a warning.""" - - def info(self, message: str) -> None: - """Report information.""" - - def success(self, message: str) -> None: - """Report success.""" - - -class PrintProgressReporter(ProgressReporter): - """A progress reporter that does nothing.""" - - prefix: str - - def __init__(self, prefix: str): - """Create a new progress reporter.""" - self.prefix = prefix - print(f"\n{self.prefix}", end="") # noqa T201 - - def __call__(self, update: Progress) -> None: - """Update progress.""" - print(".", end="") # noqa T201 - - def dispose(self) -> None: - """Dispose of the progress reporter.""" - - def child(self, prefix: str, transient: bool = True) -> "ProgressReporter": - """Create a child progress bar.""" - return PrintProgressReporter(prefix) - - def stop(self) -> None: - """Stop the progress reporter.""" - - def force_refresh(self) -> None: - """Force a refresh.""" - - def error(self, message: str) -> None: - """Report an error.""" - print(f"\n{self.prefix}ERROR: {message}") # noqa T201 - - def warning(self, message: str) -> None: - """Report a warning.""" - print(f"\n{self.prefix}WARNING: {message}") # noqa T201 - - def info(self, message: str) -> None: - """Report information.""" - print(f"\n{self.prefix}INFO: {message}") # noqa T201 - - def success(self, message: str) -> None: - """Report success.""" - print(f"\n{self.prefix}SUCCESS: {message}") # noqa T201 diff --git a/graphrag/index/reporting/__init__.py b/graphrag/index/reporting/__init__.py deleted file mode 100644 index 697d4fc51f..0000000000 --- a/graphrag/index/reporting/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Reporting utilities and implementations for the indexing engine.""" - -from .blob_workflow_callbacks import BlobWorkflowCallbacks -from .console_workflow_callbacks import ConsoleWorkflowCallbacks -from .file_workflow_callbacks import FileWorkflowCallbacks -from .load_pipeline_reporter import load_pipeline_reporter -from .progress_workflow_callbacks import ProgressWorkflowCallbacks - -__all__ = [ - "BlobWorkflowCallbacks", - "ConsoleWorkflowCallbacks", - "FileWorkflowCallbacks", - "ProgressWorkflowCallbacks", - "load_pipeline_reporter", -] diff --git a/graphrag/index/run/run.py b/graphrag/index/run/run.py index dd50c4a1cd..980e804cb7 100644 --- a/graphrag/index/run/run.py +++ b/graphrag/index/run/run.py @@ -13,6 +13,7 @@ import pandas as pd from datashaper import WorkflowCallbacks +from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks from graphrag.index.cache import PipelineCache from graphrag.index.config import ( PipelineConfig, @@ -21,10 +22,6 @@ ) from graphrag.index.emit import TableEmitterType, create_table_emitters from graphrag.index.load_pipeline_config import load_pipeline_config -from graphrag.index.progress import NullProgressReporter, ProgressReporter -from graphrag.index.reporting import ( - ConsoleWorkflowCallbacks, -) from graphrag.index.run.cache import _create_cache from graphrag.index.run.postprocess import ( _create_postprocess_steps, @@ -52,6 +49,10 @@ WorkflowDefinitions, load_workflows, ) +from graphrag.logging import ( + NullProgressReporter, + ProgressReporter, +) from graphrag.utils.storage import _create_storage log = logging.getLogger(__name__) diff --git a/graphrag/index/run/utils.py b/graphrag/index/run/utils.py index 7791913883..3617b92540 100644 --- a/graphrag/index/run/utils.py +++ b/graphrag/index/run/utils.py @@ -12,6 +12,7 @@ WorkflowCallbacks, ) +from graphrag.callbacks.factories import create_pipeline_reporter from graphrag.index.cache.memory_pipeline_cache import InMemoryCache from graphrag.index.cache.pipeline_cache import PipelineCache from graphrag.index.config.cache import ( @@ -31,10 +32,9 @@ ) from graphrag.index.context import PipelineRunContext, PipelineRunStats from graphrag.index.input import load_input -from graphrag.index.progress.types import ProgressReporter -from graphrag.index.reporting import load_pipeline_reporter from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage from graphrag.index.storage.typing import PipelineStorage +from graphrag.logging import ProgressReporter log = logging.getLogger(__name__) @@ -43,7 +43,7 @@ def _create_reporter( config: PipelineReportingConfigTypes | None, root_dir: str ) -> WorkflowCallbacks | None: """Create the reporter for the pipeline.""" - return load_pipeline_reporter(config, root_dir) if config else None + return create_pipeline_reporter(config, root_dir) if config else None async def _create_input( diff --git a/graphrag/index/run/workflow.py b/graphrag/index/run/workflow.py index 43d232a356..3c57d1120e 100644 --- a/graphrag/index/run/workflow.py +++ b/graphrag/index/run/workflow.py @@ -15,15 +15,13 @@ WorkflowCallbacksManager, ) +from graphrag.callbacks.progress_workflow_callbacks import ProgressWorkflowCallbacks from graphrag.index.context import PipelineRunContext from graphrag.index.emit.table_emitter import TableEmitter -from graphrag.index.progress.types import ProgressReporter -from graphrag.index.reporting.progress_workflow_callbacks import ( - ProgressWorkflowCallbacks, -) from graphrag.index.run.profiling import _write_workflow_stats from graphrag.index.storage.typing import PipelineStorage from graphrag.index.typing import PipelineRunResult +from graphrag.logging import ProgressReporter from graphrag.utils.storage import _load_table_from_storage log = logging.getLogger(__name__) diff --git a/graphrag/index/storage/blob_pipeline_storage.py b/graphrag/index/storage/blob_pipeline_storage.py index 7e60df9697..456fe7aa26 100644 --- a/graphrag/index/storage/blob_pipeline_storage.py +++ b/graphrag/index/storage/blob_pipeline_storage.py @@ -13,7 +13,7 @@ from azure.storage.blob import BlobServiceClient from datashaper import Progress -from graphrag.index.progress import ProgressReporter +from graphrag.logging import ProgressReporter from .typing import PipelineStorage diff --git a/graphrag/index/storage/file_pipeline_storage.py b/graphrag/index/storage/file_pipeline_storage.py index ee61bab3dd..8e51feddc4 100644 --- a/graphrag/index/storage/file_pipeline_storage.py +++ b/graphrag/index/storage/file_pipeline_storage.py @@ -16,7 +16,7 @@ from aiofiles.ospath import exists from datashaper import Progress -from graphrag.index.progress import ProgressReporter +from graphrag.logging import ProgressReporter from .typing import PipelineStorage diff --git a/graphrag/index/storage/typing.py b/graphrag/index/storage/typing.py index 595baf4efd..6eb727d7da 100644 --- a/graphrag/index/storage/typing.py +++ b/graphrag/index/storage/typing.py @@ -8,7 +8,7 @@ from collections.abc import Iterator from typing import Any -from graphrag.index.progress import ProgressReporter +from graphrag.logging import ProgressReporter class PipelineStorage(metaclass=ABCMeta): diff --git a/graphrag/index/validate_config.py b/graphrag/index/validate_config.py index 038a87f03c..bc3b8a0ed6 100644 --- a/graphrag/index/validate_config.py +++ b/graphrag/index/validate_config.py @@ -10,9 +10,7 @@ from graphrag.config.models import GraphRagConfig from graphrag.index.llm import load_llm, load_llm_embeddings -from graphrag.index.progress import ( - ProgressReporter, -) +from graphrag.logging import ProgressReporter def validate_config_names( diff --git a/graphrag/logging/__init__.py b/graphrag/logging/__init__.py new file mode 100644 index 0000000000..31afc94387 --- /dev/null +++ b/graphrag/logging/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Logging utilities and implementations.""" + +from .console import ConsoleReporter +from .factories import create_progress_reporter +from .null_progress import NullProgressReporter +from .print_progress import PrintProgressReporter +from .rich_progress import RichProgressReporter +from .types import ( + ProgressReporter, + ReporterType, + StatusLogger, +) + +__all__ = [ + # Progress Reporters + "ConsoleReporter", + "NullProgressReporter", + "PrintProgressReporter", + "ProgressReporter", + "ReporterType", + "RichProgressReporter", + "StatusLogger", + "create_progress_reporter", +] diff --git a/graphrag/query/progress.py b/graphrag/logging/console.py similarity index 52% rename from graphrag/query/progress.py rename to graphrag/logging/console.py index ad5bcee734..b00a7e8d9c 100644 --- a/graphrag/query/progress.py +++ b/graphrag/logging/console.py @@ -1,29 +1,14 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""Status Reporter for orchestration.""" +"""Console Reporter.""" -from abc import ABCMeta, abstractmethod from typing import Any +from .types import StatusLogger -class StatusReporter(metaclass=ABCMeta): - """Provides a way to report status updates from the pipeline.""" - @abstractmethod - def error(self, message: str, details: dict[str, Any] | None = None): - """Report an error.""" - - @abstractmethod - def warning(self, message: str, details: dict[str, Any] | None = None): - """Report a warning.""" - - @abstractmethod - def log(self, message: str, details: dict[str, Any] | None = None): - """Report a log.""" - - -class ConsoleStatusReporter(StatusReporter): +class ConsoleReporter(StatusLogger): """A reporter that writes to a console.""" def error(self, message: str, details: dict[str, Any] | None = None): diff --git a/graphrag/index/progress/load_progress_reporter.py b/graphrag/logging/factories.py similarity index 79% rename from graphrag/index/progress/load_progress_reporter.py rename to graphrag/logging/factories.py index 11bbb073be..efd69b7550 100644 --- a/graphrag/index/progress/load_progress_reporter.py +++ b/graphrag/logging/factories.py @@ -1,18 +1,18 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""Load a progress reporter.""" +"""Factory functions for creating loggers.""" -from .rich import RichProgressReporter +from .null_progress import NullProgressReporter +from .print_progress import PrintProgressReporter +from .rich_progress import RichProgressReporter from .types import ( - NullProgressReporter, - PrintProgressReporter, ProgressReporter, ReporterType, ) -def load_progress_reporter( +def create_progress_reporter( reporter_type: ReporterType = ReporterType.NONE, ) -> ProgressReporter: """Load a progress reporter. diff --git a/graphrag/logging/null_progress.py b/graphrag/logging/null_progress.py new file mode 100644 index 0000000000..0539c5c014 --- /dev/null +++ b/graphrag/logging/null_progress.py @@ -0,0 +1,38 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Null Progress Reporter.""" + +from .types import Progress, ProgressReporter + + +class NullProgressReporter(ProgressReporter): + """A progress reporter that does nothing.""" + + def __call__(self, update: Progress) -> None: + """Update progress.""" + + def dispose(self) -> None: + """Dispose of the progress reporter.""" + + def child(self, prefix: str, transient: bool = True) -> ProgressReporter: + """Create a child progress bar.""" + return self + + def force_refresh(self) -> None: + """Force a refresh.""" + + def stop(self) -> None: + """Stop the progress reporter.""" + + def error(self, message: str) -> None: + """Report an error.""" + + def warning(self, message: str) -> None: + """Report a warning.""" + + def info(self, message: str) -> None: + """Report information.""" + + def success(self, message: str) -> None: + """Report success.""" diff --git a/graphrag/logging/print_progress.py b/graphrag/logging/print_progress.py new file mode 100644 index 0000000000..d529e0dfd6 --- /dev/null +++ b/graphrag/logging/print_progress.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Print Progress Reporter.""" + +from .types import Progress, ProgressReporter + + +class PrintProgressReporter(ProgressReporter): + """A progress reporter that does nothing.""" + + prefix: str + + def __init__(self, prefix: str): + """Create a new progress reporter.""" + self.prefix = prefix + print(f"\n{self.prefix}", end="") # noqa T201 + + def __call__(self, update: Progress) -> None: + """Update progress.""" + print(".", end="") # noqa T201 + + def dispose(self) -> None: + """Dispose of the progress reporter.""" + + def child(self, prefix: str, transient: bool = True) -> "ProgressReporter": + """Create a child progress bar.""" + return PrintProgressReporter(prefix) + + def stop(self) -> None: + """Stop the progress reporter.""" + + def force_refresh(self) -> None: + """Force a refresh.""" + + def error(self, message: str) -> None: + """Report an error.""" + print(f"\n{self.prefix}ERROR: {message}") # noqa T201 + + def warning(self, message: str) -> None: + """Report a warning.""" + print(f"\n{self.prefix}WARNING: {message}") # noqa T201 + + def info(self, message: str) -> None: + """Report information.""" + print(f"\n{self.prefix}INFO: {message}") # noqa T201 + + def success(self, message: str) -> None: + """Report success.""" + print(f"\n{self.prefix}SUCCESS: {message}") # noqa T201 diff --git a/graphrag/index/progress/rich.py b/graphrag/logging/rich_progress.py similarity index 100% rename from graphrag/index/progress/rich.py rename to graphrag/logging/rich_progress.py diff --git a/graphrag/logging/types.py b/graphrag/logging/types.py new file mode 100644 index 0000000000..5b2ef26d23 --- /dev/null +++ b/graphrag/logging/types.py @@ -0,0 +1,82 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Types for status reporting.""" + +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any + +from datashaper import Progress + + +class ReporterType(Enum): + """The type of reporter to use.""" + + RICH = "rich" + PRINT = "print" + NONE = "none" + + def __str__(self): + """Return the string representation of the enum value.""" + return self.value + + +class StatusLogger(ABC): + """Provides a way to report status updates from the pipeline.""" + + @abstractmethod + def error(self, message: str, details: dict[str, Any] | None = None): + """Report an error.""" + + @abstractmethod + def warning(self, message: str, details: dict[str, Any] | None = None): + """Report a warning.""" + + @abstractmethod + def log(self, message: str, details: dict[str, Any] | None = None): + """Report a log.""" + + +class ProgressReporter(ABC): + """ + Abstract base class for progress reporters. + + This is used to report workflow processing progress via mechanisms like progress-bars. + """ + + @abstractmethod + def __call__(self, update: Progress): + """Update progress.""" + + @abstractmethod + def dispose(self): + """Dispose of the progress reporter.""" + + @abstractmethod + def child(self, prefix: str, transient=True) -> "ProgressReporter": + """Create a child progress bar.""" + + @abstractmethod + def force_refresh(self) -> None: + """Force a refresh.""" + + @abstractmethod + def stop(self) -> None: + """Stop the progress reporter.""" + + @abstractmethod + def error(self, message: str) -> None: + """Report an error.""" + + @abstractmethod + def warning(self, message: str) -> None: + """Report a warning.""" + + @abstractmethod + def info(self, message: str) -> None: + """Report information.""" + + @abstractmethod + def success(self, message: str) -> None: + """Report success.""" diff --git a/graphrag/prompt_tune/__main__.py b/graphrag/prompt_tune/__main__.py index aac668dba1..b55ccb4468 100644 --- a/graphrag/prompt_tune/__main__.py +++ b/graphrag/prompt_tune/__main__.py @@ -6,9 +6,9 @@ import argparse import asyncio +from graphrag.api import DocSelectionType from graphrag.utils.cli import dir_exist, file_exist -from .api import DocSelectionType from .cli import prompt_tune from .generator import MAX_TOKEN_COUNT from .loader import MIN_CHUNK_SIZE diff --git a/graphrag/prompt_tune/cli.py b/graphrag/prompt_tune/cli.py index dd964260ca..0232b0537c 100644 --- a/graphrag/prompt_tune/cli.py +++ b/graphrag/prompt_tune/cli.py @@ -5,21 +5,20 @@ from pathlib import Path +import graphrag.api as api from graphrag.config import load_config -from graphrag.index.progress import PrintProgressReporter +from graphrag.logging import PrintProgressReporter -from . import api from .generator.community_report_summarization import COMMUNITY_SUMMARIZATION_FILENAME from .generator.entity_extraction_prompt import ENTITY_EXTRACTION_FILENAME from .generator.entity_summarization_prompt import ENTITY_SUMMARIZATION_FILENAME -from .types import DocSelectionType async def prompt_tune( config: str, root: str, domain: str, - selection_method: DocSelectionType, + selection_method: api.DocSelectionType, limit: int, max_tokens: int, chunk_size: int, diff --git a/graphrag/prompt_tune/loader/input.py b/graphrag/prompt_tune/loader/input.py index a61f0a4726..5fd5719666 100644 --- a/graphrag/prompt_tune/loader/input.py +++ b/graphrag/prompt_tune/loader/input.py @@ -12,8 +12,8 @@ from graphrag.index.input import load_input from graphrag.index.llm import load_llm_embeddings from graphrag.index.operations.chunk_text import chunk_text -from graphrag.index.progress.types import ProgressReporter from graphrag.llm.types.llm_types import EmbeddingLLM +from graphrag.logging import ProgressReporter from graphrag.prompt_tune.types import DocSelectionType MIN_CHUNK_OVERLAP = 0 diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 863f16971c..23312ec4d2 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -9,13 +9,12 @@ import pandas as pd +import graphrag.api as api from graphrag.config import GraphRagConfig, load_config, resolve_paths from graphrag.index.create_pipeline_config import create_pipeline_config -from graphrag.index.progress import PrintProgressReporter +from graphrag.logging import PrintProgressReporter from graphrag.utils.storage import _create_storage, _load_table_from_storage -from . import api - reporter = PrintProgressReporter("") diff --git a/graphrag/query/llm/base.py b/graphrag/query/llm/base.py index 2c18bb29a1..a3a8cceb7f 100644 --- a/graphrag/query/llm/base.py +++ b/graphrag/query/llm/base.py @@ -7,16 +7,7 @@ from collections.abc import AsyncGenerator, Generator from typing import Any - -class BaseLLMCallback: - """Base class for LLM callbacks.""" - - def __init__(self): - self.response = [] - - def on_llm_new_token(self, token: str): - """Handle when a new token is generated.""" - self.response.append(token) +from graphrag.callbacks.llm_callbacks import BaseLLMCallback class BaseLLM(ABC): diff --git a/graphrag/query/llm/oai/base.py b/graphrag/query/llm/oai/base.py index 6181c0b2a5..08a90d98a4 100644 --- a/graphrag/query/llm/oai/base.py +++ b/graphrag/query/llm/oai/base.py @@ -8,9 +8,9 @@ from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI +from graphrag.logging import ConsoleReporter, StatusLogger from graphrag.query.llm.base import BaseTextEmbedding from graphrag.query.llm.oai.typing import OpenaiApiType -from graphrag.query.progress import ConsoleStatusReporter, StatusReporter class BaseOpenAILLM(ABC): @@ -87,7 +87,7 @@ def sync_client(self, client: OpenAI | AzureOpenAI): class OpenAILLMImpl(BaseOpenAILLM): """Orchestration OpenAI LLM Implementation.""" - _reporter: StatusReporter = ConsoleStatusReporter() + _reporter: StatusLogger = ConsoleReporter() def __init__( self, @@ -100,7 +100,7 @@ def __init__( organization: str | None = None, max_retries: int = 10, request_timeout: float = 180.0, - reporter: StatusReporter | None = None, + reporter: StatusLogger | None = None, ): self.api_key = api_key self.azure_ad_token_provider = azure_ad_token_provider @@ -111,7 +111,7 @@ def __init__( self.organization = organization self.max_retries = max_retries self.request_timeout = request_timeout - self.reporter = reporter or ConsoleStatusReporter() + self.reporter = reporter or ConsoleReporter() try: # Create OpenAI sync and async clients @@ -181,7 +181,7 @@ def _create_openai_client(self): class OpenAITextEmbeddingImpl(BaseTextEmbedding): """Orchestration OpenAI Text Embedding Implementation.""" - _reporter: StatusReporter | None = None + _reporter: StatusLogger | None = None def _create_openai_client(self, api_type: OpenaiApiType): """Create a new synchronous and asynchronous OpenAI client instance.""" diff --git a/graphrag/query/llm/oai/chat_openai.py b/graphrag/query/llm/oai/chat_openai.py index 7dc3579a19..621ebecebe 100644 --- a/graphrag/query/llm/oai/chat_openai.py +++ b/graphrag/query/llm/oai/chat_openai.py @@ -15,13 +15,13 @@ wait_exponential_jitter, ) +from graphrag.logging import StatusLogger from graphrag.query.llm.base import BaseLLM, BaseLLMCallback from graphrag.query.llm.oai.base import OpenAILLMImpl from graphrag.query.llm.oai.typing import ( OPENAI_RETRY_ERROR_TYPES, OpenaiApiType, ) -from graphrag.query.progress import StatusReporter _MODEL_REQUIRED_MSG = "model is required" @@ -42,7 +42,7 @@ def __init__( max_retries: int = 10, request_timeout: float = 180.0, retry_error_types: tuple[type[BaseException]] = OPENAI_RETRY_ERROR_TYPES, # type: ignore - reporter: StatusReporter | None = None, + reporter: StatusLogger | None = None, ): OpenAILLMImpl.__init__( self=self, diff --git a/graphrag/query/llm/oai/embedding.py b/graphrag/query/llm/oai/embedding.py index f40372dbce..6b39a0017f 100644 --- a/graphrag/query/llm/oai/embedding.py +++ b/graphrag/query/llm/oai/embedding.py @@ -18,6 +18,7 @@ wait_exponential_jitter, ) +from graphrag.logging import StatusLogger from graphrag.query.llm.base import BaseTextEmbedding from graphrag.query.llm.oai.base import OpenAILLMImpl from graphrag.query.llm.oai.typing import ( @@ -25,7 +26,6 @@ OpenaiApiType, ) from graphrag.query.llm.text_utils import chunk_text -from graphrag.query.progress import StatusReporter class OpenAIEmbedding(BaseTextEmbedding, OpenAILLMImpl): @@ -46,7 +46,7 @@ def __init__( max_retries: int = 10, request_timeout: float = 180.0, retry_error_types: tuple[type[BaseException]] = OPENAI_RETRY_ERROR_TYPES, # type: ignore - reporter: StatusReporter | None = None, + reporter: StatusLogger | None = None, ): OpenAILLMImpl.__init__( self=self, diff --git a/graphrag/query/structured_search/global_search/search.py b/graphrag/query/structured_search/global_search/search.py index 7076b2106b..5e8a71b937 100644 --- a/graphrag/query/structured_search/global_search/search.py +++ b/graphrag/query/structured_search/global_search/search.py @@ -14,6 +14,7 @@ import pandas as pd import tiktoken +from graphrag.callbacks.global_search_callbacks import GlobalSearchLLMCallback from graphrag.llm.openai.utils import try_parse_json_object from graphrag.query.context_builder.builders import GlobalContextBuilder from graphrag.query.context_builder.conversation_history import ( @@ -22,9 +23,6 @@ from graphrag.query.llm.base import BaseLLM from graphrag.query.llm.text_utils import num_tokens from graphrag.query.structured_search.base import BaseSearch, SearchResult -from graphrag.query.structured_search.global_search.callbacks import ( - GlobalSearchLLMCallback, -) from graphrag.query.structured_search.global_search.map_system_prompt import ( MAP_SYSTEM_PROMPT, ) diff --git a/graphrag/vector_stores/__init__.py b/graphrag/vector_stores/__init__.py index d4c11760aa..764d51b10e 100644 --- a/graphrag/vector_stores/__init__.py +++ b/graphrag/vector_stores/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""A package containing vector-storage implementations.""" +"""A module containing vector storage implementations.""" from .azure_ai_search import AzureAISearch from .base import BaseVectorStore, VectorStoreDocument, VectorStoreSearchResult diff --git a/pyproject.toml b/pyproject.toml index ec32220d04..5fb3aef22c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,9 +108,9 @@ requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0"] build-backend = "poetry_dynamic_versioning.backend" [tool.poe.tasks] -_sort_imports = "ruff check --select I --fix . --preview" -_format_code = "ruff format . --preview" -_ruff_check = 'ruff check . --preview' +_sort_imports = "ruff check --select I --fix ." +_format_code = "ruff format ." +_ruff_check = 'ruff check .' _pyright = "pyright" _convert_local_search_nb = 'jupyter nbconvert --output-dir=docsite/posts/query/notebooks/ --output="{notebook_name}_nb" --template=docsite/nbdocsite_template --to markdown examples_notebooks/local_search.ipynb' _convert_global_search_nb = 'jupyter nbconvert --output-dir=docsite/posts/query/notebooks/ --output="{notebook_name}_nb" --template=docsite/nbdocsite_template --to markdown examples_notebooks/global_search.ipynb' @@ -119,9 +119,9 @@ _semversioner_changelog = "semversioner changelog > CHANGELOG.md" _semversioner_update_toml_version = "update-toml update --path tool.poetry.version --value $(poetry run semversioner current-version)" semversioner_add = "semversioner add-change" coverage_report = 'coverage report --omit "**/tests/**" --show-missing' -check_format = 'ruff format . --check --preview' -fix = "ruff --preview check --fix ." -fix_unsafe = "ruff check --preview --fix --unsafe-fixes ." +check_format = 'ruff format . --check' +fix = "ruff check --fix ." +fix_unsafe = "ruff check --fix --unsafe-fixes ." _test_all = "coverage run -m pytest ./tests" test_unit = "pytest ./tests/unit" @@ -164,10 +164,12 @@ target-version = "py310" extend-include = ["*.ipynb"] [tool.ruff.format] +preview = true docstring-code-format = true docstring-code-line-length = 20 [tool.ruff.lint] +preview = true select = [ "E4", "E7",