Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Copyright 2025 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines `NumpyLayout` for loading NumPy checkpoint files."""

import asyncio
from typing import Any, Awaitable, IO
import zipfile

import jax
import jax.tree_util
import numpy as np
from numpy.lib import format as np_format
from orbax.checkpoint._src.path import async_path
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types
from orbax.checkpoint.experimental.v1._src.path import types
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types


CheckpointLayout = checkpoint_layout.CheckpointLayout
InvalidLayoutError = checkpoint_layout.InvalidLayoutError
Path = types.Path


def _get_npy_info(fp: IO[bytes]) -> tuple[tuple[int, ...], np.dtype]:
"""Reads shape and dtype from npy file header."""
try:
version = np_format.read_magic(fp)
except ValueError as e:
raise ValueError('File does not start with npy magic') from e

if version == (1, 0):
shape, _, dtype = np_format.read_array_header_1_0(fp)
elif version == (2, 0):
shape, _, dtype = np_format.read_array_header_2_0(fp)
elif version == (3, 0):
if not hasattr(np_format, 'read_array_header_3_0'):
raise ValueError(
'NumPy checkpoint uses .npy version 3.0, but support for this'
' format requires NumPy version 1.17 or later.'
)
shape, _, dtype = np_format.read_array_header_3_0(fp)
else:
raise ValueError(f'Unsupported npy format version: {version}')
return shape, dtype


def _load_numpy_on_device(
npz_file: Any,
abstract_pytree: tree_types.PyTreeOf[jax.ShapeDtypeStruct],
) -> tree_types.PyTreeOf[jax.Array]:
"""Loads arrays from npz_file into on-device JAX arrays."""
restored_pytree = {}

flat_abstract, _ = jax.tree.flatten_with_path(abstract_pytree)
for key_path, abstract_leaf in flat_abstract:
if len(key_path) != 1 or not isinstance(key_path[0], jax.tree_util.DictKey):
raise ValueError(
f'The PyTree is not a flat dictionary. Key path: {key_path}'
)
key = str(key_path[0].key)
if key not in npz_file:
raise ValueError(f'Key {key} not found in npz file.')
leaf = npz_file[key] # loads numpy array
sharding = abstract_leaf.sharding
target_shape = abstract_leaf.shape
target_dtype = abstract_leaf.dtype

device_indices_map = sharding.addressable_devices_indices_map(target_shape)
device_arrays = []
for device in device_indices_map:
idx = device_indices_map[device]
shard_np = leaf[idx]
if shard_np.dtype != target_dtype:
shard_np = shard_np.astype(target_dtype)
device_arrays.append(jax.device_put(shard_np, device))
restored_pytree[key] = jax.make_array_from_single_device_arrays(
target_shape, sharding, device_arrays
)
return restored_pytree


async def _load_numpy(
path: Path,
abstract_pytree: tree_types.PyTreeOf[jax.ShapeDtypeStruct] | None = None,
) -> dict[str, Any]:
"""Loads numpy checkpoint as numpy arrays or sharded jax arrays."""
npz_file = await asyncio.to_thread(np.load, path, allow_pickle=True)
try:
if abstract_pytree is None:
# Return NumPy arrays.
restored_pytree = {k: npz_file[k] for k in npz_file.files}
else:
# Return on-device JAX arrays.
restored_pytree = _load_numpy_on_device(npz_file, abstract_pytree)
finally:
npz_file.close()

return {checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY: restored_pytree}


class NumpyLayout(CheckpointLayout):
"""Layout for loading NumPy checkpoints (.npz)."""

def __init__(self, path: Path):
self._path = path

@property
def path(self) -> Path:
"""Returns the path of the NumPy checkpoint file."""
return self._path

def _check_zip_structure(self):
"""Sync helper to check zip file."""
try:
with zipfile.ZipFile(self._path, 'r') as zf:
if not zf.namelist():
raise InvalidLayoutError(f"'{self._path}' is an empty zip archive.")
if not any(name.endswith('.npy') for name in zf.namelist()):
raise InvalidLayoutError(
f"'{self._path}' is not a valid NumPy archive "
'(missing .npy files).'
)
except zipfile.BadZipFile as e:
raise InvalidLayoutError(
f"'{self._path}' is not a valid ZIP file."
) from e
except Exception as e:
raise InvalidLayoutError(
f"Failed to read '{self._path}' as zip file: {e}"
) from e

async def validate(self) -> None:
"""Checks if the path is a file and a valid NumPy ZIP archive."""
if not await async_path.is_file(self._path):
raise InvalidLayoutError(f'Path is not a file: {self._path}')
if self._path.suffix not in ['.npz']:
raise InvalidLayoutError(
f'File {self._path} must have a .npz suffix to be loaded as a'
' NumPy checkpoint.'
)
try:
await asyncio.to_thread(self._check_zip_structure)
except InvalidLayoutError as e:
raise e
except OSError as e:
raise InvalidLayoutError(
f'Failed to validate {self._path} as NumPy checkpoint: {e}'
) from e

async def validate_pytree(self, checkpointable_name: str | None) -> None:
"""No-op, as NumpyLayout treats the entire file as the 'pytree' item."""
return

async def metadata(
self,
) -> metadata_types.CheckpointMetadata[dict[str, tree_types.PyTreeOf[Any]]]:
"""Extracts ShapeDtypeStruct metadata without loading array data."""

def _read_metadata_sync():
metadata = {}
try:
with zipfile.ZipFile(self._path, 'r') as zf:
for name in zf.namelist():
if not name.endswith('.npy'):
continue
arr_name = name[:-4]
with zf.open(name) as f:
shape, dtype = _get_npy_info(f)
metadata[arr_name] = jax.ShapeDtypeStruct(
shape=shape, dtype=dtype
)
except zipfile.BadZipFile as e:
raise InvalidLayoutError(
f"'{self._path}' is not a valid ZIP file."
) from e
except Exception as e:
raise InvalidLayoutError(
f'Failed to read metadata from {self._path}'
) from e
return metadata

metadata_tree = await asyncio.to_thread(_read_metadata_sync)
stat_result = await async_path.async_stat(self._path)
commit_timestamp_nsecs = int(stat_result.mtime * 1e9)

return metadata_types.CheckpointMetadata[dict[str, Any]](
metadata={checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY: metadata_tree},
commit_timestamp_nsecs=commit_timestamp_nsecs,
)

async def load(
self,
abstract_checkpointables: (
dict[str, tree_types.PyTreeOf[jax.ShapeDtypeStruct]] | None
) = None,
) -> Awaitable[dict[str, tree_types.PyTreeOf[Any]]]:
"""Loads a NumPy checkpoint file."""
abstract_pytree = None
if abstract_checkpointables:
abstract_pytree = abstract_checkpointables.get(
checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY
)
return _load_numpy(self._path, abstract_pytree)
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright 2025 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from absl.testing import absltest
from absl.testing import parameterized
from etils import epath
import jax
import numpy as np
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
from orbax.checkpoint.experimental.v1._src.layout import numpy_layout
from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types


NumpyLayout = numpy_layout.NumpyLayout
InvalidLayoutError = checkpoint_layout.InvalidLayoutError


class NumpyLayoutTest(unittest.IsolatedAsyncioTestCase, parameterized.TestCase):

def setUp(self):
super().setUp()
self.test_dir = self.create_tempdir()
self.numpy_path = (
epath.Path(self.test_dir.full_path) / 'test_checkpoint.npz'
)

self.object_to_save = {
'a': np.array(3 * [1, 2, 3], dtype=np.int32),
'b': np.array([0, 1, 0.2], dtype=np.float32),
}
np.savez(self.numpy_path, **self.object_to_save)

async def test_valid_numpy_checkpoint(self):
layout = NumpyLayout(self.numpy_path)
await layout.validate()

async def test_validate_fails_not_file(self):
layout = NumpyLayout(epath.Path(self.test_dir.full_path))
with self.assertRaises(InvalidLayoutError):
await layout.validate()

async def test_validate_fails_wrong_suffix(self):
wrong_suffix_path = (
epath.Path(self.test_dir.full_path) / 'test_checkpoint.txt'
)
wrong_suffix_path.touch()
layout = NumpyLayout(wrong_suffix_path)
with self.assertRaises(InvalidLayoutError):
await layout.validate()

async def test_validate_fails_not_zip(self):
bad_zip_path = epath.Path(self.test_dir.full_path) / 'bad_zip.npz'
bad_zip_path.write_text('this is not a zip file')
layout = NumpyLayout(bad_zip_path)
with self.assertRaises(InvalidLayoutError):
await layout.validate()

@parameterized.product(
dtype=[
np.int8,
np.int32,
np.int64,
np.float16,
np.float32,
np.float64,
np.bool_,
]
)
async def test_load_numpy_checkpoint(self, dtype: np.dtype):
"""Tests loading a NumPy checkpoint with various dtypes."""
test_path = (
epath.Path(self.test_dir.full_path) / f'test_{dtype.__name__}.npz'
)
if dtype == np.bool_:
arr = np.array([True, False, True, False])
else:
arr = np.arange(8, dtype=dtype)

obj_to_save = {'x': arr, 'y': np.array([1, 2, 3], dtype=np.int32)}
np.savez(test_path, **obj_to_save)

# Load the checkpoint
layout = NumpyLayout(test_path)
restore_fn = await layout.load()
restored_checkpointables = await restore_fn
pytree = restored_checkpointables['pytree']

# Verify restored data
if np.issubdtype(dtype, np.floating):
np.testing.assert_allclose(pytree['x'], obj_to_save['x'])
else:
np.testing.assert_array_equal(pytree['x'], obj_to_save['x'])
np.testing.assert_array_equal(pytree['y'], obj_to_save['y'])

async def test_metadata(self):
layout = NumpyLayout(self.numpy_path)
metadata = await layout.metadata()
self.assertIsInstance(metadata, metadata_types.CheckpointMetadata)
self.assertEqual(
metadata.metadata,
{
checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY: {
'b': jax.ShapeDtypeStruct(shape=(3,), dtype=np.float32),
'a': jax.ShapeDtypeStruct(shape=(9,), dtype=np.int32),
}
},
)
self.assertIsInstance(metadata.commit_timestamp_nsecs, int)
self.assertGreater(metadata.commit_timestamp_nsecs, 0)


if __name__ == '__main__':
absltest.main()
Loading