Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
"kagglehub",
"numba",
"omegaconf", # CLI config
"orbax-checkpoint>=0.11.36",
"pillow", # Image processing
"pylatexenc", # Eval result parsing
"python-dotenv", # Huggingface API key
Expand Down
126 changes: 126 additions & 0 deletions tests/sft/checkpoint_options_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Tunix Checkpointing Options custom implementation for Orbax v1."""

from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
import orbax.checkpoint as ocp_v0
from orbax.checkpoint import v1 as ocp
from tunix.sft import checkpoint_options


class CheckpointOptionsTest(parameterized.TestCase):
def test_resolve_checkpointing_defaults_with_none(self):
opts = checkpoint_options.resolve_checkpointing_defaults(None)
self.assertEqual(opts, checkpoint_options.DEFAULT_CHECKPOINTING_OPTIONS)

def test_resolve_checkpointing_defaults_with_deprecated_options(self):
legacy_opts = ocp_v0.CheckpointManagerOptions(
save_interval_steps=100, max_to_keep=5
)

with self.assertLogs(level='WARNING') as log:
opts = checkpoint_options.resolve_checkpointing_defaults(
legacy_opts
)

# Verify deprecation warnings were logged
v0_warnings = [msg for msg in log.output if 'Using v0' in msg]
self.assertNotEmpty(v0_warnings)

# Verify policies were resolved correctly
self.assertEqual(
opts.save_decision_policy,
ocp.training.save_decision_policies.FixedIntervalPolicy(100),
)

self.assertEqual(
opts.preservation_policy,
ocp.training.preservation_policies.LatestN(5),
)

def test_resolve_checkpointing_defaults_with_legacy_options_dataclass(self):
legacy_opts = ocp_v0.CheckpointManagerOptions(
save_decision_policy=ocp_v0.checkpoint_managers.ContinuousCheckpointingPolicy(
minimum_interval_secs=10,
),
)
opts = checkpoint_options.resolve_checkpointing_defaults(
legacy_opts
)
self.assertIsInstance(
opts.save_decision_policy,
ocp.training.save_decision_policies.ContinuousCheckpointingPolicy,
)
# pytype: disable=attribute-error
self.assertEqual(opts.save_decision_policy.minimum_interval_secs, 10)
# pytype: enable=attribute-error

def test_resolve_checkpointing_defaults_with_async_timeout(self):
async_opts = ocp.options.AsyncOptions(timeout_secs=5000)
options = mock.create_autospec(
checkpoint_options.TunixCheckpointingOptions, instance=True
)
options.async_options = async_opts
options.save_decision_policy = None
options.preservation_policy = None
options.step_name_format = None
options.enable_async_checkpointing = None

opts = checkpoint_options.resolve_checkpointing_defaults(options)
self.assertIsNotNone(opts.async_options)
assert opts.async_options is not None
self.assertEqual(opts.async_options.timeout_secs, 5000)

def test_resolve_checkpointing_defaults_with_modern_options(self):
modern_opts = checkpoint_options.TunixCheckpointingOptions(
save_decision_policy=ocp.training.save_decision_policies.FixedIntervalPolicy(
50
),
preservation_policy=ocp.training.preservation_policies.LatestN(10),
enable_async_checkpointing=False,
)
opts = checkpoint_options.resolve_checkpointing_defaults(
modern_opts
)
self.assertEqual(
opts.save_decision_policy, modern_opts.save_decision_policy
)
self.assertEqual(
opts.preservation_policy, modern_opts.preservation_policy
)
self.assertFalse(opts.enable_async_checkpointing)

def test_create_checkpointing_options(self):
opts = checkpoint_options.create_checkpointing_options(
save_decision_policy=ocp.training.save_decision_policies.FixedIntervalPolicy(
50
),
preservation_policy=ocp.training.preservation_policies.LatestN(10),
enable_async_checkpointing=False,
)
self.assertIsInstance(opts, checkpoint_options.TunixCheckpointingOptions)
self.assertEqual(
opts.save_decision_policy,
ocp.training.save_decision_policies.FixedIntervalPolicy(50),
)
self.assertEqual(
opts.preservation_policy,
ocp.training.preservation_policies.LatestN(10),
)
self.assertFalse(opts.enable_async_checkpointing)

if __name__ == '__main__':
absltest.main()
214 changes: 214 additions & 0 deletions tunix/sft/checkpoint_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Checkpointing options for Tunix."""

import dataclasses
from typing import Generic, Protocol, TypeVar
from absl import logging

from orbax.checkpoint import v1 as ocp

MetadataT = TypeVar("MetadataT")


# We define protocols for async options and NameFormat because Orbax v0 and v1
# have different implementations of async options. Furthermore, our NameFormat
# protocol is generically typed (Generic[MetadataT]) because the underlying
# MetadataT types diverge across v0 and v1. We use Protocols to remain
# structurally compatible with both versions for a smooth transition.
class NameFormat(Protocol, Generic[MetadataT]):
...


class AsyncOptions(Protocol):
@property
def timeout_secs(self) -> int | None:
...


class CheckpointingOptions(Protocol):
"""Structural protocol for representing checkpointing options.

Any configuration object that fulfills this protocol (such as legacy v0
`ocp.CheckpointManagerOptions`, Tunix `TunixCheckpointingOptions`, or a custom
implementation) is supported and can be supplied directly to the Checkpointer.
"""

@property
def save_decision_policy(
self,
) -> ocp.training.save_decision_policies.SaveDecisionPolicy | None:
"""Returns the policy that defines when to save a checkpoint."""
...

@property
def preservation_policy(
self,
) -> ocp.training.preservation_policies.PreservationPolicy | None:
"""Returns the policy that defines when to preserve a checkpoint."""
...

@property
def step_name_format(
self,
) -> NameFormat | None:
"""Returns the format for step names."""
...

@property
def enable_async_checkpointing(self) -> bool | None:
"""Returns whether to use async checkpointing."""
...

@property
def async_options(self) -> AsyncOptions | None:
"""Returns the options for async operations."""
...


@dataclasses.dataclass(frozen=True)
class TunixCheckpointingOptions:
save_decision_policy: (
ocp.training.save_decision_policies.SaveDecisionPolicy | None
) = None
preservation_policy: (
ocp.training.preservation_policies.PreservationPolicy | None
) = None
step_name_format: (
ocp.path.step.NameFormat | None
) = None
enable_async_checkpointing: bool | None = None
async_options: ocp.options.AsyncOptions | None = None


# Default checkpointing options for Tunix:
# - Save every 180 seconds.
# - Keep the latest 3 checkpoints.
# - Use simple integer step names.
# - Use async checkpointing.
# - Timeout for async operations is 1200 seconds.
DEFAULT_CHECKPOINTING_OPTIONS = TunixCheckpointingOptions(
save_decision_policy=ocp.training.save_decision_policies.ContinuousCheckpointingPolicy(
minimum_interval_secs=180,
),
preservation_policy=ocp.training.preservation_policies.LatestN(n=3),
step_name_format=ocp.path.step.standard_name_format(),
enable_async_checkpointing=True,
async_options=ocp.options.AsyncOptions(timeout_secs=1200),
)


def create_checkpointing_options(
*,
save_decision_policy: (
ocp.training.save_decision_policies.SaveDecisionPolicy | None
) = None,
preservation_policy: (
ocp.training.preservation_policies.PreservationPolicy | None
) = None,
step_name_format: (
ocp.path.step.NameFormat | None
) = None,
enable_async_checkpointing: bool | None = None,
async_options: ocp.options.AsyncOptions | None = None,
) -> TunixCheckpointingOptions:
"""Creates a TunixCheckpointingOptions instance."""
return TunixCheckpointingOptions(
save_decision_policy=save_decision_policy,
preservation_policy=preservation_policy,
step_name_format=step_name_format,
enable_async_checkpointing=enable_async_checkpointing,
async_options=async_options,
)


def resolve_checkpointing_defaults(
options: CheckpointingOptions | None = None,
) -> TunixCheckpointingOptions:
"""Resolves options adhering to CheckpointingOptions protocol.

This function accepts any object fulfilling the `CheckpointingOptions`
protocol and cleanly extracts fields essential for Tunix. Legacy v0 fields
(`save_interval_steps` or `max_to_keep`) are applied strictly as second-tier
fallbacks, matching the explicit internal configuration logic used by Orbax V0
for backwards compatibility.

Args:
options: The options object to resolve.

Returns:
A resolved `TunixCheckpointingOptions` instance.
"""
if options is None:
return DEFAULT_CHECKPOINTING_OPTIONS

if (save_policy := options.save_decision_policy) is None:
# save_interval_steps is a v0 CheckpointManagerOptions construct only. We
# fall back to it for backward compatibility if v1 policies are not set.
# TODO(b/497926314): Remove this fallback once we no longer support v0.
if (
save_interval := getattr(options, "save_interval_steps", None)
) is not None:
logging.warning(
"Using v0 ocp.CheckpointManagerOptions is deprecated, along with"
" save_interval_steps. Please use a checkpointing_options with"
" save_decision_policy instead."
)
save_policy = ocp.training.save_decision_policies.FixedIntervalPolicy(
save_interval
)
else:
save_policy = DEFAULT_CHECKPOINTING_OPTIONS.save_decision_policy

if (preserve_policy := options.preservation_policy) is None:
# max_to_keep is a v0 CheckpointManagerOptions construct only. We fall
# back to it for backward compatibility if v1 policies are not set.
# TODO(b/497926314): Remove this fallback once we no longer support v0.
if (max_to_keep := getattr(options, "max_to_keep", None)) is not None:
logging.warning(
"Using v0 ocp.CheckpointManagerOptions is deprecated, along with"
" max_to_keep. Please use a checkpointing_options with"
" preservation_policy instead."
)
preserve_policy = ocp.training.preservation_policies.LatestN(max_to_keep)
else:
preserve_policy = DEFAULT_CHECKPOINTING_OPTIONS.preservation_policy

if (step_name_format := options.step_name_format) is None:
step_name_format = DEFAULT_CHECKPOINTING_OPTIONS.step_name_format

if (
enable_async := options.enable_async_checkpointing
) is None:
enable_async = DEFAULT_CHECKPOINTING_OPTIONS.enable_async_checkpointing

if (
options.async_options is not None
and options.async_options.timeout_secs is not None
):
# We want to only allow configuration of timeout_secs, and not the entire
# async_options, so we create a new AsyncOptions object here.
async_options = ocp.options.AsyncOptions(
timeout_secs=options.async_options.timeout_secs
)
else:
async_options = DEFAULT_CHECKPOINTING_OPTIONS.async_options

return create_checkpointing_options(
save_decision_policy=save_policy,
preservation_policy=preserve_policy,
step_name_format=step_name_format,
enable_async_checkpointing=enable_async,
async_options=async_options,
)
Loading