Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions docs/reliability.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,28 @@ training step count. By default, checkpointing is disabled if
`checkpoint_root_directory` is not specified. Users can further customize
checkpointing behavior via `checkpointing_options` in the config.

Users customize background preservation behavior granularly using components
defined inside `checkpoint_options`:

* **Save Decision Policies**: Dictates when to initiate a checkpoint based on
defined steps or intervals. Supported configurations include
`FixedIntervalPolicy` and `ContinuousCheckpointingPolicy`. The default is
`ContinuousCheckpointingPolicy(minimum_interval_secs=180)` (saves every 180
seconds). Check Orbax v1 `save_decision_policies.py` for the complete
interface contracts.
* **Preservation Policies**: Sets specifications regarding tracking
checkpoints over bounded timelines (e.g., `LatestN`). The default is
`LatestN(n=3)` (keeps the latest 3 checkpoints). See Orbax v1
`preservation_policies.py`.
* **Step Name Format**: Defines the representation of directory names for step
checkpoints. The default is `ocp.path.step.standard_name_format()` (uses
simple integer step names).
* **Asynchronous Processing**: Manage asynchronous behavior by specifying:
* `enable_async_checkpointing`: Whether to use async checkpointing.
Defaults to `True`.
* `timeout_secs`: The timeout for asynchronous operations.
Defaults to `1200` seconds.

## Fault Tolerance

Tunix ensures fault tolerance primarily through its checkpointing mechanism,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
"kagglehub",
"numba",
"omegaconf", # CLI config
"orbax-checkpoint>=0.11.36",
"pillow", # Image processing
"pylatexenc", # Eval result parsing
"python-dotenv", # Huggingface API key
Expand Down
95 changes: 85 additions & 10 deletions tests/sft/checkpoint_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import optax
import qwix
from tunix.sft import checkpoint_manager
from tunix.sft import checkpoint_options

os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'

Expand Down Expand Up @@ -110,20 +111,26 @@ def test_empty_root_directory(self):
def test_checkpoint_manager_options_none_sets_default(self):
cp_path = f'{self.temp_path}/{self.id()}'
cp_manager = checkpoint_manager.CheckpointManager(cp_path, options=None)
self.assertIsNotNone(cp_manager._checkpoint_manager)
self.assertIsNotNone(cp_manager._checkpointer)
self.assertEqual(
cp_manager._checkpoint_manager._options, # pytype: disable=attribute-error
checkpoint_manager._DEFAULT_CHECKPOINTING_OPTIONS,
cp_manager._options,
checkpoint_options.DEFAULT_CHECKPOINTING_OPTIONS,
)

def test_context_property(self):
cp_path = f'{self.temp_path}/{self.id()}'
cp_manager = checkpoint_manager.CheckpointManager(cp_path)
self.assertIsNotNone(cp_manager.context)

def test_save(self):
cp_path = f'{self.temp_path}/{self.id()}'
cp_manager = checkpoint_manager.CheckpointManager(cp_path)
model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh)

# Save the model state.
self.assertTrue(cp_manager.save(1, model))
cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
assert cp_manager._checkpointer is not None
cp_manager._checkpointer.wait()
self.assertEqual(cp_manager.latest_step(), 1)

cp_manager.close()
Expand All @@ -139,7 +146,8 @@ def test_restore(self):

# Save the model params.
self.assertTrue(cp_manager.save(1, model))
cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
assert cp_manager._checkpointer is not None
cp_manager._checkpointer.wait()

# Change the model state.
changed_state = jax.tree.map(lambda x: x + 1, nnx.state(model))
Expand All @@ -162,7 +170,8 @@ def test_restore_different_sharding(self):

# Save the model params.
self.assertTrue(cp_manager.save(1, unsharded_model))
cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
assert cp_manager._checkpointer is not None
cp_manager._checkpointer.wait()

# Restore the model without shardings.
self.assertEqual(cp_manager.maybe_restore(unsharded_model), (1, {}))
Expand Down Expand Up @@ -211,7 +220,8 @@ def test_restore_with_lora(self):

# Save the model params.
self.assertTrue(cp_manager.save(1, model, save_only_lora_params=True))
cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
assert cp_manager._checkpointer is not None
cp_manager._checkpointer.wait()

# Change the model state.
changed_state = jax.tree.map(lambda x: x + 1, nnx.state(model))
Expand Down Expand Up @@ -241,7 +251,8 @@ def test_save_and_restore_with_custom_metadata(self):
model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh)
custom_metadata = {'foo': 1, 'bar': 2}
ckpt_manager.save(1, model, custom_metadata=custom_metadata)
ckpt_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
assert ckpt_manager._checkpointer is not None
ckpt_manager._checkpointer.wait()
restored_step, restored_metadata = ckpt_manager.maybe_restore(model)
self.assertEqual(restored_step, 1)
self.assertEqual(restored_metadata, custom_metadata)
Expand All @@ -257,7 +268,8 @@ def test_save_and_restore_with_optimizer_state(self):
)
custom_metadata = {'foo': 1, 'bar': 2}
ckpt_manager.save(1, model, optimizer, custom_metadata=custom_metadata)
ckpt_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
assert ckpt_manager._checkpointer is not None
ckpt_manager._checkpointer.wait()

new_optimizer = nnx.Optimizer(
model,
Expand All @@ -281,6 +293,68 @@ def test_save_and_restore_with_optimizer_state(self):
new_optimizer.opt_state.hyperparams['learning_rate'].value, 1e-3
)

def test_save_and_restore_with_forced_single_device_sharding(self):
cp_path = f'{self.temp_path}/{self.id()}'
ckpt_manager = checkpoint_manager.CheckpointManager(cp_path)
model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh)
optimizer = nnx.Optimizer(
model,
optax.inject_hyperparams(optax.adamw)(learning_rate=1e-3),
wrt=nnx.Param,
)
custom_metadata = {'foo': 1, 'bar': 2}
ckpt_manager.save(1, model, optimizer, custom_metadata=custom_metadata)
assert ckpt_manager._checkpointer is not None
ckpt_manager._checkpointer.wait()

new_optimizer = nnx.Optimizer(
model,
optax.inject_hyperparams(optax.adamw)(learning_rate=1e-5),
wrt=nnx.Param,
)

new_optimizer.opt_state.hyperparams['learning_rate'].value = jax.device_put(
new_optimizer.opt_state.hyperparams['learning_rate'].value,
jax.devices()[0],
)

self.assertIsInstance(
new_optimizer.opt_state.hyperparams['learning_rate'].value.sharding,
jax.sharding.SingleDeviceSharding,
)

restored_step, _ = ckpt_manager.maybe_restore(
model, new_optimizer
)
self.assertEqual(restored_step, 1)

errors = []
def assert_named_sharding(path, x):
if hasattr(x, 'sharding'):
try:
self.assertIsInstance(
x.sharding,
jax.sharding.NamedSharding,
f'Variable at {path} is not NamedSharding',
)
except AssertionError as e:
errors.append(str(e))
return

path_str = str(path)
if 'hyperparams' in path_str:
try:
self.assertEqual(x.sharding.spec, jax.sharding.PartitionSpec())
except AssertionError as e:
errors.append(str(e))

jax.tree.map_with_path(
assert_named_sharding,
nnx.state(new_optimizer, nnx.optimizer.OptState),
)
if errors:
self.fail(f'Found sharding mismatches:\n{"\n".join(errors)}')

def test_restore_without_optimizer(self):
cp_path = f'{self.temp_path}/{self.id()}'
ckpt_manager = checkpoint_manager.CheckpointManager(cp_path)
Expand All @@ -291,7 +365,8 @@ def test_restore_without_optimizer(self):
wrt=nnx.Param,
)
ckpt_manager.save(1, model, optimizer)
ckpt_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
assert ckpt_manager._checkpointer is not None
ckpt_manager._checkpointer.wait()
ckpt_manager.maybe_restore(model)

@parameterized.parameters(['test_data/checkpoints'])
Expand Down
126 changes: 126 additions & 0 deletions tests/sft/checkpoint_options_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Tunix Checkpointing Options custom implementation for Orbax v1."""

from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
import orbax.checkpoint as ocp_v0
from orbax.checkpoint import v1 as ocp
from tunix.sft import checkpoint_options


class CheckpointOptionsTest(parameterized.TestCase):
def test_resolve_checkpointing_defaults_with_none(self):
opts = checkpoint_options.resolve_checkpointing_defaults(None)
self.assertEqual(opts, checkpoint_options.DEFAULT_CHECKPOINTING_OPTIONS)

def test_resolve_checkpointing_defaults_with_deprecated_options(self):
legacy_opts = ocp_v0.CheckpointManagerOptions(
save_interval_steps=100, max_to_keep=5
)

with self.assertLogs(level='WARNING') as log:
opts = checkpoint_options.resolve_checkpointing_defaults(
legacy_opts
)

# Verify deprecation warnings were logged
v0_warnings = [msg for msg in log.output if 'Using v0' in msg]
self.assertNotEmpty(v0_warnings)

# Verify policies were resolved correctly
self.assertEqual(
opts.save_decision_policy,
ocp.training.save_decision_policies.FixedIntervalPolicy(100),
)

self.assertEqual(
opts.preservation_policy,
ocp.training.preservation_policies.LatestN(5),
)

def test_resolve_checkpointing_defaults_with_legacy_options_dataclass(self):
legacy_opts = ocp_v0.CheckpointManagerOptions(
save_decision_policy=ocp_v0.checkpoint_managers.ContinuousCheckpointingPolicy(
minimum_interval_secs=10,
),
)
opts = checkpoint_options.resolve_checkpointing_defaults(
legacy_opts
)
self.assertIsInstance(
opts.save_decision_policy,
ocp.training.save_decision_policies.ContinuousCheckpointingPolicy,
)
# pytype: disable=attribute-error
self.assertEqual(opts.save_decision_policy.minimum_interval_secs, 10)
# pytype: enable=attribute-error

def test_resolve_checkpointing_defaults_with_async_timeout(self):
async_opts = ocp.options.AsyncOptions(timeout_secs=5000)
options = mock.create_autospec(
checkpoint_options.TunixCheckpointingOptions, instance=True
)
options.async_options = async_opts
options.save_decision_policy = None
options.preservation_policy = None
options.step_name_format = None
options.enable_async_checkpointing = None

opts = checkpoint_options.resolve_checkpointing_defaults(options)
self.assertIsNotNone(opts.async_options)
assert opts.async_options is not None
self.assertEqual(opts.async_options.timeout_secs, 5000)

def test_resolve_checkpointing_defaults_with_modern_options(self):
modern_opts = checkpoint_options.TunixCheckpointingOptions(
save_decision_policy=ocp.training.save_decision_policies.FixedIntervalPolicy(
50
),
preservation_policy=ocp.training.preservation_policies.LatestN(10),
enable_async_checkpointing=False,
)
opts = checkpoint_options.resolve_checkpointing_defaults(
modern_opts
)
self.assertEqual(
opts.save_decision_policy, modern_opts.save_decision_policy
)
self.assertEqual(
opts.preservation_policy, modern_opts.preservation_policy
)
self.assertFalse(opts.enable_async_checkpointing)

def test_create_checkpointing_options(self):
opts = checkpoint_options.create_checkpointing_options(
save_decision_policy=ocp.training.save_decision_policies.FixedIntervalPolicy(
50
),
preservation_policy=ocp.training.preservation_policies.LatestN(10),
enable_async_checkpointing=False,
)
self.assertIsInstance(opts, checkpoint_options.TunixCheckpointingOptions)
self.assertEqual(
opts.save_decision_policy,
ocp.training.save_decision_policies.FixedIntervalPolicy(50),
)
self.assertEqual(
opts.preservation_policy,
ocp.training.preservation_policies.LatestN(10),
)
self.assertFalse(opts.enable_async_checkpointing)

if __name__ == '__main__':
absltest.main()
8 changes: 4 additions & 4 deletions tests/sft/peft_trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
import jax.sharding as shd
import numpy as np
import optax
import orbax.checkpoint as ocp
from tunix.sft import checkpoint_manager
from tunix.sft import checkpoint_options
from tunix.sft import hooks
from tunix.sft import peft_trainer
from tunix.sft import profiler
Expand Down Expand Up @@ -547,13 +547,13 @@ def test_checkpointing(
mock_checkpoint_manager.latest_step.return_value = (
expected_save_steps[-1] - 1
) # force save at close
checkpoint_options = ocp.CheckpointManagerOptions()
checkpointing_options = checkpoint_options.create_checkpointing_options()
config = peft_trainer.TrainingConfig(
eval_every_n_steps=2,
max_steps=100,
gradient_accumulation_steps=grad_accu,
checkpoint_root_directory='/tmp/checkpoint',
checkpointing_options=checkpoint_options,
checkpointing_options=checkpointing_options,
)
rngs = nnx.Rngs(0)
model = tc.get_lora_model(
Expand All @@ -566,7 +566,7 @@ def test_checkpointing(
trainer.train(train_ds, eval_ds)

mock_checkpoint_manager_init.assert_called_once_with(
root_directory='/tmp/checkpoint', options=checkpoint_options
root_directory='/tmp/checkpoint', options=checkpointing_options
)
# Assert that the checkpoint manager is called with the correct arguments
# and does not have any unexpected calls.
Expand Down
Loading
Loading