Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fix `step_from_checkpoint_name` to allow the passed in checkpoint name to
include an arbitrary `step_prefix` with any character(s) such as underscores.
- Fix CheckpointManager initial directory creation to use `file_options.path_permission_mode`.
- Fix using jax.eval_shape with StandardRestore

### Changed

- Validate checkpoints before writing merged OCDBT database using in-memory
state, avoiding additional I/O to re-read metadata.
- add `support_format` to utils.to_shape_dtype_struct()
- Moved `register_pathways_handlers` to `ocp.pathways.register_type_handlers`.
- Replace usage of `get_json_tpec_read` and delegate functionality to new
function `build_array_read_spec` which constructs and returns an
`ArrayReadSpec`.

## [0.11.28] - 2025-11-06

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -505,8 +505,13 @@ def handler(self) -> StandardCheckpointHandler:
return StandardCheckpointHandler()

def test_with_random_keys(self):
# TODO(b/393160483) investigate Pathways remote Python support for
# random.keys.
if utils.is_pathways_backend():
self.skipTest('Pathways does not support random keys checkpoint.')
self.skipTest(
'Disabled on Pathways because random keys are not supported by'
' remote Python.'
)

def create_random_keys(seed):
duplicated_sharding = jax.sharding.NamedSharding(
Expand Down Expand Up @@ -559,3 +564,38 @@ def create_random_keys(seed):
args=self.restore_args_cls(abstract_tree),
)
test_utils.assert_tree_equal(self, self.pytree, restored)

def test_save_restore_random_keys_with_jax_eval_shape(self):
# TODO(b/393160483) investigate Pathways remote Python support for
# random.keys.
if utils.is_pathways_backend():
self.skipTest(
'Disabled on Pathways because random keys are not supported by'
' remote Python.'
)

mesh = jax.sharding.Mesh(jax.devices(), ('x',))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())

@functools.partial(
jax.jit,
in_shardings=sharding,
out_shardings=sharding,
)
def sharded_create_state_fn(root_key):
return dict(
matrix=jnp.array([[1, 2], [3, 4], [5, 6], [7, 8]]),
rngkey=jax.random.fold_in(root_key, 42),
)

pytree = sharded_create_state_fn(jax.random.key(0))
abstract_pytree = jax.eval_shape(
sharded_create_state_fn, jax.random.key(0)
)

self.handler.save(self.directory, args=self.save_args_cls(pytree))

restored = self.handler.restore(
self.directory, args=self.restore_args_cls(abstract_pytree)
)
test_utils.assert_tree_equal(self, pytree, restored)
Original file line number Diff line number Diff line change
Expand Up @@ -763,12 +763,13 @@ async def _async_deserialize(
await _validate_non_ocdbt_files(infos, metadata_key)
deserialize_ops = []
for info, arg, sharding in zip(infos, args, shardings):
tspec = ts_utils.get_json_tspec_read(
array_read_spec = ts_utils.build_array_read_spec(
info,
use_ocdbt=use_ocdbt,
metadata_key=metadata_key,
raise_array_data_missing_error=info.raise_array_data_missing_error,
)
tspec = array_read_spec.json
tspec = ts_utils.get_cast_tspec_deserialize(tspec, arg)

# set dtype=None to deserialize for random keys
Expand Down Expand Up @@ -939,19 +940,6 @@ def __init__(
def has_dispatcher(self) -> bool:
return self._dispatcher is not None

def _get_json_tspec_read(
self,
info: types.ParamInfo,
use_ocdbt: bool,
) -> Dict[str, Any]:
"""Gets Tensorstore spec for reading."""
return ts_utils.get_json_tspec_read(
info,
use_ocdbt=use_ocdbt,
metadata_key=self._metadata_key,
raise_array_data_missing_error=info.raise_array_data_missing_error,
)

def typestr(self) -> str:
return JAX_ARRAY_TYPE_STR

Expand All @@ -968,7 +956,13 @@ async def metadata(
for info in infos:
# Use OCDBT flag from the existing checkpoint.
use_ocdbt = info.is_ocdbt_checkpoint
tspec = self._get_json_tspec_read(info, use_ocdbt=use_ocdbt)
array_read_spec = ts_utils.build_array_read_spec(
info,
use_ocdbt=use_ocdbt,
metadata_key=self._metadata_key,
raise_array_data_missing_error=info.raise_array_data_missing_error,
)
tspec = array_read_spec.json
open_ops.append(
ts.open(ts.Spec(tspec), open=True, context=info.ts_context)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,45 @@ def _maybe_add_cast_to_write_spec(
return array_tspec


class ArrayReadSpec:
"""Full TensorStore spec for reading an array."""

def __init__(
self,
directory: str,
relative_array_filename: str,
use_zarr3: bool,
*,
use_ocdbt: bool,
metadata_key: str | None = None,
raise_array_data_missing_error: bool = True,
):
"""Builds a TensorStore spec for reading an array."""
kvstore_tspec = build_kvstore_tspec(
directory,
name=relative_array_filename,
use_ocdbt=use_ocdbt,
process_id=None,
)

tspec = {
'driver': ZARR_VER3 if use_zarr3 else ZARR_VER2,
'kvstore': kvstore_tspec,
'recheck_cached_data': False,
'recheck_cached_metadata': False,
# Raise error if data is missing.
'fill_missing_data_reads': not raise_array_data_missing_error,
}
if metadata_key is not None:
tspec['metadata_key'] = metadata_key
self._json_spec = tspec

@property
def json(self) -> JsonSpec:
"""Spec to be used to open a TensorStore for reading the array."""
return self._json_spec


class ArrayWriteSpec:
"""Full TensorStore spec for writing an array."""

Expand Down Expand Up @@ -677,6 +716,26 @@ def get_json_tspec_write(
return tspec


def build_array_read_spec(
info: types.ParamInfo,
*,
use_ocdbt: bool,
metadata_key: str | None = None,
raise_array_data_missing_error: bool = True,
) -> ArrayReadSpec:
"""Gets ArrayReadSpec for reading."""
if info.name is None or info.parent_dir is None:
raise ValueError('Must provide info.name and info.parent_dir.')
return ArrayReadSpec(
directory=info.parent_dir.as_posix(),
relative_array_filename=info.name,
use_zarr3=info.use_zarr3,
use_ocdbt=use_ocdbt,
metadata_key=metadata_key,
raise_array_data_missing_error=raise_array_data_missing_error,
)


def build_array_write_spec(
info: types.ParamInfo,
arg: types.SaveArgs | None = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,60 @@ def test_maybe_cloud_storage(self):
self.assertTrue(ts_utils.is_remote_storage(nested_tspec))


class BuildArrayTSpecForReadTest(parameterized.TestCase):

def setUp(self):
super().setUp()
self.directory = self.create_tempdir().full_path
self.param_name = 'params/a'

self.array_read_spec_constructor = functools.partial(
ts_utils.ArrayReadSpec,
directory=self.directory,
relative_array_filename=self.param_name,
)

@parameterized.product(
use_zarr3=(True, False),
use_ocdbt=(True, False),
)
def test_basic(self, use_zarr3: bool, use_ocdbt: bool):
tspec = self.array_read_spec_constructor(
use_zarr3=use_zarr3,
use_ocdbt=use_ocdbt,
)
json_spec = tspec.json
self.assertEqual(json_spec['driver'], 'zarr3' if use_zarr3 else 'zarr')
self.assertEqual(
json_spec['kvstore']['driver'],
'ocdbt' if use_ocdbt else ts_utils.DEFAULT_DRIVER,
)
self.assertFalse(json_spec['recheck_cached_data'])
self.assertFalse(json_spec['recheck_cached_metadata'])
self.assertFalse(json_spec['fill_missing_data_reads'])
self.assertNotIn('metadata_key', json_spec)

def test_metadata_key(self):
tspec = self.array_read_spec_constructor(
use_zarr3=False,
use_ocdbt=False,
metadata_key='custom_metadata',
)
self.assertEqual(tspec.json['metadata_key'], 'custom_metadata')

@parameterized.parameters(True, False)
def test_fill_missing_data_reads(self, raise_array_data_missing_error):
tspec = self.array_read_spec_constructor(
use_zarr3=False,
use_ocdbt=False,
raise_array_data_missing_error=raise_array_data_missing_error,
)
self.assertEqual(
tspec.json['fill_missing_data_reads'],
not raise_array_data_missing_error,
)


class GetTsContextTest(parameterized.TestCase):

@parameterized.product(
Expand Down
60 changes: 21 additions & 39 deletions checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,39 +77,6 @@ def __init__(
self._metadata_key = metadata_key
self._override_ocdbt_process_id = ocdbt_process_id

def _get_array_write_spec(
self,
info: types.ParamInfo,
value: np.ndarray,
use_ocdbt: bool,
process_index: Optional[Union[int, str]] = None,
arg: Optional[types.SaveArgs] = None,
) -> ts_utils.ArrayWriteSpec:
"""Gets ArrayWriteSpec for writing."""
return ts_utils.build_array_write_spec(
info=info,
arg=arg,
global_shape=value.shape,
local_shape=value.shape,
dtype=value.dtype,
use_ocdbt=use_ocdbt,
process_index=process_index,
metadata_key=self._metadata_key,
)

def _get_json_tspec_read(
self,
info: types.ParamInfo,
use_ocdbt: bool,
) -> Dict[str, Any]:
"""Gets Tensorstore spec for reading."""
return ts_utils.get_json_tspec_read(
info,
use_ocdbt=use_ocdbt,
metadata_key=self._metadata_key,
raise_array_data_missing_error=info.raise_array_data_missing_error,
)

def typestr(self) -> str:
return 'np.ndarray'

Expand All @@ -120,7 +87,13 @@ async def metadata(
for info in infos:
# Use OCDBT flag from the existing checkpoint.
use_ocdbt = info.is_ocdbt_checkpoint
tspec = self._get_json_tspec_read(info, use_ocdbt=use_ocdbt)
array_read_spec = ts_utils.build_array_read_spec(
info,
use_ocdbt=use_ocdbt,
metadata_key=self._metadata_key,
raise_array_data_missing_error=info.raise_array_data_missing_error,
)
tspec = array_read_spec.json
open_ops.append(
ts.open(ts.Spec(tspec), open=True, context=info.ts_context)
)
Expand Down Expand Up @@ -149,15 +122,18 @@ async def _background_serialize(
"""Serializes numpy arrays in a background thread."""
write_coros = []
for value, info, arg in zip(values, infos, args):
array_write_spec = self._get_array_write_spec(
info,
value,
array_write_spec = ts_utils.build_array_write_spec(
info=info,
arg=arg,
global_shape=value.shape,
local_shape=value.shape,
dtype=value.dtype,
use_ocdbt=info.is_ocdbt_checkpoint,
process_index=ocdbt_utils.get_process_index_for_subdir(
use_ocdbt=info.is_ocdbt_checkpoint,
override_ocdbt_process_id=self._override_ocdbt_process_id,
),
arg=arg,
metadata_key=self._metadata_key,
)
tspec = array_write_spec.json
if logging.vlog_is_on(1):
Expand Down Expand Up @@ -205,7 +181,13 @@ async def deserialize(
)
# Use OCDBT flag from the existing checkpoint.
use_ocdbt = info.is_ocdbt_checkpoint
tspec = self._get_json_tspec_read(info, use_ocdbt=use_ocdbt)
array_read_spec = ts_utils.build_array_read_spec(
info,
use_ocdbt=use_ocdbt,
metadata_key=self._metadata_key,
raise_array_data_missing_error=info.raise_array_data_missing_error,
)
tspec = array_read_spec.json
tspec = ts_utils.get_cast_tspec_deserialize(tspec, arg)

if logging.vlog_is_on(1):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,19 +118,13 @@ def test_fn(
with metrics.measure(f'train_step_{i}'):
pytree = self._train_step(pytree)

save_times = np.array(save_times)
total_save_times = np.array(total_save_times)

# Exclude step 0 from assertions; setup may take extra time.
asserting_save_times = save_times[1:]
asserting_total_save_times = total_save_times[1:]

mean_save_time = np.mean(asserting_save_times)
mean_total_save_time = np.mean(asserting_total_save_times)

assert np.all(asserting_save_times <= 2 * mean_save_time), (
f'Save times={asserting_save_times}, mean save time={mean_save_time}'
)
assert np.all(asserting_total_save_times <= 2 * mean_total_save_time), (
f'Total save times={asserting_total_save_times}, mean total save'
f' time={mean_total_save_time}'
Expand Down
8 changes: 7 additions & 1 deletion checkpoint/orbax/checkpoint/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,10 +506,16 @@ def _array_restore_args(
sharding: Optional[jax.sharding.Sharding | Format], # pytype: disable=unsupported-operands
dtype: Optional[np.dtype] = None,
) -> type_handlers.ArrayRestoreArgs:
global_shape = None
# For random keys, we only allow overriding the sharding.
if set_global_shape and not jax.dtypes.issubdtype(
value.dtype, jax.dtypes.prng_key
):
global_shape = value.shape
return type_handlers.ArrayRestoreArgs(
restore_type=jax.Array,
sharding=sharding,
global_shape=value.shape if set_global_shape else None,
global_shape=global_shape,
dtype=dtype,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ async def open_tensorstore(
use_zarr3=use_zarr3,
ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt),
)
tspec = ts_utils.get_json_tspec_read(info, use_ocdbt=use_ocdbt)
array_read_spec = ts_utils.build_array_read_spec(info, use_ocdbt=use_ocdbt)
tspec = array_read_spec.json
return await ts.open(
ts.Spec(tspec),
read=True,
Expand Down
Loading
Loading