Skip to content

Commit 1f570ef

Browse files
cpgaffney1Orbax Authors
authored andcommitted
Add a check to prevent zero-sized arrays from being saved. This behavior already resulted in an error, but it was one that was difficult to parse.
PiperOrigin-RevId: 721870235
1 parent 6e80ecc commit 1f570ef

File tree

3 files changed

+24
-0
lines changed

3 files changed

+24
-0
lines changed

checkpoint/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313
`AsyncCheckpointer.save()`, and `CheckpointManager.save()`, which saves a custom
1414
dict of user metadata to `StepMetadata`.
1515

16+
### Fixed
17+
18+
- Add a check to prevent zero-sized arrays from being saved. This behavior
19+
already resulted in an error, but it was one that was difficult to parse.
20+
1621
## [0.11.1] - 2025-01-28
1722

1823
### Changed

checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler_test_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2308,3 +2308,14 @@ def deserialize(
23082308
checkpoint_handler.save(
23092309
self.directory, args=PyTreeSaveArgs(self.pytree)
23102310
)
2311+
2312+
@parameterized.parameters((True,), (False,))
2313+
def test_zero_size_array(self, use_jax_array: bool):
2314+
arr = np.ones(shape=(0,))
2315+
mesh = jax.sharding.Mesh(np.array(jax.devices()), axis_names=('x',))
2316+
pspec = jax.sharding.PartitionSpec()
2317+
if use_jax_array:
2318+
arr = test_utils.create_sharded_array(arr, mesh, pspec)
2319+
tree = [arr]
2320+
with self.assertRaisesRegex(ValueError, 'zero size'):
2321+
self.handler.save(self.directory, args=PyTreeSaveArgs(tree))

checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,12 @@ def check_input_arguments(*args):
249249
raise ValueError('Found input args with mismatched lengths.')
250250

251251

252+
def _check_array_values(values: Sequence[Union[jax.Array, np.ndarray]]):
253+
for v in values:
254+
if v.size == 0:
255+
raise ValueError('Cannot save arrays with zero size.')
256+
257+
252258
async def _validate_params(
253259
directory: epath.Path,
254260
ts_context: ts.Context,
@@ -631,6 +637,7 @@ async def serialize(
631637
"""Uses Tensorstore to serialize a numpy array."""
632638
args = args or [types.SaveArgs()] * len(values)
633639
check_input_arguments(values, infos, args)
640+
_check_array_values(values)
634641
if logging.vlog_is_on(1):
635642
_print_ts_debug_data(self._metadata_key, infos)
636643
copied_values = [copy.deepcopy(v) for v in values]
@@ -1102,6 +1109,7 @@ async def serialize(
11021109
)
11031110
args = args or [types.SaveArgs()] * len(values)
11041111
check_input_arguments(values, infos, args)
1112+
_check_array_values(values)
11051113

11061114
assert all([info.enable_pinned_host_transfer for info in infos]) or all(
11071115
[not info.enable_pinned_host_transfer for info in infos]

0 commit comments

Comments
 (0)