Skip to content

Commit 2f18992

Browse files
hejiang0116Orbax Authors
authored andcommitted
Internal change
PiperOrigin-RevId: 834940834
1 parent d966ddf commit 2f18992

22 files changed

+427
-167
lines changed

checkpoint/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1212
- Fix `step_from_checkpoint_name` to allow the passed in checkpoint name to
1313
include an arbitrary `step_prefix` with any character(s) such as underscores.
1414
- Fix CheckpointManager initial directory creation to use `file_options.path_permission_mode`.
15+
- Fix using jax.eval_shape with StandardRestore
1516

1617
### Changed
1718

1819
- Validate checkpoints before writing merged OCDBT database using in-memory
1920
state, avoiding additional I/O to re-read metadata.
2021
- add `support_format` to utils.to_shape_dtype_struct()
2122
- Moved `register_pathways_handlers` to `ocp.pathways.register_type_handlers`.
23+
- Replace usage of `get_json_tpec_read` and delegate functionality to new
24+
function `build_array_read_spec` which constructs and returns an
25+
`ArrayReadSpec`.
2226

2327
## [0.11.28] - 2025-11-06
2428

checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler_test_utils.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,8 +505,13 @@ def handler(self) -> StandardCheckpointHandler:
505505
return StandardCheckpointHandler()
506506

507507
def test_with_random_keys(self):
508+
# TODO(b/393160483) investigate Pathways remote Python support for
509+
# random.keys.
508510
if utils.is_pathways_backend():
509-
self.skipTest('Pathways does not support random keys checkpoint.')
511+
self.skipTest(
512+
'Disabled on Pathways because random keys are not supported by'
513+
' remote Python.'
514+
)
510515

511516
def create_random_keys(seed):
512517
duplicated_sharding = jax.sharding.NamedSharding(
@@ -559,3 +564,38 @@ def create_random_keys(seed):
559564
args=self.restore_args_cls(abstract_tree),
560565
)
561566
test_utils.assert_tree_equal(self, self.pytree, restored)
567+
568+
def test_save_restore_random_keys_with_jax_eval_shape(self):
569+
# TODO(b/393160483) investigate Pathways remote Python support for
570+
# random.keys.
571+
if utils.is_pathways_backend():
572+
self.skipTest(
573+
'Disabled on Pathways because random keys are not supported by'
574+
' remote Python.'
575+
)
576+
577+
mesh = jax.sharding.Mesh(jax.devices(), ('x',))
578+
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
579+
580+
@functools.partial(
581+
jax.jit,
582+
in_shardings=sharding,
583+
out_shardings=sharding,
584+
)
585+
def sharded_create_state_fn(root_key):
586+
return dict(
587+
matrix=jnp.array([[1, 2], [3, 4], [5, 6], [7, 8]]),
588+
rngkey=jax.random.fold_in(root_key, 42),
589+
)
590+
591+
pytree = sharded_create_state_fn(jax.random.key(0))
592+
abstract_pytree = jax.eval_shape(
593+
sharded_create_state_fn, jax.random.key(0)
594+
)
595+
596+
self.handler.save(self.directory, args=self.save_args_cls(pytree))
597+
598+
restored = self.handler.restore(
599+
self.directory, args=self.restore_args_cls(abstract_pytree)
600+
)
601+
test_utils.assert_tree_equal(self, pytree, restored)

checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -763,12 +763,13 @@ async def _async_deserialize(
763763
await _validate_non_ocdbt_files(infos, metadata_key)
764764
deserialize_ops = []
765765
for info, arg, sharding in zip(infos, args, shardings):
766-
tspec = ts_utils.get_json_tspec_read(
766+
array_read_spec = ts_utils.build_array_read_spec(
767767
info,
768768
use_ocdbt=use_ocdbt,
769769
metadata_key=metadata_key,
770770
raise_array_data_missing_error=info.raise_array_data_missing_error,
771771
)
772+
tspec = array_read_spec.json
772773
tspec = ts_utils.get_cast_tspec_deserialize(tspec, arg)
773774

774775
# set dtype=None to deserialize for random keys
@@ -939,19 +940,6 @@ def __init__(
939940
def has_dispatcher(self) -> bool:
940941
return self._dispatcher is not None
941942

942-
def _get_json_tspec_read(
943-
self,
944-
info: types.ParamInfo,
945-
use_ocdbt: bool,
946-
) -> Dict[str, Any]:
947-
"""Gets Tensorstore spec for reading."""
948-
return ts_utils.get_json_tspec_read(
949-
info,
950-
use_ocdbt=use_ocdbt,
951-
metadata_key=self._metadata_key,
952-
raise_array_data_missing_error=info.raise_array_data_missing_error,
953-
)
954-
955943
def typestr(self) -> str:
956944
return JAX_ARRAY_TYPE_STR
957945

@@ -968,7 +956,13 @@ async def metadata(
968956
for info in infos:
969957
# Use OCDBT flag from the existing checkpoint.
970958
use_ocdbt = info.is_ocdbt_checkpoint
971-
tspec = self._get_json_tspec_read(info, use_ocdbt=use_ocdbt)
959+
array_read_spec = ts_utils.build_array_read_spec(
960+
info,
961+
use_ocdbt=use_ocdbt,
962+
metadata_key=self._metadata_key,
963+
raise_array_data_missing_error=info.raise_array_data_missing_error,
964+
)
965+
tspec = array_read_spec.json
972966
open_ops.append(
973967
ts.open(ts.Spec(tspec), open=True, context=info.ts_context)
974968
)

checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,45 @@ def _maybe_add_cast_to_write_spec(
391391
return array_tspec
392392

393393

394+
class ArrayReadSpec:
395+
"""Full TensorStore spec for reading an array."""
396+
397+
def __init__(
398+
self,
399+
directory: str,
400+
relative_array_filename: str,
401+
use_zarr3: bool,
402+
*,
403+
use_ocdbt: bool,
404+
metadata_key: str | None = None,
405+
raise_array_data_missing_error: bool = True,
406+
):
407+
"""Builds a TensorStore spec for reading an array."""
408+
kvstore_tspec = build_kvstore_tspec(
409+
directory,
410+
name=relative_array_filename,
411+
use_ocdbt=use_ocdbt,
412+
process_id=None,
413+
)
414+
415+
tspec = {
416+
'driver': ZARR_VER3 if use_zarr3 else ZARR_VER2,
417+
'kvstore': kvstore_tspec,
418+
'recheck_cached_data': False,
419+
'recheck_cached_metadata': False,
420+
# Raise error if data is missing.
421+
'fill_missing_data_reads': not raise_array_data_missing_error,
422+
}
423+
if metadata_key is not None:
424+
tspec['metadata_key'] = metadata_key
425+
self._json_spec = tspec
426+
427+
@property
428+
def json(self) -> JsonSpec:
429+
"""Spec to be used to open a TensorStore for reading the array."""
430+
return self._json_spec
431+
432+
394433
class ArrayWriteSpec:
395434
"""Full TensorStore spec for writing an array."""
396435

@@ -677,6 +716,26 @@ def get_json_tspec_write(
677716
return tspec
678717

679718

719+
def build_array_read_spec(
720+
info: types.ParamInfo,
721+
*,
722+
use_ocdbt: bool,
723+
metadata_key: str | None = None,
724+
raise_array_data_missing_error: bool = True,
725+
) -> ArrayReadSpec:
726+
"""Gets ArrayReadSpec for reading."""
727+
if info.name is None or info.parent_dir is None:
728+
raise ValueError('Must provide info.name and info.parent_dir.')
729+
return ArrayReadSpec(
730+
directory=info.parent_dir.as_posix(),
731+
relative_array_filename=info.name,
732+
use_zarr3=info.use_zarr3,
733+
use_ocdbt=use_ocdbt,
734+
metadata_key=metadata_key,
735+
raise_array_data_missing_error=raise_array_data_missing_error,
736+
)
737+
738+
680739
def build_array_write_spec(
681740
info: types.ParamInfo,
682741
arg: types.SaveArgs | None = None,

checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,60 @@ def test_maybe_cloud_storage(self):
613613
self.assertTrue(ts_utils.is_remote_storage(nested_tspec))
614614

615615

616+
class BuildArrayTSpecForReadTest(parameterized.TestCase):
617+
618+
def setUp(self):
619+
super().setUp()
620+
self.directory = self.create_tempdir().full_path
621+
self.param_name = 'params/a'
622+
623+
self.array_read_spec_constructor = functools.partial(
624+
ts_utils.ArrayReadSpec,
625+
directory=self.directory,
626+
relative_array_filename=self.param_name,
627+
)
628+
629+
@parameterized.product(
630+
use_zarr3=(True, False),
631+
use_ocdbt=(True, False),
632+
)
633+
def test_basic(self, use_zarr3: bool, use_ocdbt: bool):
634+
tspec = self.array_read_spec_constructor(
635+
use_zarr3=use_zarr3,
636+
use_ocdbt=use_ocdbt,
637+
)
638+
json_spec = tspec.json
639+
self.assertEqual(json_spec['driver'], 'zarr3' if use_zarr3 else 'zarr')
640+
self.assertEqual(
641+
json_spec['kvstore']['driver'],
642+
'ocdbt' if use_ocdbt else ts_utils.DEFAULT_DRIVER,
643+
)
644+
self.assertFalse(json_spec['recheck_cached_data'])
645+
self.assertFalse(json_spec['recheck_cached_metadata'])
646+
self.assertFalse(json_spec['fill_missing_data_reads'])
647+
self.assertNotIn('metadata_key', json_spec)
648+
649+
def test_metadata_key(self):
650+
tspec = self.array_read_spec_constructor(
651+
use_zarr3=False,
652+
use_ocdbt=False,
653+
metadata_key='custom_metadata',
654+
)
655+
self.assertEqual(tspec.json['metadata_key'], 'custom_metadata')
656+
657+
@parameterized.parameters(True, False)
658+
def test_fill_missing_data_reads(self, raise_array_data_missing_error):
659+
tspec = self.array_read_spec_constructor(
660+
use_zarr3=False,
661+
use_ocdbt=False,
662+
raise_array_data_missing_error=raise_array_data_missing_error,
663+
)
664+
self.assertEqual(
665+
tspec.json['fill_missing_data_reads'],
666+
not raise_array_data_missing_error,
667+
)
668+
669+
616670
class GetTsContextTest(parameterized.TestCase):
617671

618672
@parameterized.product(

checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py

Lines changed: 21 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -77,39 +77,6 @@ def __init__(
7777
self._metadata_key = metadata_key
7878
self._override_ocdbt_process_id = ocdbt_process_id
7979

80-
def _get_array_write_spec(
81-
self,
82-
info: types.ParamInfo,
83-
value: np.ndarray,
84-
use_ocdbt: bool,
85-
process_index: Optional[Union[int, str]] = None,
86-
arg: Optional[types.SaveArgs] = None,
87-
) -> ts_utils.ArrayWriteSpec:
88-
"""Gets ArrayWriteSpec for writing."""
89-
return ts_utils.build_array_write_spec(
90-
info=info,
91-
arg=arg,
92-
global_shape=value.shape,
93-
local_shape=value.shape,
94-
dtype=value.dtype,
95-
use_ocdbt=use_ocdbt,
96-
process_index=process_index,
97-
metadata_key=self._metadata_key,
98-
)
99-
100-
def _get_json_tspec_read(
101-
self,
102-
info: types.ParamInfo,
103-
use_ocdbt: bool,
104-
) -> Dict[str, Any]:
105-
"""Gets Tensorstore spec for reading."""
106-
return ts_utils.get_json_tspec_read(
107-
info,
108-
use_ocdbt=use_ocdbt,
109-
metadata_key=self._metadata_key,
110-
raise_array_data_missing_error=info.raise_array_data_missing_error,
111-
)
112-
11380
def typestr(self) -> str:
11481
return 'np.ndarray'
11582

@@ -120,7 +87,13 @@ async def metadata(
12087
for info in infos:
12188
# Use OCDBT flag from the existing checkpoint.
12289
use_ocdbt = info.is_ocdbt_checkpoint
123-
tspec = self._get_json_tspec_read(info, use_ocdbt=use_ocdbt)
90+
array_read_spec = ts_utils.build_array_read_spec(
91+
info,
92+
use_ocdbt=use_ocdbt,
93+
metadata_key=self._metadata_key,
94+
raise_array_data_missing_error=info.raise_array_data_missing_error,
95+
)
96+
tspec = array_read_spec.json
12497
open_ops.append(
12598
ts.open(ts.Spec(tspec), open=True, context=info.ts_context)
12699
)
@@ -149,15 +122,18 @@ async def _background_serialize(
149122
"""Serializes numpy arrays in a background thread."""
150123
write_coros = []
151124
for value, info, arg in zip(values, infos, args):
152-
array_write_spec = self._get_array_write_spec(
153-
info,
154-
value,
125+
array_write_spec = ts_utils.build_array_write_spec(
126+
info=info,
127+
arg=arg,
128+
global_shape=value.shape,
129+
local_shape=value.shape,
130+
dtype=value.dtype,
155131
use_ocdbt=info.is_ocdbt_checkpoint,
156132
process_index=ocdbt_utils.get_process_index_for_subdir(
157133
use_ocdbt=info.is_ocdbt_checkpoint,
158134
override_ocdbt_process_id=self._override_ocdbt_process_id,
159135
),
160-
arg=arg,
136+
metadata_key=self._metadata_key,
161137
)
162138
tspec = array_write_spec.json
163139
if logging.vlog_is_on(1):
@@ -205,7 +181,13 @@ async def deserialize(
205181
)
206182
# Use OCDBT flag from the existing checkpoint.
207183
use_ocdbt = info.is_ocdbt_checkpoint
208-
tspec = self._get_json_tspec_read(info, use_ocdbt=use_ocdbt)
184+
array_read_spec = ts_utils.build_array_read_spec(
185+
info,
186+
use_ocdbt=use_ocdbt,
187+
metadata_key=self._metadata_key,
188+
raise_array_data_missing_error=info.raise_array_data_missing_error,
189+
)
190+
tspec = array_read_spec.json
209191
tspec = ts_utils.get_cast_tspec_deserialize(tspec, arg)
210192

211193
if logging.vlog_is_on(1):

checkpoint/orbax/checkpoint/_src/testing/benchmarks/checkpoint_manager_perf_benchmark.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,19 +118,13 @@ def test_fn(
118118
with metrics.measure(f'train_step_{i}'):
119119
pytree = self._train_step(pytree)
120120

121-
save_times = np.array(save_times)
122121
total_save_times = np.array(total_save_times)
123122

124123
# Exclude step 0 from assertions; setup may take extra time.
125-
asserting_save_times = save_times[1:]
126124
asserting_total_save_times = total_save_times[1:]
127125

128-
mean_save_time = np.mean(asserting_save_times)
129126
mean_total_save_time = np.mean(asserting_total_save_times)
130127

131-
assert np.all(asserting_save_times <= 2 * mean_save_time), (
132-
f'Save times={asserting_save_times}, mean save time={mean_save_time}'
133-
)
134128
assert np.all(asserting_total_save_times <= 2 * mean_total_save_time), (
135129
f'Total save times={asserting_total_save_times}, mean total save'
136130
f' time={mean_total_save_time}'

checkpoint/orbax/checkpoint/checkpoint_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,10 +506,16 @@ def _array_restore_args(
506506
sharding: Optional[jax.sharding.Sharding | Format], # pytype: disable=unsupported-operands
507507
dtype: Optional[np.dtype] = None,
508508
) -> type_handlers.ArrayRestoreArgs:
509+
global_shape = None
510+
# For random keys, we only allow overriding the sharding.
511+
if set_global_shape and not jax.dtypes.issubdtype(
512+
value.dtype, jax.dtypes.prng_key
513+
):
514+
global_shape = value.shape
509515
return type_handlers.ArrayRestoreArgs(
510516
restore_type=jax.Array,
511517
sharding=sharding,
512-
global_shape=value.shape if set_global_shape else None,
518+
global_shape=global_shape,
513519
dtype=dtype,
514520
)
515521

checkpoint/orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ async def open_tensorstore(
101101
use_zarr3=use_zarr3,
102102
ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt),
103103
)
104-
tspec = ts_utils.get_json_tspec_read(info, use_ocdbt=use_ocdbt)
104+
array_read_spec = ts_utils.build_array_read_spec(info, use_ocdbt=use_ocdbt)
105+
tspec = array_read_spec.json
105106
return await ts.open(
106107
ts.Spec(tspec),
107108
read=True,

0 commit comments

Comments
 (0)