Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
state, avoiding additional I/O to re-read metadata.
- add `support_format` to utils.to_shape_dtype_struct()
- Moved `register_pathways_handlers` to `ocp.pathways.register_type_handlers`.
- Replace usage of `get_json_tpec_read` and delegate functionality to new
function `build_array_read_spec` which constructs and returns an
`ArrayReadSpec`.

## [0.11.28] - 2025-11-06

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -763,12 +763,13 @@ async def _async_deserialize(
await _validate_non_ocdbt_files(infos, metadata_key)
deserialize_ops = []
for info, arg, sharding in zip(infos, args, shardings):
tspec = ts_utils.get_json_tspec_read(
array_read_spec = ts_utils.build_array_read_spec(
info,
use_ocdbt=use_ocdbt,
metadata_key=metadata_key,
raise_array_data_missing_error=info.raise_array_data_missing_error,
)
tspec = array_read_spec.json
tspec = ts_utils.get_cast_tspec_deserialize(tspec, arg)

# set dtype=None to deserialize for random keys
Expand Down Expand Up @@ -939,19 +940,6 @@ def __init__(
def has_dispatcher(self) -> bool:
return self._dispatcher is not None

def _get_json_tspec_read(
self,
info: types.ParamInfo,
use_ocdbt: bool,
) -> Dict[str, Any]:
"""Gets Tensorstore spec for reading."""
return ts_utils.get_json_tspec_read(
info,
use_ocdbt=use_ocdbt,
metadata_key=self._metadata_key,
raise_array_data_missing_error=info.raise_array_data_missing_error,
)

def typestr(self) -> str:
return JAX_ARRAY_TYPE_STR

Expand All @@ -968,7 +956,13 @@ async def metadata(
for info in infos:
# Use OCDBT flag from the existing checkpoint.
use_ocdbt = info.is_ocdbt_checkpoint
tspec = self._get_json_tspec_read(info, use_ocdbt=use_ocdbt)
array_read_spec = ts_utils.build_array_read_spec(
info,
use_ocdbt=use_ocdbt,
metadata_key=self._metadata_key,
raise_array_data_missing_error=info.raise_array_data_missing_error,
)
tspec = array_read_spec.json
open_ops.append(
ts.open(ts.Spec(tspec), open=True, context=info.ts_context)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,45 @@ def _maybe_add_cast_to_write_spec(
return array_tspec


class ArrayReadSpec:
"""Full TensorStore spec for reading an array."""

def __init__(
self,
directory: str,
relative_array_filename: str,
use_zarr3: bool,
*,
use_ocdbt: bool,
metadata_key: str | None = None,
raise_array_data_missing_error: bool = True,
):
"""Builds a TensorStore spec for reading an array."""
kvstore_tspec = build_kvstore_tspec(
directory,
name=relative_array_filename,
use_ocdbt=use_ocdbt,
process_id=None,
)

tspec = {
'driver': ZARR_VER3 if use_zarr3 else ZARR_VER2,
'kvstore': kvstore_tspec,
'recheck_cached_data': False,
'recheck_cached_metadata': False,
# Raise error if data is missing.
'fill_missing_data_reads': not raise_array_data_missing_error,
}
if metadata_key is not None:
tspec['metadata_key'] = metadata_key
self._json_spec = tspec

@property
def json(self) -> JsonSpec:
"""Spec to be used to open a TensorStore for reading the array."""
return self._json_spec


class ArrayWriteSpec:
"""Full TensorStore spec for writing an array."""

Expand Down Expand Up @@ -677,6 +716,26 @@ def get_json_tspec_write(
return tspec


def build_array_read_spec(
info: types.ParamInfo,
*,
use_ocdbt: bool,
metadata_key: str | None = None,
raise_array_data_missing_error: bool = True,
) -> ArrayReadSpec:
"""Gets ArrayReadSpec for reading."""
if info.name is None or info.parent_dir is None:
raise ValueError('Must provide info.name and info.parent_dir.')
return ArrayReadSpec(
directory=info.parent_dir.as_posix(),
relative_array_filename=info.name,
use_zarr3=info.use_zarr3,
use_ocdbt=use_ocdbt,
metadata_key=metadata_key,
raise_array_data_missing_error=raise_array_data_missing_error,
)


def build_array_write_spec(
info: types.ParamInfo,
arg: types.SaveArgs | None = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,60 @@ def test_maybe_cloud_storage(self):
self.assertTrue(ts_utils.is_remote_storage(nested_tspec))


class BuildArrayTSpecForReadTest(parameterized.TestCase):

def setUp(self):
super().setUp()
self.directory = self.create_tempdir().full_path
self.param_name = 'params/a'

self.array_read_spec_constructor = functools.partial(
ts_utils.ArrayReadSpec,
directory=self.directory,
relative_array_filename=self.param_name,
)

@parameterized.product(
use_zarr3=(True, False),
use_ocdbt=(True, False),
)
def test_basic(self, use_zarr3: bool, use_ocdbt: bool):
tspec = self.array_read_spec_constructor(
use_zarr3=use_zarr3,
use_ocdbt=use_ocdbt,
)
json_spec = tspec.json
self.assertEqual(json_spec['driver'], 'zarr3' if use_zarr3 else 'zarr')
self.assertEqual(
json_spec['kvstore']['driver'],
'ocdbt' if use_ocdbt else ts_utils.DEFAULT_DRIVER,
)
self.assertFalse(json_spec['recheck_cached_data'])
self.assertFalse(json_spec['recheck_cached_metadata'])
self.assertFalse(json_spec['fill_missing_data_reads'])
self.assertNotIn('metadata_key', json_spec)

def test_metadata_key(self):
tspec = self.array_read_spec_constructor(
use_zarr3=False,
use_ocdbt=False,
metadata_key='custom_metadata',
)
self.assertEqual(tspec.json['metadata_key'], 'custom_metadata')

@parameterized.parameters(True, False)
def test_fill_missing_data_reads(self, raise_array_data_missing_error):
tspec = self.array_read_spec_constructor(
use_zarr3=False,
use_ocdbt=False,
raise_array_data_missing_error=raise_array_data_missing_error,
)
self.assertEqual(
tspec.json['fill_missing_data_reads'],
not raise_array_data_missing_error,
)


class GetTsContextTest(parameterized.TestCase):

@parameterized.product(
Expand Down
60 changes: 21 additions & 39 deletions checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,39 +77,6 @@ def __init__(
self._metadata_key = metadata_key
self._override_ocdbt_process_id = ocdbt_process_id

def _get_array_write_spec(
self,
info: types.ParamInfo,
value: np.ndarray,
use_ocdbt: bool,
process_index: Optional[Union[int, str]] = None,
arg: Optional[types.SaveArgs] = None,
) -> ts_utils.ArrayWriteSpec:
"""Gets ArrayWriteSpec for writing."""
return ts_utils.build_array_write_spec(
info=info,
arg=arg,
global_shape=value.shape,
local_shape=value.shape,
dtype=value.dtype,
use_ocdbt=use_ocdbt,
process_index=process_index,
metadata_key=self._metadata_key,
)

def _get_json_tspec_read(
self,
info: types.ParamInfo,
use_ocdbt: bool,
) -> Dict[str, Any]:
"""Gets Tensorstore spec for reading."""
return ts_utils.get_json_tspec_read(
info,
use_ocdbt=use_ocdbt,
metadata_key=self._metadata_key,
raise_array_data_missing_error=info.raise_array_data_missing_error,
)

def typestr(self) -> str:
return 'np.ndarray'

Expand All @@ -120,7 +87,13 @@ async def metadata(
for info in infos:
# Use OCDBT flag from the existing checkpoint.
use_ocdbt = info.is_ocdbt_checkpoint
tspec = self._get_json_tspec_read(info, use_ocdbt=use_ocdbt)
array_read_spec = ts_utils.build_array_read_spec(
info,
use_ocdbt=use_ocdbt,
metadata_key=self._metadata_key,
raise_array_data_missing_error=info.raise_array_data_missing_error,
)
tspec = array_read_spec.json
open_ops.append(
ts.open(ts.Spec(tspec), open=True, context=info.ts_context)
)
Expand Down Expand Up @@ -149,15 +122,18 @@ async def _background_serialize(
"""Serializes numpy arrays in a background thread."""
write_coros = []
for value, info, arg in zip(values, infos, args):
array_write_spec = self._get_array_write_spec(
info,
value,
array_write_spec = ts_utils.build_array_write_spec(
info=info,
arg=arg,
global_shape=value.shape,
local_shape=value.shape,
dtype=value.dtype,
use_ocdbt=info.is_ocdbt_checkpoint,
process_index=ocdbt_utils.get_process_index_for_subdir(
use_ocdbt=info.is_ocdbt_checkpoint,
override_ocdbt_process_id=self._override_ocdbt_process_id,
),
arg=arg,
metadata_key=self._metadata_key,
)
tspec = array_write_spec.json
if logging.vlog_is_on(1):
Expand Down Expand Up @@ -205,7 +181,13 @@ async def deserialize(
)
# Use OCDBT flag from the existing checkpoint.
use_ocdbt = info.is_ocdbt_checkpoint
tspec = self._get_json_tspec_read(info, use_ocdbt=use_ocdbt)
array_read_spec = ts_utils.build_array_read_spec(
info,
use_ocdbt=use_ocdbt,
metadata_key=self._metadata_key,
raise_array_data_missing_error=info.raise_array_data_missing_error,
)
tspec = array_read_spec.json
tspec = ts_utils.get_cast_tspec_deserialize(tspec, arg)

if logging.vlog_is_on(1):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ async def open_tensorstore(
use_zarr3=use_zarr3,
ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt),
)
tspec = ts_utils.get_json_tspec_read(info, use_ocdbt=use_ocdbt)
array_read_spec = ts_utils.build_array_read_spec(info, use_ocdbt=use_ocdbt)
tspec = array_read_spec.json
return await ts.open(
ts.Spec(tspec),
read=True,
Expand Down
9 changes: 9 additions & 0 deletions export/orbax/export/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,18 @@ class ExportModelType(enum.Enum):
# Mesh for the model.
JAX_MESH = 'jax_mesh'

# TODO: b/459991985 - Remove this flag and use PERSIST_XLA_FLAGS instead.
# Whether to strip XLA flags from the model.
STRIP_XLA_FLAGS = 'strip_xla_flags'

# Whether to persist XLA flags in the model.
PERSIST_XLA_FLAGS = 'persist_xla_flags'

# Whether to enable bf16 optimization for the model.
# TODO_REGEX: b/422170690: (1): Apply this flag to the pre/post processors. (2):
# Adding filter flags once the flag is applied to the pre/post processors.
ENABLE_BF16_OPTIMIZATION = 'enable_bf16_optimization'

################################################################################
# Proto field names
################################################################################
Expand Down
10 changes: 10 additions & 0 deletions export/orbax/export/jax_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,16 @@ def jax2tf_kwargs_map(self) -> Mapping[str, Any]:
tensorflow_module.TensorFlowModule, self._export_module
).jax2tf_kwargs_map

@property
def jax2obm_kwargs(self) -> Mapping[str, Any]:
"""Returns the jax2obm_kwargs."""
if self._export_version == constants.ExportModelType.TF_SAVEDMODEL:
raise TypeError(
'jax2obm_kwargs is not implemented for export version'
' ExportModelType.TF_SAVEDMODEL.'
)
return cast(obm_module.ObmModule, self._export_module).jax2obm_kwargs

@property
def input_polymorphic_shape_map(self) -> Mapping[str, PyTree]:
"""Returns the polymorphic shapes."""
Expand Down
Loading
Loading