Skip to content

Commit 201e692

Browse files
author
Orbax Authors
committed
Run all correctness benchmarks in Github
PiperOrigin-RevId: 831318632
1 parent d966ddf commit 201e692

File tree

18 files changed

+301
-152
lines changed

18 files changed

+301
-152
lines changed

.github/workflows/build.yml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ jobs:
242242
working-directory: checkpoint
243243
strategy:
244244
matrix:
245-
python-version: ["3.10", "3.11", "3.12"]
245+
python-version: ["3.10"]
246246
jax-version: ["0.6.0"]
247247
steps:
248248
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -267,8 +267,11 @@ jobs:
267267
- name: Run benchmarks
268268
env:
269269
GCS_BUCKET_PATH: gs://orbax-benchmarks/benchmark-results/${{ github.run_id }}
270+
TF_FORCE_GPU_ALLOW_GROWTH: true
271+
XLA_PYTHON_CLIENT_PREALLOCATE: false
272+
KERAS_BACKEND: "jax"
270273
run: |
271-
cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
274+
cd orbax/checkpoint/_src/testing/benchmarks && python -c "import jax; print(jax.devices()); import run_benchmarks; run_benchmarks.main(['run_benchmarks.py', '--config_file=configs/pytree_checkpoint_benchmark.yaml', '--output_directory=$GCS_BUCKET_PATH'])"
272275
cd ../../../../..
273276
# python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py
274277
# cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH

.github/workflows/multiprocess_tests.yml

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
working-directory: checkpoint
2626
strategy:
2727
matrix:
28-
python-version: ["3.10", "3.11", "3.12"]
28+
python-version: ["3.12"]
2929
jax-version: ["0.6.0"]
3030
steps:
3131
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -50,9 +50,27 @@ jobs:
5050
- name: Run benchmarks
5151
env:
5252
GCS_BUCKET_PATH: gs://orbax-benchmarks/benchmark-results/${{ github.run_id }}
53+
TF_FORCE_GPU_ALLOW_GROWTH: true
54+
XLA_PYTHON_CLIENT_PREALLOCATE: false
5355
run: |
54-
cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
56+
cd orbax/checkpoint/_src/testing/benchmarks
57+
failed_benchmarks=""
58+
benchmark_configs_file="multiprocess_benchmark_configs.txt"
59+
echo "Running benchmarks specified in $benchmark_configs_file"
60+
while IFS= read -r entry || [ -n "$entry" ]; do
61+
if [ -n "$entry" ]; then
62+
echo "Running benchmark for $entry"
63+
if ! python run_benchmarks.py --config_file="$entry" --output_directory=$GCS_BUCKET_PATH; then
64+
echo "Benchmark $entry failed"
65+
failed_benchmarks="$failed_benchmarks $entry"
66+
fi
67+
fi
68+
done < "$benchmark_configs_file"
5569
cd ../../../../..
70+
if [ -n "$failed_benchmarks" ]; then
71+
echo "The following benchmarks failed:$failed_benchmarks"
72+
exit 1
73+
fi
5674
# python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py
5775
# cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
5876
# The below step just reports the success or failure of tests as a "commit status".

checkpoint/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1919
state, avoiding additional I/O to re-read metadata.
2020
- add `support_format` to utils.to_shape_dtype_struct()
2121
- Moved `register_pathways_handlers` to `ocp.pathways.register_type_handlers`.
22+
- Replace usage of `get_json_tpec_read` and delegate functionality to new
23+
function `build_array_read_spec` which constructs and returns an
24+
`ArrayReadSpec`.
2225

2326
## [0.11.28] - 2025-11-06
2427

checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -763,12 +763,13 @@ async def _async_deserialize(
763763
await _validate_non_ocdbt_files(infos, metadata_key)
764764
deserialize_ops = []
765765
for info, arg, sharding in zip(infos, args, shardings):
766-
tspec = ts_utils.get_json_tspec_read(
766+
array_read_spec = ts_utils.build_array_read_spec(
767767
info,
768768
use_ocdbt=use_ocdbt,
769769
metadata_key=metadata_key,
770770
raise_array_data_missing_error=info.raise_array_data_missing_error,
771771
)
772+
tspec = array_read_spec.json
772773
tspec = ts_utils.get_cast_tspec_deserialize(tspec, arg)
773774

774775
# set dtype=None to deserialize for random keys
@@ -939,19 +940,6 @@ def __init__(
939940
def has_dispatcher(self) -> bool:
940941
return self._dispatcher is not None
941942

942-
def _get_json_tspec_read(
943-
self,
944-
info: types.ParamInfo,
945-
use_ocdbt: bool,
946-
) -> Dict[str, Any]:
947-
"""Gets Tensorstore spec for reading."""
948-
return ts_utils.get_json_tspec_read(
949-
info,
950-
use_ocdbt=use_ocdbt,
951-
metadata_key=self._metadata_key,
952-
raise_array_data_missing_error=info.raise_array_data_missing_error,
953-
)
954-
955943
def typestr(self) -> str:
956944
return JAX_ARRAY_TYPE_STR
957945

@@ -968,7 +956,13 @@ async def metadata(
968956
for info in infos:
969957
# Use OCDBT flag from the existing checkpoint.
970958
use_ocdbt = info.is_ocdbt_checkpoint
971-
tspec = self._get_json_tspec_read(info, use_ocdbt=use_ocdbt)
959+
array_read_spec = ts_utils.build_array_read_spec(
960+
info,
961+
use_ocdbt=use_ocdbt,
962+
metadata_key=self._metadata_key,
963+
raise_array_data_missing_error=info.raise_array_data_missing_error,
964+
)
965+
tspec = array_read_spec.json
972966
open_ops.append(
973967
ts.open(ts.Spec(tspec), open=True, context=info.ts_context)
974968
)

checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,45 @@ def _maybe_add_cast_to_write_spec(
391391
return array_tspec
392392

393393

394+
class ArrayReadSpec:
395+
"""Full TensorStore spec for reading an array."""
396+
397+
def __init__(
398+
self,
399+
directory: str,
400+
relative_array_filename: str,
401+
use_zarr3: bool,
402+
*,
403+
use_ocdbt: bool,
404+
metadata_key: str | None = None,
405+
raise_array_data_missing_error: bool = True,
406+
):
407+
"""Builds a TensorStore spec for reading an array."""
408+
kvstore_tspec = build_kvstore_tspec(
409+
directory,
410+
name=relative_array_filename,
411+
use_ocdbt=use_ocdbt,
412+
process_id=None,
413+
)
414+
415+
tspec = {
416+
'driver': ZARR_VER3 if use_zarr3 else ZARR_VER2,
417+
'kvstore': kvstore_tspec,
418+
'recheck_cached_data': False,
419+
'recheck_cached_metadata': False,
420+
# Raise error if data is missing.
421+
'fill_missing_data_reads': not raise_array_data_missing_error,
422+
}
423+
if metadata_key is not None:
424+
tspec['metadata_key'] = metadata_key
425+
self._json_spec = tspec
426+
427+
@property
428+
def json(self) -> JsonSpec:
429+
"""Spec to be used to open a TensorStore for reading the array."""
430+
return self._json_spec
431+
432+
394433
class ArrayWriteSpec:
395434
"""Full TensorStore spec for writing an array."""
396435

@@ -677,6 +716,26 @@ def get_json_tspec_write(
677716
return tspec
678717

679718

719+
def build_array_read_spec(
720+
info: types.ParamInfo,
721+
*,
722+
use_ocdbt: bool,
723+
metadata_key: str | None = None,
724+
raise_array_data_missing_error: bool = True,
725+
) -> ArrayReadSpec:
726+
"""Gets ArrayReadSpec for reading."""
727+
if info.name is None or info.parent_dir is None:
728+
raise ValueError('Must provide info.name and info.parent_dir.')
729+
return ArrayReadSpec(
730+
directory=info.parent_dir.as_posix(),
731+
relative_array_filename=info.name,
732+
use_zarr3=info.use_zarr3,
733+
use_ocdbt=use_ocdbt,
734+
metadata_key=metadata_key,
735+
raise_array_data_missing_error=raise_array_data_missing_error,
736+
)
737+
738+
680739
def build_array_write_spec(
681740
info: types.ParamInfo,
682741
arg: types.SaveArgs | None = None,

checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,60 @@ def test_maybe_cloud_storage(self):
613613
self.assertTrue(ts_utils.is_remote_storage(nested_tspec))
614614

615615

616+
class BuildArrayTSpecForReadTest(parameterized.TestCase):
617+
618+
def setUp(self):
619+
super().setUp()
620+
self.directory = self.create_tempdir().full_path
621+
self.param_name = 'params/a'
622+
623+
self.array_read_spec_constructor = functools.partial(
624+
ts_utils.ArrayReadSpec,
625+
directory=self.directory,
626+
relative_array_filename=self.param_name,
627+
)
628+
629+
@parameterized.product(
630+
use_zarr3=(True, False),
631+
use_ocdbt=(True, False),
632+
)
633+
def test_basic(self, use_zarr3: bool, use_ocdbt: bool):
634+
tspec = self.array_read_spec_constructor(
635+
use_zarr3=use_zarr3,
636+
use_ocdbt=use_ocdbt,
637+
)
638+
json_spec = tspec.json
639+
self.assertEqual(json_spec['driver'], 'zarr3' if use_zarr3 else 'zarr')
640+
self.assertEqual(
641+
json_spec['kvstore']['driver'],
642+
'ocdbt' if use_ocdbt else ts_utils.DEFAULT_DRIVER,
643+
)
644+
self.assertFalse(json_spec['recheck_cached_data'])
645+
self.assertFalse(json_spec['recheck_cached_metadata'])
646+
self.assertFalse(json_spec['fill_missing_data_reads'])
647+
self.assertNotIn('metadata_key', json_spec)
648+
649+
def test_metadata_key(self):
650+
tspec = self.array_read_spec_constructor(
651+
use_zarr3=False,
652+
use_ocdbt=False,
653+
metadata_key='custom_metadata',
654+
)
655+
self.assertEqual(tspec.json['metadata_key'], 'custom_metadata')
656+
657+
@parameterized.parameters(True, False)
658+
def test_fill_missing_data_reads(self, raise_array_data_missing_error):
659+
tspec = self.array_read_spec_constructor(
660+
use_zarr3=False,
661+
use_ocdbt=False,
662+
raise_array_data_missing_error=raise_array_data_missing_error,
663+
)
664+
self.assertEqual(
665+
tspec.json['fill_missing_data_reads'],
666+
not raise_array_data_missing_error,
667+
)
668+
669+
616670
class GetTsContextTest(parameterized.TestCase):
617671

618672
@parameterized.product(

checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py

Lines changed: 21 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -77,39 +77,6 @@ def __init__(
7777
self._metadata_key = metadata_key
7878
self._override_ocdbt_process_id = ocdbt_process_id
7979

80-
def _get_array_write_spec(
81-
self,
82-
info: types.ParamInfo,
83-
value: np.ndarray,
84-
use_ocdbt: bool,
85-
process_index: Optional[Union[int, str]] = None,
86-
arg: Optional[types.SaveArgs] = None,
87-
) -> ts_utils.ArrayWriteSpec:
88-
"""Gets ArrayWriteSpec for writing."""
89-
return ts_utils.build_array_write_spec(
90-
info=info,
91-
arg=arg,
92-
global_shape=value.shape,
93-
local_shape=value.shape,
94-
dtype=value.dtype,
95-
use_ocdbt=use_ocdbt,
96-
process_index=process_index,
97-
metadata_key=self._metadata_key,
98-
)
99-
100-
def _get_json_tspec_read(
101-
self,
102-
info: types.ParamInfo,
103-
use_ocdbt: bool,
104-
) -> Dict[str, Any]:
105-
"""Gets Tensorstore spec for reading."""
106-
return ts_utils.get_json_tspec_read(
107-
info,
108-
use_ocdbt=use_ocdbt,
109-
metadata_key=self._metadata_key,
110-
raise_array_data_missing_error=info.raise_array_data_missing_error,
111-
)
112-
11380
def typestr(self) -> str:
11481
return 'np.ndarray'
11582

@@ -120,7 +87,13 @@ async def metadata(
12087
for info in infos:
12188
# Use OCDBT flag from the existing checkpoint.
12289
use_ocdbt = info.is_ocdbt_checkpoint
123-
tspec = self._get_json_tspec_read(info, use_ocdbt=use_ocdbt)
90+
array_read_spec = ts_utils.build_array_read_spec(
91+
info,
92+
use_ocdbt=use_ocdbt,
93+
metadata_key=self._metadata_key,
94+
raise_array_data_missing_error=info.raise_array_data_missing_error,
95+
)
96+
tspec = array_read_spec.json
12497
open_ops.append(
12598
ts.open(ts.Spec(tspec), open=True, context=info.ts_context)
12699
)
@@ -149,15 +122,18 @@ async def _background_serialize(
149122
"""Serializes numpy arrays in a background thread."""
150123
write_coros = []
151124
for value, info, arg in zip(values, infos, args):
152-
array_write_spec = self._get_array_write_spec(
153-
info,
154-
value,
125+
array_write_spec = ts_utils.build_array_write_spec(
126+
info=info,
127+
arg=arg,
128+
global_shape=value.shape,
129+
local_shape=value.shape,
130+
dtype=value.dtype,
155131
use_ocdbt=info.is_ocdbt_checkpoint,
156132
process_index=ocdbt_utils.get_process_index_for_subdir(
157133
use_ocdbt=info.is_ocdbt_checkpoint,
158134
override_ocdbt_process_id=self._override_ocdbt_process_id,
159135
),
160-
arg=arg,
136+
metadata_key=self._metadata_key,
161137
)
162138
tspec = array_write_spec.json
163139
if logging.vlog_is_on(1):
@@ -205,7 +181,13 @@ async def deserialize(
205181
)
206182
# Use OCDBT flag from the existing checkpoint.
207183
use_ocdbt = info.is_ocdbt_checkpoint
208-
tspec = self._get_json_tspec_read(info, use_ocdbt=use_ocdbt)
184+
array_read_spec = ts_utils.build_array_read_spec(
185+
info,
186+
use_ocdbt=use_ocdbt,
187+
metadata_key=self._metadata_key,
188+
raise_array_data_missing_error=info.raise_array_data_missing_error,
189+
)
190+
tspec = array_read_spec.json
209191
tspec = ts_utils.get_cast_tspec_deserialize(tspec, arg)
210192

211193
if logging.vlog_is_on(1):

checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/array_handler_benchmark.yaml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
suite_name: "ArrayHandler Benchmark"
22

3-
mesh_config:
4-
mesh_axes: ["data", "model"]
5-
ici_parallelism: {"data": 2, "model": 2}
6-
dcn_parallelism: {"data": 2, "model": 1}
3+
mesh_configs:
4+
- mesh_axes: ["data", "model"]
5+
ici_parallelism: {"data": 2, "model": 2}
6+
dcn_parallelism: {"data": 2, "model": 1}
7+
- mesh_axes: ["data", "model"]
8+
ici_parallelism: {"data": 1, "model": 1}
9+
dcn_parallelism: {"data": 4, "model": 1}
710

811
checkpoint_config:
912
spec:
@@ -17,5 +20,4 @@ benchmarks:
1720
use_replica_parallel: [True, False]
1821
enable_replica_parallel_separate_folder: [True, False]
1922
use_metadata_store: [True, False]
20-
use_colocated_python: [True, False]
2123

checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/checkpoint_manager_benchmark.yaml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
suite_name: "CheckpointManager Benchmark"
22

3-
mesh_config:
4-
mesh_axes: ["data", "model"]
5-
ici_parallelism: {"data": 2, "model": 2}
3+
mesh_configs:
4+
- mesh_axes: ["data", "model"]
5+
ici_parallelism: {"data": 2, "model": 2}
6+
- mesh_axes: ["data", "model"]
7+
ici_parallelism: {"data": 1, "model": 1}
8+
dcn_parallelism: {"data": 4, "model": 1}
69

710
checkpoint_config:
811
spec:

0 commit comments

Comments
 (0)