Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from typing import Any, Callable, Dict, List, Optional

import docker
import docker.types
from docker.errors import DockerException

import docker
from mindtrace.core import Mindtrace


Expand Down
79 changes: 44 additions & 35 deletions mindtrace/registry/mindtrace/registry/core/_registry_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@
BatchResult,
OnConflict,
VerifyLevel,
Version,
)


def _version_sort_key(v: str) -> list[int]:
"""Sort key for semantic version strings."""
def _version_sort_key(v: str, digits: int) -> tuple[int, ...]:
"""Sort key for fixed-width numeric version strings."""
try:
return [int(x) for x in v.split(".")]
return Version(v, digits=digits).parts
except ValueError:
return [0]
return tuple(0 for _ in range(digits))


class _RegistryCore(Mindtrace):
Expand All @@ -51,6 +52,7 @@ def __init__(
backend: str | Path | RegistryBackend | None = None,
version_objects: bool | None = None,
mutable: bool | None = None,
version_digits: int | None = None,
versions_cache_ttl: float = 60.0,
**kwargs,
):
Expand Down Expand Up @@ -80,12 +82,14 @@ def __init__(

self.backend = backend

# Initialize registry metadata (version_objects, mutable) in a single read/write
self.version_objects, self.mutable = self._initialize_registry_metadata(
# Initialize registry metadata (version_objects, mutable, version_digits) in a single read/write
self.version_objects, self.mutable, self.version_digits = self._initialize_registry_metadata(
version_objects=version_objects if version_objects is not None else False,
version_objects_explicit=version_objects is not None,
mutable=mutable if mutable is not None else False,
mutable_explicit=mutable is not None,
version_digits=version_digits if version_digits is not None else 3,
version_digits_explicit=version_digits is not None,
)

self._artifact_store = LocalArtifactStore(
Expand Down Expand Up @@ -141,7 +145,9 @@ def _initialize_registry_metadata(
version_objects_explicit: bool,
mutable: bool,
mutable_explicit: bool,
) -> tuple[bool, bool]:
version_digits: int,
version_digits_explicit: bool,
) -> tuple[bool, bool, int]:
"""Initialize registry metadata in a single read/write cycle.

Reads existing metadata once, validates both version_objects and mutable against
Expand Down Expand Up @@ -183,14 +189,29 @@ def _initialize_registry_metadata(
)
mutable = stored_mut

# Resolve version_digits
if version_digits < 1:
raise ValueError(f"version_digits must be >= 1, got {version_digits}")

stored_digits = existing.get("version_digits")
if stored_digits is not None:
if version_digits_explicit and stored_digits != version_digits:
raise ValueError(
f"Version digits conflict: existing registry has version_digits={stored_digits}, "
f"but new Registry instance was created with version_digits={version_digits}. "
"All Registry instances must use the same version_digits setting."
)
version_digits = stored_digits

# Write back only if any value was not yet persisted
if stored_vo is None or stored_mut is None:
if stored_vo is None or stored_mut is None or stored_digits is None:
existing.setdefault("materializers", {})
existing["version_objects"] = version_objects
existing["mutable"] = mutable
existing["version_digits"] = version_digits
self.backend.save_registry_metadata(existing)

return version_objects, mutable
return version_objects, mutable, version_digits

def _resolve_load_version(self, name: str, version: str | None) -> str:
"""Resolve a version string for loading into a concrete version.
Expand All @@ -209,7 +230,7 @@ def _resolve_load_version(self, name: str, version: str | None) -> str:
ValueError: If explicit version format is invalid.
"""
if not self.version_objects:
return "1"
return str(Version("1", digits=self.version_digits))

if version is None or version == "latest":
resolved = self._latest(name)
Expand Down Expand Up @@ -803,7 +824,7 @@ def _delete_single(
if version is None:
# Delete all versions
if not self.version_objects:
versions_to_delete = ["1"]
versions_to_delete = [str(Version("1", digits=self.version_digits))]
else:
versions_to_delete = self.list_versions(name)
if not versions_to_delete:
Expand Down Expand Up @@ -851,7 +872,7 @@ def _delete_batch(

if v is None:
# Delete all versions
all_versions = ["1"] if not self.version_objects else self.list_versions(n)
all_versions = [str(Version("1", digits=self.version_digits))] if not self.version_objects else self.list_versions(n)
if not all_versions:
result.errors[original_key] = {
"error": "RegistryObjectNotFound",
Expand Down Expand Up @@ -934,7 +955,7 @@ def info(self, name: str | None = None, version: str | None = None) -> Dict[str,
items = [(n, v) for n in self.list_objects() for v in self.list_versions(n)]
elif version is not None:
# Specific version (resolve "latest")
resolved_version = self._latest(name) if version == "latest" else version
resolved_version = self._latest(name) if version == "latest" else self._validate_version(version)
items = [(name, resolved_version)] if resolved_version else []
else:
# All versions for one object
Expand Down Expand Up @@ -1183,20 +1204,10 @@ def _validate_version(self, version: str | None) -> str:
"Resolve to a concrete version before calling."
)

# Remove any 'v' prefix
if version.startswith("v"):
version = version[1:]

# Split into components and validate
try:
components = version.split(".")
# Convert each component to int to validate
[int(c) for c in components]
return version
except ValueError:
raise ValueError(
f"Invalid version string '{version}'. Must be in semantic versioning format (e.g. '1', '1.0', '1.0.0')"
)
return str(Version(version, digits=self.version_digits))
except ValueError as e:
raise ValueError(str(e)) from e

def _format_object_value(self, object_name: str, version: str, class_name: str) -> str:
"""Format object value for display in __str__ method.
Expand Down Expand Up @@ -1257,7 +1268,7 @@ def __str__(self, *, color: bool = True, latest_only: bool = True) -> str:
for object_name, versions in info.items():
version_items = versions.items()
if latest_only and version_items:
version_items = [max(versions.items(), key=lambda kv: _version_sort_key(kv[0]))]
version_items = [max(versions.items(), key=lambda kv: _version_sort_key(kv[0], self.version_digits))]

for version, details in version_items:
meta = details.get("metadata", {})
Expand Down Expand Up @@ -1293,7 +1304,7 @@ def __str__(self, *, color: bool = True, latest_only: bool = True) -> str:
lines.append(f"\n🧠 {object_name}:")
version_items = versions.items()
if latest_only:
version_items = [max(versions.items(), key=lambda kv: _version_sort_key(kv[0]))]
version_items = [max(versions.items(), key=lambda kv: _version_sort_key(kv[0], self.version_digits))]
for version, details in version_items:
cls = details.get("class", "❓ Not registered")
value_str = self._format_object_value(object_name, version, cls)
Expand Down Expand Up @@ -1331,15 +1342,13 @@ def _next_version(self, name: str) -> str:
Next version string
"""
if not self.version_objects:
return "1"
return str(Version("1", digits=self.version_digits))

most_recent = self._latest(name)
if most_recent is None:
return "1"
components = most_recent.split(".")
components[-1] = str(int(components[-1]) + 1)
return str(Version("1", digits=self.version_digits))

return ".".join(components)
return str(Version(most_recent, digits=self.version_digits).bump())

def _latest(self, name: str) -> str:
"""Return the most recent version string for an object.
Expand All @@ -1357,7 +1366,7 @@ def _latest(self, name: str) -> str:
# Filter out temporary versions (those with __temp__ prefix)
versions = [v for v in versions if not v.startswith("__temp__")]

return sorted(versions, key=_version_sort_key)[-1]
return sorted(versions, key=lambda v: _version_sort_key(v, self.version_digits))[-1]

def _register_default_materializers(self, override_preexisting_materializers: bool = False):
"""Register default materializers from the class-level registry.
Expand Down Expand Up @@ -1584,7 +1593,7 @@ def clear(self, clear_registry_metadata: bool = False) -> None:
if clear_registry_metadata:
try:
# Clear registry metadata by creating a new empty metadata file
empty_metadata = {"materializers": {}, "version_objects": False}
empty_metadata = {"materializers": {}, "version_objects": False, "mutable": False, "version_digits": 3}
self.backend.save_registry_metadata(empty_metadata)
except Exception as e:
self.logger.warning(f"Could not clear registry metadata: {e}")
Expand Down
8 changes: 8 additions & 0 deletions mindtrace/registry/mindtrace/registry/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(
backend: str | Path | RegistryBackend | None = None,
version_objects: bool | None = None,
mutable: bool | None = None,
version_digits: int | None = None,
versions_cache_ttl: float = 60.0,
use_cache: bool = True,
**kwargs,
Expand Down Expand Up @@ -111,6 +112,7 @@ def __init__(
backend=backend,
version_objects=version_objects,
mutable=mutable,
version_digits=version_digits,
versions_cache_ttl=versions_cache_ttl,
**kwargs,
)
Expand All @@ -119,6 +121,7 @@ def __init__(
backend=LocalRegistryBackend(uri=cache_dir),
version_objects=self._remote.version_objects,
mutable=True, # cache is always mutable for updates
version_digits=self._remote.version_digits,
versions_cache_ttl=versions_cache_ttl,
**kwargs,
)
Expand All @@ -130,6 +133,7 @@ def __init__(
backend=backend,
version_objects=version_objects,
mutable=mutable,
version_digits=version_digits,
versions_cache_ttl=versions_cache_ttl,
**kwargs,
)
Expand Down Expand Up @@ -159,6 +163,10 @@ def version_objects(self) -> bool:
def mutable(self) -> bool:
return self._core.mutable

@property
def version_digits(self) -> int:
return self._core.version_digits

# ─────────────────────────────────────────────────────────────────────────
# Class-level materializer registry (delegates to _RegistryCore)
# ─────────────────────────────────────────────────────────────────────────
Expand Down
73 changes: 73 additions & 0 deletions mindtrace/registry/mindtrace/registry/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,79 @@
ERROR_UNKNOWN = "UnknownError" # Fallback error type when error info unavailable


@dataclass(frozen=True, order=True)
class Version:
"""Canonical numeric version with fixed-width digits.

A ``Version`` is initialized from a version string (e.g. ``"1"`` or ``"1.2"``)
and a fixed ``digits`` width for its registry context.

Examples:
Version("1", digits=3) == Version("1.0", digits=3) == Version("1.0.0", digits=3)
str(Version("1.0", digits=3)) == "1.0.0"

Notes:
- Optional ``v`` prefix is accepted (e.g. ``"v1.2"``).
- Number of components must be between 1 and ``digits`` (inclusive).
- All components must be non-negative integers.
- Versions with more than ``digits`` components are rejected.
"""

parts: Tuple[int, ...]
digits: int = field(default=3, compare=False)

def __init__(self, value: str, digits: int = 3):
if digits < 1:
raise ValueError(f"digits must be >= 1, got {digits}")
object.__setattr__(self, "digits", digits)
object.__setattr__(self, "parts", self._parse(value, digits))

@classmethod
def _parse(cls, value: str, digits: int) -> Tuple[int, ...]:
if value is None:
raise ValueError("Version cannot be None")

raw = value.strip()
if not raw:
raise ValueError("Version cannot be empty")

if raw.startswith("v"):
raw = raw[1:]

components = raw.split(".")
if len(components) < 1 or len(components) > digits:
raise ValueError(
f"Invalid version string '{value}'. Expected between 1 and {digits} numeric components."
)

try:
parsed = tuple(int(component) for component in components)
except ValueError as exc:
raise ValueError(
f"Invalid version string '{value}'. Expected numeric components like '1.0.0'."
) from exc

if any(component < 0 for component in parsed):
raise ValueError(f"Invalid version string '{value}'. Components must be non-negative integers.")

if len(parsed) < digits:
parsed = parsed + (0,) * (digits - len(parsed))

return parsed

@property
def normalized(self) -> str:
return ".".join(str(component) for component in self.parts)

def bump(self) -> "Version":
components = list(self.parts)
components[-1] += 1
return Version(".".join(str(component) for component in components), digits=self.digits)

def __str__(self) -> str:
return self.normalized


# ─────────────────────────────────────────────────────────────────────────────
# Enums
# ─────────────────────────────────────────────────────────────────────────────
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ def test_object_discovery(registry):
registry.save(f"{test_prefix}object:1", "data1_v2", version="2.0.0")
versions = registry.list_versions(f"{test_prefix}object:1")
assert len(versions) == 2
assert "1" in versions # Auto-generated version
assert "1.0.0" in versions # Auto-generated canonical version
assert "2.0.0" in versions


Expand Down
2 changes: 1 addition & 1 deletion tests/unit/mindtrace/cluster/test_docker_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from unittest.mock import Mock, patch

import pytest

from docker.errors import DockerException

from mindtrace.cluster.workers.environments.docker_env import DockerEnvironment


Expand Down
4 changes: 2 additions & 2 deletions tests/unit/mindtrace/jobs/local/test_local_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def test_priority_queue_with_priority(self, temp_local_client):
def test_receive_from_nonexistent_queue(self, temp_local_client):
"""Test receiving from a queue that doesn't exist."""
client = temp_local_client
with pytest.raises(RegistryObjectNotFound, match="Object nonexistent-queue@1 not found"):
with pytest.raises(RegistryObjectNotFound, match=r"Object nonexistent-queue@1\.0\.0 not found"):
client.receive_message("nonexistent-queue")

def test_receive_message_json_decode_error(self, temp_local_client):
Expand All @@ -246,7 +246,7 @@ def test_receive_message_json_decode_error(self, temp_local_client):
def test_clean_nonexistent_queue(self, temp_local_client):
"""Test cleaning a queue that doesn't exist."""
client = temp_local_client
with pytest.raises(RegistryObjectNotFound, match="Object nonexistent-queue@1 not found"):
with pytest.raises(RegistryObjectNotFound, match=r"Object nonexistent-queue@1\.0\.0 not found"):
client.clean_queue("nonexistent-queue")

def test_count_nonexistent_queue(self, temp_local_client):
Expand Down
Loading
Loading