Skip to content

Commit

Permalink
Fix storage class instantiation (#1582)
Browse files Browse the repository at this point in the history
  • Loading branch information
jgbradley1 authored Jan 3, 2025
1 parent a35cb12 commit cbb8f87
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 8 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20250103210427219013.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "fix instantiation of storage classes."
}
7 changes: 5 additions & 2 deletions graphrag/cache/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import TYPE_CHECKING, ClassVar

from graphrag.config.enums import CacheType
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
from graphrag.storage.blob_pipeline_storage import create_blob_storage
from graphrag.storage.cosmosdb_pipeline_storage import create_cosmosdb_storage
from graphrag.storage.file_pipeline_storage import FilePipelineStorage

Expand All @@ -24,6 +24,9 @@ class CacheFactory:
"""A factory class for cache implementations.
Includes a method for users to register a custom cache implementation.
Configuration arguments are passed to each cache implementation as kwargs (where possible)
for individual enforcement of required/optional arguments.
"""

cache_types: ClassVar[dict[str, type]] = {}
Expand All @@ -50,7 +53,7 @@ def create_cache(
FilePipelineStorage(root_dir=root_dir).child(kwargs["base_dir"])
)
case CacheType.blob:
return JsonPipelineCache(BlobPipelineStorage(**kwargs))
return JsonPipelineCache(create_blob_storage(**kwargs))
case CacheType.cosmosdb:
return JsonPipelineCache(create_cosmosdb_storage(**kwargs))
case _:
Expand Down
11 changes: 5 additions & 6 deletions graphrag/storage/blob_pipeline_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,13 +290,12 @@ def _abfs_url(self, key: str) -> str:
return f"abfs://{path}"


def create_blob_storage(
connection_string: str | None,
storage_account_blob_url: str | None,
container_name: str,
base_dir: str | None,
) -> PipelineStorage:
def create_blob_storage(**kwargs: Any) -> PipelineStorage:
"""Create a blob based storage."""
connection_string = kwargs.get("connection_string")
storage_account_blob_url = kwargs.get("storage_account_blob_url")
base_dir = kwargs.get("base_dir")
container_name = kwargs["container_name"]
log.info("Creating blob storage at %s", container_name)
if container_name is None:
msg = "No container name provided for blob storage."
Expand Down
3 changes: 3 additions & 0 deletions graphrag/storage/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ class StorageFactory:
"""A factory class for storage implementations.
Includes a method for users to register a custom storage implementation.
Configuration arguments are passed to each storage implementation as kwargs
for individual enforcement of required/optional arguments.
"""

storage_types: ClassVar[dict[str, type]] = {}
Expand Down
75 changes: 75 additions & 0 deletions tests/integration/storage/test_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""StorageFactory Tests.
These tests will test the StorageFactory class and the creation of each storage type that is natively supported.
"""

import sys

import pytest

from graphrag.config.enums import StorageType
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage
from graphrag.storage.factory import StorageFactory
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage

# cspell:disable-next-line well-known-key
WELL_KNOWN_BLOB_STORAGE_KEY = "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;"
# cspell:disable-next-line well-known-key
WELL_KNOWN_COSMOS_CONNECTION_STRING = "AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="


def test_create_blob_storage():
kwargs = {
"type": "blob",
"connection_string": WELL_KNOWN_BLOB_STORAGE_KEY,
"base_dir": "testbasedir",
"container_name": "testcontainer",
}
storage = StorageFactory.create_storage(StorageType.blob, kwargs)
assert isinstance(storage, BlobPipelineStorage)


@pytest.mark.skipif(
not sys.platform.startswith("win"),
reason="cosmosdb emulator is only available on windows runners at this time",
)
def test_create_cosmosdb_storage():
kwargs = {
"type": "cosmosdb",
"connection_string": WELL_KNOWN_COSMOS_CONNECTION_STRING,
"base_dir": "testdatabase",
"container_name": "testcontainer",
}
storage = StorageFactory.create_storage(StorageType.cosmosdb, kwargs)
assert isinstance(storage, CosmosDBPipelineStorage)


def test_create_file_storage():
kwargs = {"type": "file", "base_dir": "/tmp/teststorage"}
storage = StorageFactory.create_storage(StorageType.file, kwargs)
assert isinstance(storage, FilePipelineStorage)


def test_create_memory_storage():
kwargs = {"type": "memory"}
storage = StorageFactory.create_storage(StorageType.memory, kwargs)
assert isinstance(storage, MemoryPipelineStorage)


def test_register_and_create_custom_storage():
class CustomStorage:
def __init__(self, **kwargs):
pass

StorageFactory.register("custom", CustomStorage)
storage = StorageFactory.create_storage("custom", {})
assert isinstance(storage, CustomStorage)


def test_create_unknown_storage():
with pytest.raises(ValueError, match="Unknown storage type: unknown"):
StorageFactory.create_storage("unknown", {})

0 comments on commit cbb8f87

Please sign in to comment.