Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions mindtrace/registry/mindtrace/registry/core/_registry_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
version_objects: bool | None = None,
mutable: bool | None = None,
versions_cache_ttl: float = 60.0,
default_materializer: Type[BaseMaterializer] | str | None = None,
**kwargs,
):
"""Initialize the registry core.
Expand All @@ -66,6 +67,8 @@ def __init__(
If explicitly set, must match the stored setting (if any) or a ValueError is raised.
Object level concurrency is handled via lock-free MVCC for both mutable and immutable registries.
versions_cache_ttl: Time-to-live in seconds for the versions cache. Default is 60.0 seconds.
default_materializer: Optional fallback materializer used when no
type-specific materializer is registered.
**kwargs: Additional arguments to pass to the backend.
"""
super().__init__(**kwargs)
Expand All @@ -80,6 +83,11 @@ def __init__(

self.backend = backend

if isinstance(default_materializer, type):
self._default_materializer = f"{default_materializer.__module__}.{default_materializer.__name__}"
else:
self._default_materializer = default_materializer

# Initialize registry metadata (version_objects, mutable) in a single read/write
self.version_objects, self.mutable = self._initialize_registry_metadata(
version_objects=version_objects if version_objects is not None else False,
Expand Down Expand Up @@ -226,7 +234,8 @@ def _find_materializer(self, obj: Any, provided_materializer: Type[BaseMateriali
1. Materializer provided as an argument.
2. Materializer previously registered for the object type.
3. Materializer for any of the object's base classes (checked recursively).
4. The object itself, if it's its own materializer.
4. Registry-level default materializer (if configured).
5. The object itself, if it's its own materializer.

Args:
obj: Object to find materializer for.
Expand Down Expand Up @@ -257,6 +266,7 @@ def get_all_base_classes(cls):
self.registered_materializer(f"{base.__module__}.{base.__name__}")
for base in get_all_base_classes(type(obj))
],
self._default_materializer,
object_class if isinstance(obj, BaseMaterializer) else None,
)
)
Expand Down Expand Up @@ -540,9 +550,14 @@ def _materialize(self, temp_dir: Path, metadata: dict, **kwargs) -> Any:
materializer = instantiate_target(materializer_class, uri=str(temp_dir), artifact_store=self._artifact_store)

if isinstance(object_class, str):
module_name, class_name = object_class.rsplit(".", 1)
module = __import__(module_name, fromlist=[class_name])
object_class = getattr(module, class_name)
try:
module_name, class_name = object_class.rsplit(".", 1)
module = __import__(module_name, fromlist=[class_name])
object_class = getattr(module, class_name)
except Exception:
# Some serializers (e.g., cloudpickle materializers) can deserialize
# objects that are not importable by dotted path (lambdas/local classes).
object_class = Any

return materializer.load(data_type=object_class, **init_params)

Expand Down
6 changes: 6 additions & 0 deletions mindtrace/registry/mindtrace/registry/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(
version_objects: bool | None = None,
mutable: bool | None = None,
versions_cache_ttl: float = 60.0,
default_materializer: Type[BaseMaterializer] | None = None,
use_cache: bool = True,
**kwargs,
):
Expand All @@ -97,6 +98,8 @@ def __init__(
mutable: Whether to allow overwriting existing versions. If ``None``
(default), uses the stored setting, or ``False`` for a new registry.
versions_cache_ttl: TTL in seconds for the in-memory versions cache.
default_materializer: Optional fallback materializer used when no
type-specific materializer is registered.
use_cache: Whether to maintain a local cache for remote backends.
Default ``True``.
**kwargs: Additional arguments forwarded to the backend.
Expand All @@ -112,6 +115,7 @@ def __init__(
version_objects=version_objects,
mutable=mutable,
versions_cache_ttl=versions_cache_ttl,
default_materializer=default_materializer,
**kwargs,
)
cache_dir = self._get_cache_dir(self._remote.backend.uri)
Expand All @@ -120,6 +124,7 @@ def __init__(
version_objects=self._remote.version_objects,
mutable=True, # cache is always mutable for updates
versions_cache_ttl=versions_cache_ttl,
default_materializer=default_materializer,
**kwargs,
)
self._core = self._remote
Expand All @@ -131,6 +136,7 @@ def __init__(
version_objects=version_objects,
mutable=mutable,
versions_cache_ttl=versions_cache_ttl,
default_materializer=default_materializer,
**kwargs,
)
self._remote = None # type: ignore
Expand Down
56 changes: 56 additions & 0 deletions tests/unit/mindtrace/registry/core/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,62 @@ def __init__(self):
registry.save("test:custom", custom_obj)


def test_registry_default_materializer_supports_lambda(temp_registry_dir):
"""Registry default_materializer should enable lambda/cloudpickle roundtrips."""
from zenml.materializers.cloudpickle_materializer import CloudpickleMaterializer

registry = Registry(
backend=temp_registry_dir,
version_objects=True,
default_materializer=CloudpickleMaterializer,
)

fn = lambda x: x + 1
registry.save("test:lambda", fn, version="1.0.0")

loaded_fn = registry.load("test:lambda", version="1.0.0")
assert callable(loaded_fn)
assert loaded_fn(1) == 2


def test_registry_default_materializer_supports_reduce_protocol(temp_registry_dir):
"""Registry default_materializer should serialize objects via __reduce__."""
from zenml.materializers.cloudpickle_materializer import CloudpickleMaterializer

class Point:
def __init__(self, x):
self.x = x

def __reduce__(self):
return (Point, (self.x,))

registry = Registry(
backend=temp_registry_dir,
version_objects=True,
default_materializer=CloudpickleMaterializer,
)

registry.save("test:point", Point(5), version="1.0.0")
point = registry.load("test:point", version="1.0.0")
assert isinstance(point, Point)
assert point.x == 5


def test_registry_core_default_materializer_string_is_used(temp_registry_dir):
"""_RegistryCore should accept a string default_materializer fallback."""
registry = Registry(
backend=temp_registry_dir,
version_objects=True,
default_materializer="zenml.materializers.cloudpickle_materializer.CloudpickleMaterializer",
)

class CustomObject:
pass

materializer = registry._find_materializer(CustomObject())
assert materializer == "zenml.materializers.cloudpickle_materializer.CloudpickleMaterializer"


def test_find_materializer_with_class_object(registry, test_config):
"""Test that _find_materializer() converts a materializer class object to a string.

Expand Down
Loading