Skip to content

Refactor StorageFactory class to use registration functionality #1944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/minor-20250521041234833898.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Refactored StorageFactory to use a registration-based approach"
}
83 changes: 64 additions & 19 deletions graphrag/storage/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

from contextlib import suppress
from typing import TYPE_CHECKING, ClassVar

from graphrag.config.enums import OutputType
Expand All @@ -14,6 +15,8 @@
from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage

if TYPE_CHECKING:
from collections.abc import Callable

from graphrag.storage.pipeline_storage import PipelineStorage


Expand All @@ -26,29 +29,71 @@ class StorageFactory:
for individual enforcement of required/optional arguments.
"""

storage_types: ClassVar[dict[str, type]] = {}
_storage_registry: ClassVar[dict[str, Callable[..., PipelineStorage]]] = {}
storage_types: ClassVar[dict[str, type]] = {} # For backward compatibility

@classmethod
def register(cls, storage_type: str, storage: type):
"""Register a custom storage implementation."""
cls.storage_types[storage_type] = storage
def register(
cls, storage_type: str, creator: Callable[..., PipelineStorage]
) -> None:
"""Register a custom storage implementation.

Args:
storage_type: The type identifier for the storage.
creator: A callable that creates an instance of the storage.
"""
cls._storage_registry[storage_type] = creator

# For backward compatibility with code that may access storage_types directly
if (
callable(creator)
and hasattr(creator, "__annotations__")
and "return" in creator.__annotations__
):
with suppress(TypeError, KeyError):
cls.storage_types[storage_type] = creator.__annotations__["return"]

@classmethod
def create_storage(
cls, storage_type: OutputType | str, kwargs: dict
) -> PipelineStorage:
"""Create or get a storage object from the provided type."""
match storage_type:
case OutputType.blob:
return create_blob_storage(**kwargs)
case OutputType.cosmosdb:
return create_cosmosdb_storage(**kwargs)
case OutputType.file:
return create_file_storage(**kwargs)
case OutputType.memory:
return MemoryPipelineStorage()
case _:
if storage_type in cls.storage_types:
return cls.storage_types[storage_type](**kwargs)
msg = f"Unknown storage type: {storage_type}"
raise ValueError(msg)
"""Create a storage object from the provided type.

Args:
storage_type: The type of storage to create.
kwargs: Additional keyword arguments for the storage constructor.

Returns
-------
A PipelineStorage instance.

Raises
------
ValueError: If the storage type is not registered.
"""
storage_type_str = (
storage_type.value if isinstance(storage_type, OutputType) else storage_type
)

if storage_type_str not in cls._storage_registry:
msg = f"Unknown storage type: {storage_type}"
raise ValueError(msg)

return cls._storage_registry[storage_type_str](**kwargs)

@classmethod
def get_storage_types(cls) -> list[str]:
"""Get the registered storage implementations."""
return list(cls._storage_registry.keys())

@classmethod
def is_supported_storage_type(cls, storage_type: str) -> bool:
"""Check if the given storage type is supported."""
return storage_type in cls._storage_registry


# --- Register default implementations ---
StorageFactory.register(OutputType.blob.value, create_blob_storage)
StorageFactory.register(OutputType.cosmosdb.value, create_cosmosdb_storage)
StorageFactory.register(OutputType.file.value, create_file_storage)
StorageFactory.register(OutputType.memory.value, lambda **_: MemoryPipelineStorage())
45 changes: 39 additions & 6 deletions tests/integration/storage/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
from graphrag.storage.factory import StorageFactory
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage
from graphrag.storage.pipeline_storage import PipelineStorage

# cspell:disable-next-line well-known-key
WELL_KNOWN_BLOB_STORAGE_KEY = "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;"
# cspell:disable-next-line well-known-key
WELL_KNOWN_COSMOS_CONNECTION_STRING = "AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="


@pytest.mark.skip(reason="Blob storage emulator is not available in this environment")
def test_create_blob_storage():
kwargs = {
"type": "blob",
Expand Down Expand Up @@ -61,13 +63,44 @@ def test_create_memory_storage():


def test_register_and_create_custom_storage():
class CustomStorage:
def __init__(self, **kwargs):
pass

StorageFactory.register("custom", CustomStorage)
"""Test registering and creating a custom storage type."""
from unittest.mock import MagicMock

# Create a mock that satisfies the PipelineStorage interface
custom_storage_class = MagicMock(spec=PipelineStorage)
# Make the mock return a mock instance when instantiated
instance = MagicMock()
# We can set attributes on the mock instance, even if they don't exist on PipelineStorage
instance.initialized = True
custom_storage_class.return_value = instance

StorageFactory.register("custom", lambda **kwargs: custom_storage_class(**kwargs))
storage = StorageFactory.create_storage("custom", {})
assert isinstance(storage, CustomStorage)

assert custom_storage_class.called
assert storage is instance
# Access the attribute we set on our mock
assert storage.initialized is True # type: ignore # Attribute only exists on our mock

# Check if it's in the list of registered storage types
assert "custom" in StorageFactory.get_storage_types()
assert StorageFactory.is_supported_storage_type("custom")


def test_get_storage_types():
storage_types = StorageFactory.get_storage_types()
# Check that built-in types are registered
assert OutputType.file.value in storage_types
assert OutputType.memory.value in storage_types
assert OutputType.blob.value in storage_types
assert OutputType.cosmosdb.value in storage_types


def test_backward_compatibility():
"""Test that the storage_types attribute is still accessible for backward compatibility."""
assert hasattr(StorageFactory, "storage_types")
# The storage_types attribute should be a dict
assert isinstance(StorageFactory.storage_types, dict)


def test_create_unknown_storage():
Expand Down
Loading