Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/launching.md
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,8 @@ before performing a parameter update (simulates larger batch sizes).
* **`checkpointing_options`**:
* `max_to_keep`: Number of recent checkpoints to retain.
* `save_interval_steps`: How often to save a checkpoint.
* `enable_async_checkpointing`: Boolean to toggle asynchronous checkpointing execution.
* `timeout_secs`: Maximum time permitted for asynchronous writes natively.


* **`metrics_logging_options`**: Settings for logging. Includes project name, run name, and flush frequency.
Expand Down
22 changes: 22 additions & 0 deletions docs/reliability.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,28 @@ training step count. By default, checkpointing is disabled if
`checkpoint_root_directory` is not specified. Users can further customize
checkpointing behavior via `checkpointing_options` in the config.

Users customize background preservation behavior granularly using components
defined inside `checkpoint_options`:

* **Save Decision Policies**: Dictates when to initiate a checkpoint based on
defined steps or intervals. Supported configurations include
`FixedIntervalPolicy` and `ContinuousCheckpointingPolicy`. The default is
`ContinuousCheckpointingPolicy(minimum_interval_secs=180)` (saves every 180
seconds). Check Orbax v1 `save_decision_policies.py` for the complete
interface contracts.
* **Preservation Policies**: Sets specifications regarding tracking
checkpoints over bounded timelines (e.g., `LatestN`). The default is
`LatestN(n=3)` (keeps the latest 3 checkpoints). See Orbax v1
`preservation_policies.py`.
* **Step Name Format**: Defines the representation of directory names for step
checkpoints. The default is `ocp.path.step.standard_name_format()` (uses
simple integer step names).
* **Asynchronous Processing**: Manage asynchronous behavior by specifying:
* `enable_async_checkpointing`: Whether to use async checkpointing.
Defaults to `True`.
* `timeout_secs`: The timeout for asynchronous operations.
Defaults to `1200` seconds.

## Fault Tolerance

Tunix ensures fault tolerance primarily through its checkpointing mechanism,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
"kagglehub",
"numba",
"omegaconf", # CLI config
"orbax-checkpoint>=0.11.36",
"pillow", # Image processing
"pylatexenc", # Eval result parsing
"python-dotenv", # Huggingface API key
Expand Down
15 changes: 10 additions & 5 deletions tests/rl/agentic/agentic_grpo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import jax.numpy as jnp
import numpy as np
import optax
import orbax.checkpoint as ocp
from orbax.checkpoint import v1 as ocp
from tunix.generate import tokenizer_adapter
from tunix.rl import common as rl_common
from tunix.rl import function_registry
Expand All @@ -48,6 +48,7 @@
from tunix.rl.agentic.environments.base_environment import BaseTaskEnv, EnvStepResult
from tunix.rl.queue import data_queue as queue_lib
from tunix.rl.rollout import base_rollout
from tunix.sft import checkpoint_options
from tunix.sft import metrics_logger
from tunix.tests import test_common
from tunix.utils import trajectory_logger
Expand Down Expand Up @@ -697,8 +698,10 @@ def create_learner(
train_micro_batch_size=mini_batch_size,
rollout_micro_batch_size=mini_batch_size,
compute_logps_micro_batch_size=mini_batch_size,
checkpointing_options=ocp.CheckpointManagerOptions(
save_interval_steps=1,
checkpointing_options=checkpoint_options.create_checkpointing_options(
save_decision_policy=(
ocp.training.save_decision_policies.FixedIntervalPolicy(1)
),
),
checkpoint_root_directory=ckpt_dir,
),
Expand Down Expand Up @@ -886,8 +889,10 @@ def create_learner(

mesh = pxla.thread_resources.env.physical_mesh
if ckpt_dir:
checkpointing_options = ocp.CheckpointManagerOptions(
save_interval_steps=1,
checkpointing_options = checkpoint_options.create_checkpointing_options(
save_decision_policy=(
ocp.training.save_decision_policies.FixedIntervalPolicy(1)
),
)
else:
checkpointing_options = None
Expand Down
19 changes: 13 additions & 6 deletions tests/rl/grpo/grpo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@
import jax.numpy as jnp
import numpy as np
import optax
import orbax.checkpoint as ocp
from orbax.checkpoint import v1 as ocp
from tunix.perf import trace as trace_lib
from tunix.perf.experimental import tracer as perf_tracer_v2
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo import grpo_learner as grpo_lib
from tunix.rl.queue import data_queue as queue_lib
from tunix.rl.rollout import base_rollout
from tunix.sft import checkpoint_options
from tunix.sft import profiler
from tunix.tests import test_common as tc
from typing_extensions import override
Expand Down Expand Up @@ -718,9 +719,13 @@ def test_resume_training(self):
eval_every_n_steps=10,
max_steps=10,
checkpoint_root_directory=temp_path,
checkpointing_options=ocp.CheckpointManagerOptions(
save_interval_steps=1,
max_to_keep=10,
checkpointing_options=checkpoint_options.create_checkpointing_options(
save_decision_policy=(
ocp.training.save_decision_policies.FixedIntervalPolicy(1)
),
preservation_policy=ocp.training.preservation_policies.LatestN(
10
),
),
),
rollout_config=base_rollout.RolloutConfig(
Expand Down Expand Up @@ -1050,8 +1055,10 @@ def create_learner(
train_micro_batch_size=mini_batch_size,
rollout_micro_batch_size=mini_batch_size,
compute_logps_micro_batch_size=mini_batch_size,
checkpointing_options=ocp.CheckpointManagerOptions(
save_interval_steps=4,
checkpointing_options=checkpoint_options.create_checkpointing_options(
save_decision_policy=(
ocp.training.save_decision_policies.FixedIntervalPolicy(4)
),
),
checkpoint_root_directory=ckpt_dir,
),
Expand Down
87 changes: 77 additions & 10 deletions tests/sft/checkpoint_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import optax
import qwix
from tunix.sft import checkpoint_manager
from tunix.sft import checkpoint_options

os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'

Expand Down Expand Up @@ -110,20 +111,25 @@ def test_empty_root_directory(self):
def test_checkpoint_manager_options_none_sets_default(self):
cp_path = f'{self.temp_path}/{self.id()}'
cp_manager = checkpoint_manager.CheckpointManager(cp_path, options=None)
self.assertIsNotNone(cp_manager._checkpoint_manager)
self.assertIsNotNone(cp_manager._checkpointer)
self.assertEqual(
cp_manager._checkpoint_manager._options, # pytype: disable=attribute-error
checkpoint_manager._DEFAULT_CHECKPOINTING_OPTIONS,
cp_manager._options,
checkpoint_options.DEFAULT_CHECKPOINTING_OPTIONS,
)

def test_context_property(self):
cp_path = f'{self.temp_path}/{self.id()}'
cp_manager = checkpoint_manager.CheckpointManager(cp_path)
self.assertIsNotNone(cp_manager.context)

def test_save(self):
cp_path = f'{self.temp_path}/{self.id()}'
cp_manager = checkpoint_manager.CheckpointManager(cp_path)
model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh)

# Save the model state.
self.assertTrue(cp_manager.save(1, model))
cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
cp_manager._checkpointer.wait() # pytype: disable=attribute-error
self.assertEqual(cp_manager.latest_step(), 1)

cp_manager.close()
Expand All @@ -139,7 +145,7 @@ def test_restore(self):

# Save the model params.
self.assertTrue(cp_manager.save(1, model))
cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
cp_manager._checkpointer.wait() # pytype: disable=attribute-error

# Change the model state.
changed_state = jax.tree.map(lambda x: x + 1, nnx.state(model))
Expand All @@ -162,7 +168,7 @@ def test_restore_different_sharding(self):

# Save the model params.
self.assertTrue(cp_manager.save(1, unsharded_model))
cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
cp_manager._checkpointer.wait() # pytype: disable=attribute-error

# Restore the model without shardings.
self.assertEqual(cp_manager.maybe_restore(unsharded_model), (1, {}))
Expand Down Expand Up @@ -211,7 +217,7 @@ def test_restore_with_lora(self):

# Save the model params.
self.assertTrue(cp_manager.save(1, model, save_only_lora_params=True))
cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
cp_manager._checkpointer.wait() # pytype: disable=attribute-error

# Change the model state.
changed_state = jax.tree.map(lambda x: x + 1, nnx.state(model))
Expand Down Expand Up @@ -241,7 +247,7 @@ def test_save_and_restore_with_custom_metadata(self):
model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh)
custom_metadata = {'foo': 1, 'bar': 2}
ckpt_manager.save(1, model, custom_metadata=custom_metadata)
ckpt_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
ckpt_manager._checkpointer.wait() # pytype: disable=attribute-error
restored_step, restored_metadata = ckpt_manager.maybe_restore(model)
self.assertEqual(restored_step, 1)
self.assertEqual(restored_metadata, custom_metadata)
Expand All @@ -257,7 +263,7 @@ def test_save_and_restore_with_optimizer_state(self):
)
custom_metadata = {'foo': 1, 'bar': 2}
ckpt_manager.save(1, model, optimizer, custom_metadata=custom_metadata)
ckpt_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
ckpt_manager._checkpointer.wait() # pytype: disable=attribute-error

new_optimizer = nnx.Optimizer(
model,
Expand All @@ -281,6 +287,67 @@ def test_save_and_restore_with_optimizer_state(self):
new_optimizer.opt_state.hyperparams['learning_rate'].value, 1e-3
)

def test_save_and_restore_with_forced_single_device_sharding(self):
cp_path = f'{self.temp_path}/{self.id()}'
ckpt_manager = checkpoint_manager.CheckpointManager(cp_path)
model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh)
optimizer = nnx.Optimizer(
model,
optax.inject_hyperparams(optax.adamw)(learning_rate=1e-3),
wrt=nnx.Param,
)
custom_metadata = {'foo': 1, 'bar': 2}
ckpt_manager.save(1, model, optimizer, custom_metadata=custom_metadata)
ckpt_manager._checkpointer.wait() # pytype: disable=attribute-error

new_optimizer = nnx.Optimizer(
model,
optax.inject_hyperparams(optax.adamw)(learning_rate=1e-5),
wrt=nnx.Param,
)

new_optimizer.opt_state.hyperparams['learning_rate'].value = jax.device_put(
new_optimizer.opt_state.hyperparams['learning_rate'].value,
jax.devices()[0],
)

self.assertIsInstance(
new_optimizer.opt_state.hyperparams['learning_rate'].value.sharding,
jax.sharding.SingleDeviceSharding,
)

restored_step, _ = ckpt_manager.maybe_restore(
model, new_optimizer
)
self.assertEqual(restored_step, 1)

errors = []
def assert_named_sharding(path, x):
if hasattr(x, 'sharding'):
try:
self.assertIsInstance(
x.sharding,
jax.sharding.NamedSharding,
f'Variable at {path} is not NamedSharding',
)
except AssertionError as e:
errors.append(str(e))
return

path_str = str(path)
if 'hyperparams' in path_str:
try:
self.assertEqual(x.sharding.spec, jax.sharding.PartitionSpec())
except AssertionError as e:
errors.append(str(e))

jax.tree.map_with_path(
assert_named_sharding,
nnx.state(new_optimizer, nnx.optimizer.OptState),
)
if errors:
self.fail(f'Found sharding mismatches:\n{"\n".join(errors)}')

def test_restore_without_optimizer(self):
cp_path = f'{self.temp_path}/{self.id()}'
ckpt_manager = checkpoint_manager.CheckpointManager(cp_path)
Expand All @@ -291,7 +358,7 @@ def test_restore_without_optimizer(self):
wrt=nnx.Param,
)
ckpt_manager.save(1, model, optimizer)
ckpt_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
ckpt_manager._checkpointer.wait() # pytype: disable=attribute-error
ckpt_manager.maybe_restore(model)

@parameterized.parameters(['test_data/checkpoints'])
Expand Down
Loading
Loading