Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ jobs:
working-directory: checkpoint
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12"]
python-version: ["3.10"]
jax-version: ["0.6.0"]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
Expand All @@ -267,9 +267,30 @@ jobs:
- name: Run benchmarks
env:
GCS_BUCKET_PATH: gs://orbax-benchmarks/benchmark-results/${{ github.run_id }}
TF_FORCE_GPU_ALLOW_GROWTH: true
XLA_PYTHON_CLIENT_PREALLOCATE: false
KERAS_BACKEND: "jax"
run: |
cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
cd orbax/checkpoint/_src/testing/benchmarks
failed_benchmarks=""
benchmark_configs_file="multiprocess_benchmark_configs.txt"
echo "Running benchmarks specified in $benchmark_configs_file"
while IFS= read -r entry || [ -n "$entry" ]; do
if [ -n "$entry" ]; then
echo "Running benchmark for $entry"
if ! python -c "import sys; import jax; jax.distributed.initialize(); print(jax.devices()); from absl import app; import run_benchmarks; sys.argv = ['run_benchmarks.py', '--config_file="$entry"', '--output_directory=$GCS_BUCKET_PATH']; app.run(run_benchmarks.main)"; then
echo "Benchmark $entry failed"
failed_benchmarks="$failed_benchmarks $entry"
fi
fi
done < "$benchmark_configs_file"
cd ../../../../..
if [ -n "$failed_benchmarks" ]; then
echo "The following benchmarks failed:$failed_benchmarks"
exit 1
fi
# cd orbax/checkpoint/_src/testing/benchmarks && python -c "import sys; import jax; jax.distributed.initialize(); print(jax.devices()); from absl import app; import run_benchmarks; sys.argv = ['run_benchmarks.py', '--config_file=configs/pytree_checkpoint_benchmark.yaml', '--output_directory=$GCS_BUCKET_PATH']; app.run(run_benchmarks.main)"
# cd ../../../../..
# python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py
# cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
# The below step just reports the success or failure of tests as a "commit status".
Expand Down
22 changes: 20 additions & 2 deletions .github/workflows/multiprocess_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
working-directory: checkpoint
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12"]
python-version: ["3.12"]
jax-version: ["0.6.0"]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
Expand All @@ -50,9 +50,27 @@ jobs:
- name: Run benchmarks
env:
GCS_BUCKET_PATH: gs://orbax-benchmarks/benchmark-results/${{ github.run_id }}
TF_FORCE_GPU_ALLOW_GROWTH: true
XLA_PYTHON_CLIENT_PREALLOCATE: false
run: |
cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
cd orbax/checkpoint/_src/testing/benchmarks
failed_benchmarks=""
benchmark_configs_file="multiprocess_benchmark_configs.txt"
echo "Running benchmarks specified in $benchmark_configs_file"
while IFS= read -r entry || [ -n "$entry" ]; do
if [ -n "$entry" ]; then
echo "Running benchmark for $entry"
if ! python run_benchmarks.py --config_file="$entry" --output_directory=$GCS_BUCKET_PATH; then
echo "Benchmark $entry failed"
failed_benchmarks="$failed_benchmarks $entry"
fi
fi
done < "$benchmark_configs_file"
cd ../../../../..
if [ -n "$failed_benchmarks" ]; then
echo "The following benchmarks failed:$failed_benchmarks"
exit 1
fi
# python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py
# cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
# The below step just reports the success or failure of tests as a "commit status".
Expand Down
4 changes: 4 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fix `step_from_checkpoint_name` to allow the passed in checkpoint name to
include an arbitrary `step_prefix` with any character(s) such as underscores.
- Fix CheckpointManager initial directory creation to use `file_options.path_permission_mode`.
- Fix using jax.eval_shape with StandardRestore

### Changed

- Validate checkpoints before writing merged OCDBT database using in-memory
state, avoiding additional I/O to re-read metadata.
- add `support_format` to utils.to_shape_dtype_struct()
- Moved `register_pathways_handlers` to `ocp.pathways.register_type_handlers`.
- Replace usage of `get_json_tpec_read` and delegate functionality to new
function `build_array_read_spec` which constructs and returns an
`ArrayReadSpec`.

## [0.11.28] - 2025-11-06

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -505,8 +505,13 @@ def handler(self) -> StandardCheckpointHandler:
return StandardCheckpointHandler()

def test_with_random_keys(self):
# TODO(b/393160483) investigate Pathways remote Python support for
# random.keys.
if utils.is_pathways_backend():
self.skipTest('Pathways does not support random keys checkpoint.')
self.skipTest(
'Disabled on Pathways because random keys are not supported by'
' remote Python.'
)

def create_random_keys(seed):
duplicated_sharding = jax.sharding.NamedSharding(
Expand Down Expand Up @@ -559,3 +564,38 @@ def create_random_keys(seed):
args=self.restore_args_cls(abstract_tree),
)
test_utils.assert_tree_equal(self, self.pytree, restored)

def test_save_restore_random_keys_with_jax_eval_shape(self):
# TODO(b/393160483) investigate Pathways remote Python support for
# random.keys.
if utils.is_pathways_backend():
self.skipTest(
'Disabled on Pathways because random keys are not supported by'
' remote Python.'
)

mesh = jax.sharding.Mesh(jax.devices(), ('x',))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())

@functools.partial(
jax.jit,
in_shardings=sharding,
out_shardings=sharding,
)
def sharded_create_state_fn(root_key):
return dict(
matrix=jnp.array([[1, 2], [3, 4], [5, 6], [7, 8]]),
rngkey=jax.random.fold_in(root_key, 42),
)

pytree = sharded_create_state_fn(jax.random.key(0))
abstract_pytree = jax.eval_shape(
sharded_create_state_fn, jax.random.key(0)
)

self.handler.save(self.directory, args=self.save_args_cls(pytree))

restored = self.handler.restore(
self.directory, args=self.restore_args_cls(abstract_pytree)
)
test_utils.assert_tree_equal(self, pytree, restored)
4 changes: 2 additions & 2 deletions checkpoint/orbax/checkpoint/_src/path/atomicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,13 @@ async def _create_tmp_directory(
def _get_tmp_directory(final_path: epath.Path) -> epath.Path:
# Path may not be completely unique if a preemption occurs. We rely on the
# existing tmp directory being deleted elsewhere.
return epath.Path(final_path.parent) / (final_path.name + TMP_DIR_SUFFIX)
return final_path.parent / (final_path.name + TMP_DIR_SUFFIX)


def _get_final_directory(tmp_path: epath.Path) -> epath.Path:
if (suffix_idx := tmp_path.name.find(TMP_DIR_SUFFIX)) == -1:
raise ValueError(f'Expected {tmp_path} to end with "{TMP_DIR_SUFFIX}".')
return epath.Path(tmp_path.parent) / tmp_path.name[:suffix_idx]
return tmp_path.parent / tmp_path.name[:suffix_idx]


class TemporaryPathBase(atomicity_types.TemporaryPath):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -763,12 +763,13 @@ async def _async_deserialize(
await _validate_non_ocdbt_files(infos, metadata_key)
deserialize_ops = []
for info, arg, sharding in zip(infos, args, shardings):
tspec = ts_utils.get_json_tspec_read(
array_read_spec = ts_utils.build_array_read_spec(
info,
use_ocdbt=use_ocdbt,
metadata_key=metadata_key,
raise_array_data_missing_error=info.raise_array_data_missing_error,
)
tspec = array_read_spec.json
tspec = ts_utils.get_cast_tspec_deserialize(tspec, arg)

# set dtype=None to deserialize for random keys
Expand Down Expand Up @@ -939,19 +940,6 @@ def __init__(
def has_dispatcher(self) -> bool:
return self._dispatcher is not None

def _get_json_tspec_read(
self,
info: types.ParamInfo,
use_ocdbt: bool,
) -> Dict[str, Any]:
"""Gets Tensorstore spec for reading."""
return ts_utils.get_json_tspec_read(
info,
use_ocdbt=use_ocdbt,
metadata_key=self._metadata_key,
raise_array_data_missing_error=info.raise_array_data_missing_error,
)

def typestr(self) -> str:
return JAX_ARRAY_TYPE_STR

Expand All @@ -968,7 +956,13 @@ async def metadata(
for info in infos:
# Use OCDBT flag from the existing checkpoint.
use_ocdbt = info.is_ocdbt_checkpoint
tspec = self._get_json_tspec_read(info, use_ocdbt=use_ocdbt)
array_read_spec = ts_utils.build_array_read_spec(
info,
use_ocdbt=use_ocdbt,
metadata_key=self._metadata_key,
raise_array_data_missing_error=info.raise_array_data_missing_error,
)
tspec = array_read_spec.json
open_ops.append(
ts.open(ts.Spec(tspec), open=True, context=info.ts_context)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,45 @@ def _maybe_add_cast_to_write_spec(
return array_tspec


class ArrayReadSpec:
"""Full TensorStore spec for reading an array."""

def __init__(
self,
directory: str,
relative_array_filename: str,
use_zarr3: bool,
*,
use_ocdbt: bool,
metadata_key: str | None = None,
raise_array_data_missing_error: bool = True,
):
"""Builds a TensorStore spec for reading an array."""
kvstore_tspec = build_kvstore_tspec(
directory,
name=relative_array_filename,
use_ocdbt=use_ocdbt,
process_id=None,
)

tspec = {
'driver': ZARR_VER3 if use_zarr3 else ZARR_VER2,
'kvstore': kvstore_tspec,
'recheck_cached_data': False,
'recheck_cached_metadata': False,
# Raise error if data is missing.
'fill_missing_data_reads': not raise_array_data_missing_error,
}
if metadata_key is not None:
tspec['metadata_key'] = metadata_key
self._json_spec = tspec

@property
def json(self) -> JsonSpec:
"""Spec to be used to open a TensorStore for reading the array."""
return self._json_spec


class ArrayWriteSpec:
"""Full TensorStore spec for writing an array."""

Expand Down Expand Up @@ -677,6 +716,26 @@ def get_json_tspec_write(
return tspec


def build_array_read_spec(
info: types.ParamInfo,
*,
use_ocdbt: bool,
metadata_key: str | None = None,
raise_array_data_missing_error: bool = True,
) -> ArrayReadSpec:
"""Gets ArrayReadSpec for reading."""
if info.name is None or info.parent_dir is None:
raise ValueError('Must provide info.name and info.parent_dir.')
return ArrayReadSpec(
directory=info.parent_dir.as_posix(),
relative_array_filename=info.name,
use_zarr3=info.use_zarr3,
use_ocdbt=use_ocdbt,
metadata_key=metadata_key,
raise_array_data_missing_error=raise_array_data_missing_error,
)


def build_array_write_spec(
info: types.ParamInfo,
arg: types.SaveArgs | None = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,60 @@ def test_maybe_cloud_storage(self):
self.assertTrue(ts_utils.is_remote_storage(nested_tspec))


class BuildArrayTSpecForReadTest(parameterized.TestCase):

def setUp(self):
super().setUp()
self.directory = self.create_tempdir().full_path
self.param_name = 'params/a'

self.array_read_spec_constructor = functools.partial(
ts_utils.ArrayReadSpec,
directory=self.directory,
relative_array_filename=self.param_name,
)

@parameterized.product(
use_zarr3=(True, False),
use_ocdbt=(True, False),
)
def test_basic(self, use_zarr3: bool, use_ocdbt: bool):
tspec = self.array_read_spec_constructor(
use_zarr3=use_zarr3,
use_ocdbt=use_ocdbt,
)
json_spec = tspec.json
self.assertEqual(json_spec['driver'], 'zarr3' if use_zarr3 else 'zarr')
self.assertEqual(
json_spec['kvstore']['driver'],
'ocdbt' if use_ocdbt else ts_utils.DEFAULT_DRIVER,
)
self.assertFalse(json_spec['recheck_cached_data'])
self.assertFalse(json_spec['recheck_cached_metadata'])
self.assertFalse(json_spec['fill_missing_data_reads'])
self.assertNotIn('metadata_key', json_spec)

def test_metadata_key(self):
tspec = self.array_read_spec_constructor(
use_zarr3=False,
use_ocdbt=False,
metadata_key='custom_metadata',
)
self.assertEqual(tspec.json['metadata_key'], 'custom_metadata')

@parameterized.parameters(True, False)
def test_fill_missing_data_reads(self, raise_array_data_missing_error):
tspec = self.array_read_spec_constructor(
use_zarr3=False,
use_ocdbt=False,
raise_array_data_missing_error=raise_array_data_missing_error,
)
self.assertEqual(
tspec.json['fill_missing_data_reads'],
not raise_array_data_missing_error,
)


class GetTsContextTest(parameterized.TestCase):

@parameterized.product(
Expand Down
Loading
Loading