Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

- #v1 Make most V1 public concrete classes final.
- Refactor `CheckpointLayout` splitting `load()` into `load_pytree()` and
`load_checkpointables()` each with their own dedicated loading logic
- Refactor v0 Pytree validation and metadata resolution and add `OrbaxV0Layout`
tests
- Refactor `CompositeHandler` logic into the orbax layout objects and handler
resolution utility, deprecating and deleting the `CompositeHandler` class.

## [0.11.32] - 2026-01-20

Expand Down
2 changes: 1 addition & 1 deletion checkpoint/orbax/checkpoint/experimental/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from orbax.checkpoint.experimental.v1._src.context.context import (
Context,
)
from orbax.checkpoint.experimental.v1._src.layout.orbax_layout import (
from orbax.checkpoint.experimental.v1._src.layout.registry import (
is_orbax_checkpoint,
)
from orbax.checkpoint.experimental.v1._src.loading.loading import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ async def _create_orbax_identifier_file(
)


# TODO(b/477603241): Remove this class and delete this file and its tests.
class CompositeHandler:
"""CompositeHandler.

Expand Down Expand Up @@ -291,8 +292,15 @@ def _get_saved_handler_typestrs(
directory,
)

# TODO(b/475265289): Currently, we rely solely on CHECKPOINT_METADATA to
# find available checkpointables, ignoring valid subdirectories. We
# should update the composite handler to validate subdirectories to
# check if any either represents a valid pytree checkpointable or has a
# name that is registered in the handler registry.
saved_handler_typestrs: dict[str, str] = {}
for checkpointable_path in directory.iterdir():
if not checkpointable_path.is_dir():
continue
serialized_metadata = self._metadata_store.read(
checkpoint_metadata.step_metadata_file_path(checkpointable_path)
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Logic for resolving handlers for loading."""

import itertools
from typing import Any

from absl import logging
from orbax.checkpoint._src.metadata import checkpoint as checkpoint_metadata
from orbax.checkpoint._src.metadata import step_metadata_serialization
from orbax.checkpoint._src.path import async_path
from orbax.checkpoint.experimental.v1._src.handlers import registration
from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types
import orbax.checkpoint.experimental.v1._src.handlers.global_registration # pylint: disable=unused-import
from orbax.checkpoint.experimental.v1._src.metadata import serialization as metadata_serialization
from orbax.checkpoint.experimental.v1._src.path import types as path_types


def _subdirs(directory: path_types.Path, *, limit: int = 3) -> list[str]:
return list(
itertools.islice(
(subdir.name for subdir in directory.iterdir() if subdir.is_dir()),
limit,
)
)


_V0_ERROR_MESSAGE = (
"If your checkpoint was saved with the Orbax V0 API, please follow the"
" instructions at"
" https://orbax.readthedocs.io/en/latest/guides/checkpoint/v1/orbax_v0_to_v1_migration.html"
" to load it with the Orbax V1 API."
)


def _existing_checkpointable_names(directory: path_types.Path) -> set[str]:
return {p.name for p in directory.iterdir() if p.is_dir()}


async def get_step_metadata(
directory: path_types.Path,
) -> step_metadata_serialization.StepMetadata:
"""Returns the step metadata for a given path, normalized for V1."""

serialized_metadata = await metadata_serialization.read(
metadata_serialization.checkpoint_metadata_file_path(directory)
)
metadata = step_metadata_serialization.deserialize(serialized_metadata or {})

return metadata


def get_handlers_for_save(
handler_registry: registration.CheckpointableHandlerRegistry,
checkpointables: dict[str, Any],
) -> dict[str, handler_types.CheckpointableHandler]:
"""Returns a mapping from checkpointable name to handler."""
return {
checkpointable_name: registration.resolve_handler_for_save(
handler_registry, checkpointable, name=checkpointable_name
)
for checkpointable_name, checkpointable in checkpointables.items()
}


async def get_handlers_for_load(
directory: path_types.Path,
handler_registry: registration.CheckpointableHandlerRegistry,
abstract_checkpointables: dict[str, Any],
) -> dict[str, handler_types.CheckpointableHandler]:
"""Returns a mapping from checkpointable name to handler."""
existing_checkpointable_names_to_handler_typestrs = (
await _get_saved_handler_typestrs(directory)
)
abstract_checkpointables = abstract_checkpointables or {
name: None for name in existing_checkpointable_names_to_handler_typestrs
}

loadable_checkpointable_names_to_handlers = {}
for name, abstract_checkpointable in abstract_checkpointables.items():
if name not in existing_checkpointable_names_to_handler_typestrs:
raise KeyError(
f'Checkpointable "{name}" was not found in the checkpoint.'
" Available names:"
f" {existing_checkpointable_names_to_handler_typestrs.keys()}"
)
handler_typestr = existing_checkpointable_names_to_handler_typestrs[name]
handler = registration.resolve_handler_for_load(
handler_registry,
abstract_checkpointable,
name=name,
handler_typestr=handler_typestr,
)
loadable_checkpointable_names_to_handlers[name] = handler
return loadable_checkpointable_names_to_handlers


async def _get_saved_handler_typestrs(
directory: path_types.Path,
) -> dict[str, str]:
"""Reads from the checkpoint metadata to get saved handler typestrs."""
step_metadata_file_path = checkpoint_metadata.step_metadata_file_path(
directory
)
if await async_path.exists(step_metadata_file_path):
step_metadata = await get_step_metadata(directory)
if isinstance(step_metadata.item_handlers, dict):
return step_metadata.item_handlers # found step level metadata.
raise ValueError(
f"Path at {directory} contains subdirectories:"
f" {_subdirs(directory)}, which are expected to"
" match the keys given by the _CHECKPOINT_METADATA file:"
f" {step_metadata.item_handlers}. If you intended to load a pytree"
" checkpoint from the given path, then please consider using"
" `loading.load_pytree(..., checkpointable_name=None)` instead."
f" {_V0_ERROR_MESSAGE}"
)

logging.warning(
"Given dir does not contain checkpoint metadata file: %s. Trying to get"
" saved handlers from checkpoint metadata in each of the checkpointable"
" subdirectory.",
directory,
)

# TODO(b/475265289): Currently, we rely solely on CHECKPOINT_METADATA to
# find available checkpointables, ignoring valid subdirectories. We
# should update the composite handler to validate subdirectories to
# check if any either represents a valid pytree checkpointable or has a
# name that is registered in the handler registry.
saved_handler_typestrs: dict[str, str] = {}
for checkpointable_path in await async_path.iterdir(directory):
if not await async_path.is_dir(checkpointable_path):
continue
step_metadata = await get_step_metadata(checkpointable_path)
if isinstance(step_metadata.item_handlers, dict):
raise ValueError(
f"Path at {directory} contains subdirectories:"
f" {_subdirs(directory)}, which are expected to"
" match the keys given by the _CHECKPOINT_METADATA file:"
f" {step_metadata.item_handlers}. If you intended to load a pytree"
" checkpoint from the given path, then please consider using"
" `loading.load_pytree(..., checkpointable_name=None)` instead."
f" {_V0_ERROR_MESSAGE}"
)
item_handlers = step_metadata.item_handlers
if item_handlers is not None:
checkpointable_name = checkpointable_path.name
saved_handler_typestrs[checkpointable_name] = item_handlers
return saved_handler_typestrs
Original file line number Diff line number Diff line change
Expand Up @@ -89,24 +89,21 @@ async def validate_pytree(
"""
...

async def load(
async def load_pytree(
self,
path: Path,
abstract_checkpointables: dict[str, Any] | None = None,
) -> Awaitable[dict[str, Any]]:
"""Loads the checkpoint from the given directory.
checkpointable_name: str | None = None,
abstract_pytree: Any | None = None,
) -> Awaitable[Any]:
"""Loads a PyTree from the checkpoint.

Args:
path: The path to the checkpoint.
abstract_checkpointables: A dictionary of abstract checkpointables.
Dictionary keys represent the names of the checkpointables, while the
values are the abstract checkpointable objects themselves.
checkpointable_name: The name of the checkpointable to load.
abstract_pytree: The abstract PyTree structure.

Returns:
An awaitable dictionary of checkpointables. Dictionary keys represent the
names of
the checkpointables, while the values are the checkpointable objects
themselves.
An awaitable PyTree.
"""
...

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _load_leaf(leaf: Any, abstract_leaf: jax.ShapeDtypeStruct):
async def _load_numpy(
path: Path,
abstract_pytree: tree_types.PyTreeOf[jax.ShapeDtypeStruct] | None = None,
) -> dict[str, Any]:
) -> Any:
"""Loads numpy checkpoint as numpy arrays or sharded jax arrays."""
npz_file = await asyncio.to_thread(np.load, path, allow_pickle=True)
try:
Expand All @@ -112,7 +112,7 @@ async def _load_numpy(
finally:
npz_file.close()

return {checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY: restored_pytree}
return restored_pytree


class NumpyLayout(CheckpointLayout):
Expand Down Expand Up @@ -193,20 +193,30 @@ def _read_metadata_sync():
commit_timestamp_nsecs=commit_timestamp_nsecs,
)

async def load(
async def load_pytree(
self,
path: Path,
abstract_checkpointables: (
dict[str, tree_types.PyTreeOf[jax.ShapeDtypeStruct]] | None
) = None,
) -> Awaitable[dict[str, tree_types.PyTreeOf[Any]]]:
"""Loads a NumPy checkpoint file."""
abstract_pytree = None
if abstract_checkpointables:
abstract_pytree = abstract_checkpointables.get(
checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY
)
return _load_numpy(path, abstract_pytree)
checkpointable_name: str | None = None,
abstract_pytree: Any | None = None,
) -> Awaitable[tree_types.PyTreeOf[Any]]:
"""Loads a NumPy checkpoint file.

If `abstract_pytree` is provided, it attempts to load numpy arrays as
sharded `jax.Arrays` onto devices.

Args:
path: The path to the checkpoint.
checkpointable_name: The name of the pytree checkpointable to load,
unsused in this case.
abstract_pytree: An optional PyTree of abstract arrays specifying sharding
information.

Returns:
An awaitable of a dictionary containing the loaded PyTree.
"""
del checkpointable_name
load_awaitable = _load_numpy(path, abstract_pytree)
return load_awaitable

async def save(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,8 @@ async def test_load_numpy_checkpoint(self, dtype: np.dtype):

# Load the checkpoint
layout = NumpyLayout()
restore_fn = await layout.load(test_path)
restored_checkpointables = await restore_fn
pytree = restored_checkpointables['pytree']
restore_fn = await layout.load_pytree(test_path)
pytree = await restore_fn

# Verify restored data
if np.issubdtype(dtype, np.floating):
Expand Down
Loading
Loading