Skip to content

Commit 1f5536e

Browse files
cpgaffney1Orbax Authors
authored andcommitted
Introduce an enum CheckpointingImpl to simplify Pathways registration and to allow multiple options to be specified and resolved in order of priority.
PiperOrigin-RevId: 833441656
1 parent 8cb9ca7 commit 1f5536e

File tree

9 files changed

+83
-16
lines changed

9 files changed

+83
-16
lines changed

checkpoint/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1717
- Validate checkpoints before writing merged OCDBT database using in-memory
1818
state, avoiding additional I/O to re-read metadata.
1919
- add `support_format` to utils.to_shape_dtype_struct()
20+
- Moved `register_pathways_handlers` to `ocp.pathways.register_type_handlers`.
2021

2122
## [0.11.28] - 2025-11-06
2223

checkpoint/orbax/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from orbax.checkpoint import msgpack_utils
3333
from orbax.checkpoint import options
3434
from orbax.checkpoint import path
35+
from orbax.checkpoint import pathways
3536
from orbax.checkpoint import serialization
3637
from orbax.checkpoint import transform_utils
3738
from orbax.checkpoint import tree

checkpoint/orbax/checkpoint/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from orbax.checkpoint import msgpack_utils
3333
from orbax.checkpoint import options
3434
from orbax.checkpoint import path
35+
from orbax.checkpoint import pathways
3536
from orbax.checkpoint import serialization
3637
from orbax.checkpoint import transform_utils
3738
from orbax.checkpoint import tree

checkpoint/orbax/checkpoint/_src/serialization/pathways_handler_registry.py

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
"""Registers the Pathways handlers with the given options."""
1616

17+
from __future__ import annotations
18+
19+
import enum
1720
import types
1821

1922
from absl import logging
@@ -73,19 +76,57 @@ def _register_numpy_and_scalar_handlers():
7376
)
7477

7578

79+
class CheckpointingImpl(enum.Enum):
80+
"""The implementation to use for Pathways checkpointing."""
81+
82+
NO_DISPATCHER = enum.auto()
83+
COLOCATED_PYTHON = enum.auto()
84+
85+
@classmethod
86+
def from_options(
87+
cls,
88+
*,
89+
use_colocated_python: bool = False,
90+
) -> CheckpointingImpl:
91+
"""Obtains a CheckpointingImpl from the given options.
92+
93+
More than one option can be set to True. Resolves in order of priority:
94+
1. Colocated Python
95+
4. No Dispatcher
96+
97+
Args:
98+
use_colocated_python: Whether to use colocated Python. # BEGIN
99+
use_remote_python: Whether to use remote Python.
100+
use_persistence_array_handler: Whether to use the persistence array
101+
102+
Returns:
103+
The CheckpointingImpl to use.
104+
"""
105+
if use_colocated_python:
106+
return cls.COLOCATED_PYTHON
107+
else:
108+
return cls.NO_DISPATCHER
109+
110+
76111
def get_pathways_array_handler(
77112
use_single_replica_array_handler: bool = False,
78-
use_colocated_python: bool = True,
113+
checkpointing_impl: CheckpointingImpl | None = None,
79114
**kwargs,
80115
) -> type_handlers.ArrayHandler:
81116
"""Returns the Pathways ArrayHandler with the given options."""
82-
83-
if use_colocated_python:
84-
logging.info('Using ColocatedPythonDispatcher')
85-
dispatcher = dispatchers.ColocatedPythonDispatcher()
86-
else:
87-
logging.info('Not using dispatcher')
88-
dispatcher = None
117+
# If not set, use whichever dispatcher implementation is available.
118+
checkpointing_impl = checkpointing_impl or CheckpointingImpl.from_options(
119+
use_colocated_python=True,
120+
)
121+
match checkpointing_impl:
122+
case CheckpointingImpl.COLOCATED_PYTHON:
123+
logging.info('Using ColocatedPythonDispatcher')
124+
dispatcher = dispatchers.ColocatedPythonDispatcher()
125+
case CheckpointingImpl.NO_DISPATCHER:
126+
logging.info('Not using dispatcher')
127+
dispatcher = None
128+
case _:
129+
raise ValueError(f'Unsupported CheckpointingImpl: {checkpointing_impl}')
89130

90131
return _get_array_hander_with_dispatcher(
91132
dispatcher,
@@ -96,15 +137,15 @@ def get_pathways_array_handler(
96137

97138
def register_pathways_handlers(
98139
use_single_replica_array_handler: bool = False,
99-
use_colocated_python: bool = True,
140+
checkpointing_impl: CheckpointingImpl | None = None,
100141
**kwargs,
101142
):
102143
"""Registers the Pathways handlers with the given options.
103144
104145
Args:
105146
use_single_replica_array_handler: Whether to use the
106147
SingleReplicaArrayHandler.
107-
use_colocated_python: Use ColocatedPythonDispatcher with jax array handler.
148+
checkpointing_impl: The implementation to use for Pathways checkpointing.
108149
**kwargs: Keyword arguments to pass to the ArrayHandler.
109150
"""
110151
_register_numpy_and_scalar_handlers()
@@ -113,7 +154,7 @@ def register_pathways_handlers(
113154
jax.Array,
114155
get_pathways_array_handler(
115156
use_single_replica_array_handler,
116-
use_colocated_python,
157+
checkpointing_impl=checkpointing_impl,
117158
**kwargs,
118159
),
119160
override=True,

checkpoint/orbax/checkpoint/_src/testing/benchmarks/pytree_checkpoint_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,11 @@ def register_array_type_handler(self, options: PyTreeCheckpointOptions):
124124
else:
125125
if options.use_jax_array_handler:
126126
if options.use_persistence_array_handler:
127-
ocp.type_handlers.register_pathways_handlers(
127+
ocp.pathways.register_type_handlers(
128128
use_persistence_array_handler=options.use_persistence_array_handler,
129129
)
130130
else:
131-
ocp.type_handlers.register_pathways_handlers(
131+
ocp.pathways.register_type_handlers(
132132
use_colocated_python=options.use_colocated_python,
133133
use_replica_parallel=options.use_replica_parallel,
134134
enable_replica_parallel_separate_folder=options.enable_replica_parallel_separate_folder,

checkpoint/orbax/checkpoint/_src/testing/benchmarks/single_replica_benchmark.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@ def test_fn(
119119

120120
pathways_handler_registry.register_pathways_handlers(
121121
use_single_replica_array_handler=True,
122-
use_colocated_python=options.use_colocated_python,
122+
checkpointing_impl=pathways_handler_registry.CheckpointingImpl.from_options(
123+
use_colocated_python=options.use_colocated_python,
124+
),
123125
replica_axis_index=options.replica_axis_index,
124126
primary_replica_id=options.primary_replica_id,
125127
use_replica_parallel=options.use_replica_parallel,

checkpoint/orbax/checkpoint/_src/testing/benchmarks/single_replica_benchmark_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,9 @@ def test_test_fn_handler_options(
237237
mock_is_runtime_to_distributed_ids_initialized.assert_called_once()
238238
register_pathways_handlers.assert_called_once_with(
239239
use_single_replica_array_handler=True,
240-
use_colocated_python=options.use_colocated_python,
240+
checkpointing_impl=pathways_handler_registry.CheckpointingImpl.from_options(
241+
use_colocated_python=options.use_colocated_python,
242+
),
241243
replica_axis_index=options.replica_axis_index,
242244
primary_replica_id=options.primary_replica_id,
243245
use_replica_parallel=options.use_replica_parallel,
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 2025 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Public symbols for Pathways-related serialization."""
16+
17+
# pylint: disable=g-importing-member, unused-import, g-bad-import-order
18+
19+
from orbax.checkpoint._src.serialization.pathways_handler_registry import CheckpointingImpl
20+
from orbax.checkpoint._src.serialization.pathways_handler_registry import register_pathways_handlers as register_type_handlers

checkpoint/orbax/checkpoint/type_handlers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from orbax.checkpoint._src.serialization.type_handler_registry import get_type_handler
3535
from orbax.checkpoint._src.serialization.type_handler_registry import has_type_handler
3636
from orbax.checkpoint._src.serialization.type_handler_registry import register_standard_handlers_with_options
37-
from orbax.checkpoint._src.serialization.pathways_handler_registry import register_pathways_handlers
3837
from orbax.checkpoint._src.serialization.type_handler_registry import register_type_handler
3938
from orbax.checkpoint._src.serialization.type_handler_registry import supported_types
4039

0 commit comments

Comments
 (0)