Skip to content

#v1 Add compatibility tests for save-by-v0-load-by-v1 and also fix code. #1849

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ would want them to be preferred.
- #v1 Add `JsonHandler`.
- #v1 Add `training.Checkpointer`.
- #v1 Add checkpointables support for `training.Checkpointer`.

### Added

- `PartsOf` structure which holds a PyTree whose leaf nodes may be missing.
- #v1 Add compatibility tests for save-by-v0-load-by-v1 and also fix code.

## [0.11.12] - 2025-04-09

Expand Down
2 changes: 1 addition & 1 deletion checkpoint/orbax/checkpoint/_src/metadata/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _sanitize_metadata_path(path: epath.PathLike) -> epath.Path:


def step_metadata_file_path(path: epath.PathLike) -> epath.Path:
"""The path to step metadata file for a given checkpoint directory."""
"""The path to step metadata file, `_CHECKPOINT_METADATA`."""
return _sanitize_metadata_path(path) / _STEP_METADATA_FILENAME


Expand Down
1 change: 1 addition & 0 deletions checkpoint/orbax/checkpoint/_src/path/format_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def _has_msgpack_metadata_file(path: epath.Path) -> bool:


def _has_pytree_metadata_file(path: epath.Path) -> bool:
"""Returns True if path contains `_METADATA` or `checkpoint` file."""
return (path / PYTREE_METADATA_FILE).exists() or _has_msgpack_metadata_file(
path
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,19 +191,47 @@ def _get_loadable_handlers(
handler_typestr=handler_typestr,
)
loadable_checkpointable_names_to_handlers[name] = handler

return loadable_checkpointable_names_to_handlers

def _get_saved_handler_typestrs(
self,
directory: path_types.Path,
) -> dict[str, str]:
"""Reads from the checkpoint metadata to get saved handler typestrs."""
serialized_metadata = self._metadata_store.read(
checkpoint_metadata.step_metadata_file_path(directory)
step_metadata_file_path = checkpoint_metadata.step_metadata_file_path(
directory
)
saved_metadata = step_metadata_serialization.deserialize(
serialized_metadata or {}
if step_metadata_file_path.exists():
serialized_metadata = self._metadata_store.read(step_metadata_file_path)
saved_metadata = step_metadata_serialization.deserialize(
serialized_metadata or {}
)
assert isinstance(saved_metadata.item_handlers, dict)
return saved_metadata.item_handlers

logging.warning(
'Given dir contains checkpointables subdirs but no step metadata'
' file=%s. Such dirs can exist if the checkpoints are saved directly'
' using V0 Checkpointer instead of using CheckpointManager or'
' CompositeCheckpointHandler. Will fetch saved handlers from each of'
' the checkpointable subdirectories.',
directory,
)
assert isinstance(saved_metadata.item_handlers, dict)
return saved_metadata.item_handlers
saved_handler_typestrs = {}
for _, checkpointable_name in self._handler_registry.get_all_entries():
if not checkpointable_name:
continue
checkpointable_path = directory / checkpointable_name
if not checkpointable_path.exists() or not checkpointable_path.is_dir():
continue
serialized_metadata = self._metadata_store.read(
checkpoint_metadata.step_metadata_file_path(checkpointable_path)
)
saved_metadata = step_metadata_serialization.deserialize(
serialized_metadata or {}
)
assert not isinstance(saved_metadata.item_handlers, dict)
item_handlers = saved_metadata.item_handlers
if item_handlers is not None:
saved_handler_typestrs[checkpointable_name] = item_handlers
return saved_handler_typestrs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def load_pytree(
that is restored.

Args:
directory: The directory to load the checkpoint from.
directory: The directory to load the checkpoint from. This directory must
contain a subdirectory named `pytree`.
abstract_pytree: Provides a tree structure for the checkpoint to be restored
into. May be omitted to load exactly as saved., but this is much more
brittle than providing the tree.
Expand Down Expand Up @@ -203,9 +204,7 @@ def get_v0_checkpointer_and_args(
# pylint: disable=protected-access
handlers = composite_handler.CompositeHandler(
context.checkpointables_options.registry
)._get_loadable_handlers(
directory, abstract_checkpointables
)
)._get_loadable_handlers(directory, abstract_checkpointables)
# pylint: enable=protected-access
if not abstract_checkpointables:
abstract_checkpointables = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,12 @@ def validate_pytree_checkpoint(path: path_types.PathLike):
raise FileNotFoundError(f'Checkpoint path {path} does not exist.')
if not path.is_dir():
raise NotADirectoryError(f'Checkpoint path {path} is not a directory.')
metadata_store = checkpoint_metadata.metadata_store(enable_write=False)
# Path points to a single step checkpoint with valid metadata.
checkpoint_metadata_path = checkpoint_metadata.step_metadata_file_path(path)
if not checkpoint_metadata_path.exists():
raise FileNotFoundError(
f'Checkpoint path {path} does not contain a valid metadata file.'
)
if metadata_store.read(checkpoint_metadata_path) is None:
raise ValueError(
f'Failed to read valid metadata for checkpoint path {path}.'
)
if not (path / PYTREE_CHECKPOINTABLE_KEY).exists():
raise FileNotFoundError(
f'Checkpoint path {path} does not contain a PyTree checkpointable'
f' (called "{PYTREE_CHECKPOINTABLE_KEY}").'
f'Checkpoint path {path} must contain a subdirectory named'
f' "{PYTREE_CHECKPOINTABLE_KEY}". Please try inspecting the'
' checkpointable metadata using `ocp.checkpointables_metadata()` or'
' try loading the checkpoint using `ocp.load_checkpointables()`.'
)
if not format_utils._has_pytree_metadata_file( # pylint: disable=protected-access
path / PYTREE_CHECKPOINTABLE_KEY
Expand All @@ -84,3 +75,15 @@ def validate_pytree_checkpoint(path: path_types.PathLike):
' entirely of strings or other non-standard PyTree leaves.',
path,
)
metadata_store = checkpoint_metadata.metadata_store(enable_write=False)
# Path points to a single step checkpoint with valid metadata.
checkpoint_metadata_path = checkpoint_metadata.step_metadata_file_path(path)
if not checkpoint_metadata_path.exists():
raise FileNotFoundError(
f'Checkpoint path {path} does not contain a valid metadata file:'
f' {checkpoint_metadata_path.name}'
)
if metadata_store.read(checkpoint_metadata_path) is None:
raise ValueError(
f'Failed to read valid metadata for checkpoint path {path}.'
)
17 changes: 17 additions & 0 deletions checkpoint/orbax/checkpoint/experimental/v1/_src/testing/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,20 @@ py_library(
"//orbax/checkpoint/experimental/v1/_src/synchronization:multihost",
],
)

py_library(
name = "v0v1_compatibility_save_load_test_base",
srcs = ["v0v1_compatibility_save_load_test_base.py"],
deps = [
":array_utils",
"//checkpoint/orbax/checkpoint:args",
"//checkpoint/orbax/checkpoint:test_utils",
"//checkpoint/orbax/checkpoint/_src/checkpointers:checkpointer",
"//checkpoint/orbax/checkpoint/_src/checkpointers:standard_checkpointer",
"//checkpoint/orbax/checkpoint/_src/handlers:composite_checkpoint_handler",
"//orbax/checkpoint/experimental/v1",
"//orbax/checkpoint/experimental/v1/_src/path:types",
"//orbax/checkpoint/experimental/v1/_src/synchronization:multihost",
"//orbax/checkpoint/experimental/v1/_src/tree:types",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,40 @@ py_test(
"//orbax/checkpoint/experimental/v1/_src/handlers:pytree_handler_test_base",
],
)

py_test(
name = "v0v1_compatibility_save_load_test_single_worker",
srcs = ["v0v1_compatibility_save_load_test.py"],
args = [
"--jax_platforms=pathways",
"--jax_backend_target=subprocess",
"--pathways_ifrt=true",
"--jax_allow_unused_tpus=true",
],
main = "v0v1_compatibility_save_load_test.py",
deps = [
"//pyglib/contrib/g3_multiprocessing",
"//testing/pybase:parameterized",
"//orbax/checkpoint/experimental/v1:pathways_support",
"//orbax/checkpoint/experimental/v1/_src/testing:v0v1_compatibility_save_load_test_base",
],
)

py_test(
name = "v0v1_compatibility_save_load_test_multi_worker",
srcs = ["v0v1_compatibility_save_load_test.py"],
args = [
"--jax_platforms=pathways",
"--jax_backend_target=subslice",
"--pathways_ifrt=true",
"--jax_allow_unused_tpus=true",
"--pathways_expected_instances=df=1x1,df=1x1,df=1x1,df=1x1",
],
main = "v0v1_compatibility_save_load_test.py",
deps = [
"//pyglib/contrib/g3_multiprocessing",
"//testing/pybase:parameterized",
"//orbax/checkpoint/experimental/v1:pathways_support",
"//orbax/checkpoint/experimental/v1/_src/testing:v0v1_compatibility_save_load_test_base",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def test_missing_keys(self):

with self.subTest('load_pytree'):
with self.assertRaisesRegex(
FileNotFoundError, 'does not contain a PyTree checkpointable'
FileNotFoundError, 'must contain a subdirectory named "pytree"'
):
ocp.load_pytree(self.directory)

Expand Down
Loading
Loading