Skip to content

#v1 Rename GenericHandler to StatefulCheckpointableHandler, remove metadata method, and other minor changes. #1883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,11 @@ py_library(
name = "global_registration",
srcs = ["global_registration.py"],
deps = [
":generic_handler",
":json_handler",
":proto_handler",
":pytree_handler",
":registration",
":stateful_checkpointable_handler",
":types",
"//orbax/checkpoint/experimental/v1/_src/path:format_utils",
],
Expand Down Expand Up @@ -179,13 +179,10 @@ py_test(
)

py_library(
name = "generic_handler",
srcs = ["generic_handler.py"],
name = "stateful_checkpointable_handler",
srcs = ["stateful_checkpointable_handler.py"],
deps = [
":proto_handler",
":types",
"//checkpoint/orbax/checkpoint/_src/multihost",
"//orbax/checkpoint/experimental/v1/_src/context",
"//orbax/checkpoint/experimental/v1/_src/path:types",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@

from typing import Type

from orbax.checkpoint.experimental.v1._src.handlers import generic_handler
from orbax.checkpoint.experimental.v1._src.handlers import json_handler
from orbax.checkpoint.experimental.v1._src.handlers import proto_handler
from orbax.checkpoint.experimental.v1._src.handlers import pytree_handler
from orbax.checkpoint.experimental.v1._src.handlers import registration
from orbax.checkpoint.experimental.v1._src.handlers import stateful_checkpointable_handler
from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types
from orbax.checkpoint.experimental.v1._src.path import format_utils

Expand All @@ -43,7 +43,9 @@ def _try_register_handler(

_try_register_handler(proto_handler.ProtoHandler)
_try_register_handler(json_handler.JsonHandler)
_try_register_handler(generic_handler.GenericHandler)
_try_register_handler(
stateful_checkpointable_handler.StatefulCheckpointableHandler
)
_try_register_handler(
json_handler.MetricsHandler,
format_utils.METRICS_CHECKPOINTABLE_KEY,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,52 +12,57 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""GenericHandler class."""
"""StatefulCheckpointableHandler class."""

from typing import Any, Awaitable, Generic
from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types
from orbax.checkpoint.experimental.v1._src.path import types as path_types

T = handler_types.T

class GenericHandler(
handler_types.CheckpointableHandler[handler_types.T, handler_types.T],
Generic[handler_types.T],

class StatefulCheckpointableHandler(
handler_types.CheckpointableHandler[T, T],
Generic[T],
):
"""Serializes/deserializes a Checkpointable."""

async def save(
self,
directory: path_types.PathAwaitingCreation,
checkpointable: handler_types.StatefulCheckpointable[handler_types.T],
checkpointable: handler_types.StatefulCheckpointable[T],
) -> Awaitable[None]:
return await checkpointable.save(directory)

async def _background_load(
self,
directory: path_types.Path,
abstract_checkpointable: handler_types.StatefulCheckpointable[
handler_types.T
],
) -> handler_types.T:
await abstract_checkpointable.load(directory)
return abstract_checkpointable

async def load(
self,
directory: path_types.Path,
abstract_checkpointable: (
handler_types.StatefulCheckpointable[handler_types.T] | None
handler_types.StatefulCheckpointable[T] | None
) = None,
) -> Awaitable[handler_types.T]:
) -> Awaitable[T]:
if abstract_checkpointable is None:
raise ValueError(
'Abstract checkpointable is required for GenericHandler.load.'
'To restore a `StatefulCheckpointable`, you must pass an instance of'
' the object.'
)
return self._background_load(directory, abstract_checkpointable)

async def metadata(self, directory: path_types.Path) -> handler_types.T:
# TODO(yaning): Implement Metadata.
raise NotImplementedError()
# Returns Awaitable[None]
background_load = await abstract_checkpointable.load(directory)

async def _background_load() -> T:
await background_load
# After loading, `abstract_checkpointable` (actually just a concrete
# checkpointable) should be populated with the loaded data.
return abstract_checkpointable

return _background_load()

async def metadata(self, directory: path_types.Path) -> T:
raise NotImplementedError(
'Metadata retrieval is not supported for objects implementing'
' `StatefulCheckpointable`.'
)

def is_handleable(self, checkpointable: Any) -> bool:
# TODO(yaning): Add test for a class that partially implements
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@ async def load(
"""Loads the checkpointable from the given `directory`."""
...

async def metadata(self, directory: path_types.Path) -> T:
"""Returns the metadata for the given `directory`."""
...


class CheckpointableHandler(Protocol[T, AbstractT]):
"""An interface that defines save/load logic for a `checkpointable` object.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ py_library(
"//checkpoint/orbax/checkpoint/_src/tree:utils",
"//third_party/py/aiofiles",
"//orbax/checkpoint/experimental/v1",
"//orbax/checkpoint/experimental/v1/_src/handlers:generic_handler",
"//orbax/checkpoint/experimental/v1/_src/handlers:stateful_checkpointable_handler",
"//orbax/checkpoint/experimental/v1/_src/path:types",
"//orbax/checkpoint/experimental/v1/_src/synchronization:multihost",
"//orbax/checkpoint/experimental/v1/_src/tree:types",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,26 +52,26 @@ async def background_load(

@dataclasses.dataclass
class Point:
"""Implements StatefulCheckpointable."""
x: int
y: int

def __eq__(self, other: Point) -> bool:
return isinstance(other, Point) and self.x == other.x and self.y == other.y

async def save(self, directory: path_types.PathAwaitingCreation):
async def save(
self, directory: path_types.PathAwaitingCreation
) -> Awaitable[None]:
return DataclassHandler().background_save(directory, self)

async def load(
self,
directory: path_types.Path,
) -> None:
async def _background_load(self, directory: path_types.Path):
async with aiofiles.open(directory / 'foo.txt', 'r') as f:
contents = json.loads(await f.read())
self.x = contents['x']
self.y = contents['y']

async def metadata(self, directory: path_types.Path) -> Point:
raise NotImplementedError()
async def load(self, directory: path_types.Path) -> Awaitable[None]:
return self._background_load(directory)


@dataclasses.dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ def test_partial_restore_with_placeholder(self, use_async):
):
self.load_and_wait(directory, reference_item, use_async=use_async)

def test_checkpointable_with_generic_handler(self):
def test_checkpointable_with_stateful_checkpointable(self):
point = handler_utils.Point(1, 2)
checkpointables = {'point': point}
ocp.save_checkpointables(self.directory, checkpointables)
Expand Down