From cbb8f8788e060f53a3dbe9580eb620c648b3343d Mon Sep 17 00:00:00 2001 From: Josh Bradley Date: Fri, 3 Jan 2025 17:39:44 -0500 Subject: [PATCH] Fix storage class instantiation (#1582) --- .../patch-20250103210427219013.json | 4 + graphrag/cache/factory.py | 7 +- graphrag/storage/blob_pipeline_storage.py | 11 ++- graphrag/storage/factory.py | 3 + tests/integration/storage/test_factory.py | 75 +++++++++++++++++++ 5 files changed, 92 insertions(+), 8 deletions(-) create mode 100644 .semversioner/next-release/patch-20250103210427219013.json create mode 100644 tests/integration/storage/test_factory.py diff --git a/.semversioner/next-release/patch-20250103210427219013.json b/.semversioner/next-release/patch-20250103210427219013.json new file mode 100644 index 0000000000..9a89ae1b5e --- /dev/null +++ b/.semversioner/next-release/patch-20250103210427219013.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "fix instantiation of storage classes." +} diff --git a/graphrag/cache/factory.py b/graphrag/cache/factory.py index 7490233cb7..f44c68953b 100644 --- a/graphrag/cache/factory.py +++ b/graphrag/cache/factory.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, ClassVar from graphrag.config.enums import CacheType -from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage +from graphrag.storage.blob_pipeline_storage import create_blob_storage from graphrag.storage.cosmosdb_pipeline_storage import create_cosmosdb_storage from graphrag.storage.file_pipeline_storage import FilePipelineStorage @@ -24,6 +24,9 @@ class CacheFactory: """A factory class for cache implementations. Includes a method for users to register a custom cache implementation. + + Configuration arguments are passed to each cache implementation as kwargs (where possible) + for individual enforcement of required/optional arguments. """ cache_types: ClassVar[dict[str, type]] = {} @@ -50,7 +53,7 @@ def create_cache( FilePipelineStorage(root_dir=root_dir).child(kwargs["base_dir"]) ) case CacheType.blob: - return JsonPipelineCache(BlobPipelineStorage(**kwargs)) + return JsonPipelineCache(create_blob_storage(**kwargs)) case CacheType.cosmosdb: return JsonPipelineCache(create_cosmosdb_storage(**kwargs)) case _: diff --git a/graphrag/storage/blob_pipeline_storage.py b/graphrag/storage/blob_pipeline_storage.py index f72663052c..5da7598f3e 100644 --- a/graphrag/storage/blob_pipeline_storage.py +++ b/graphrag/storage/blob_pipeline_storage.py @@ -290,13 +290,12 @@ def _abfs_url(self, key: str) -> str: return f"abfs://{path}" -def create_blob_storage( - connection_string: str | None, - storage_account_blob_url: str | None, - container_name: str, - base_dir: str | None, -) -> PipelineStorage: +def create_blob_storage(**kwargs: Any) -> PipelineStorage: """Create a blob based storage.""" + connection_string = kwargs.get("connection_string") + storage_account_blob_url = kwargs.get("storage_account_blob_url") + base_dir = kwargs.get("base_dir") + container_name = kwargs["container_name"] log.info("Creating blob storage at %s", container_name) if container_name is None: msg = "No container name provided for blob storage." diff --git a/graphrag/storage/factory.py b/graphrag/storage/factory.py index e3346af401..d9243fb7d2 100644 --- a/graphrag/storage/factory.py +++ b/graphrag/storage/factory.py @@ -21,6 +21,9 @@ class StorageFactory: """A factory class for storage implementations. Includes a method for users to register a custom storage implementation. + + Configuration arguments are passed to each storage implementation as kwargs + for individual enforcement of required/optional arguments. """ storage_types: ClassVar[dict[str, type]] = {} diff --git a/tests/integration/storage/test_factory.py b/tests/integration/storage/test_factory.py new file mode 100644 index 0000000000..81e1781dba --- /dev/null +++ b/tests/integration/storage/test_factory.py @@ -0,0 +1,75 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +"""StorageFactory Tests. + +These tests will test the StorageFactory class and the creation of each storage type that is natively supported. +""" + +import sys + +import pytest + +from graphrag.config.enums import StorageType +from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage +from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage +from graphrag.storage.factory import StorageFactory +from graphrag.storage.file_pipeline_storage import FilePipelineStorage +from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage + +# cspell:disable-next-line well-known-key +WELL_KNOWN_BLOB_STORAGE_KEY = "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;" +# cspell:disable-next-line well-known-key +WELL_KNOWN_COSMOS_CONNECTION_STRING = "AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==" + + +def test_create_blob_storage(): + kwargs = { + "type": "blob", + "connection_string": WELL_KNOWN_BLOB_STORAGE_KEY, + "base_dir": "testbasedir", + "container_name": "testcontainer", + } + storage = StorageFactory.create_storage(StorageType.blob, kwargs) + assert isinstance(storage, BlobPipelineStorage) + + +@pytest.mark.skipif( + not sys.platform.startswith("win"), + reason="cosmosdb emulator is only available on windows runners at this time", +) +def test_create_cosmosdb_storage(): + kwargs = { + "type": "cosmosdb", + "connection_string": WELL_KNOWN_COSMOS_CONNECTION_STRING, + "base_dir": "testdatabase", + "container_name": "testcontainer", + } + storage = StorageFactory.create_storage(StorageType.cosmosdb, kwargs) + assert isinstance(storage, CosmosDBPipelineStorage) + + +def test_create_file_storage(): + kwargs = {"type": "file", "base_dir": "/tmp/teststorage"} + storage = StorageFactory.create_storage(StorageType.file, kwargs) + assert isinstance(storage, FilePipelineStorage) + + +def test_create_memory_storage(): + kwargs = {"type": "memory"} + storage = StorageFactory.create_storage(StorageType.memory, kwargs) + assert isinstance(storage, MemoryPipelineStorage) + + +def test_register_and_create_custom_storage(): + class CustomStorage: + def __init__(self, **kwargs): + pass + + StorageFactory.register("custom", CustomStorage) + storage = StorageFactory.create_storage("custom", {}) + assert isinstance(storage, CustomStorage) + + +def test_create_unknown_storage(): + with pytest.raises(ValueError, match="Unknown storage type: unknown"): + StorageFactory.create_storage("unknown", {})