Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/launching.md
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,8 @@ before performing a parameter update (simulates larger batch sizes).
* **`checkpointing_options`**:
* `max_to_keep`: Number of recent checkpoints to retain.
* `save_interval_steps`: How often to save a checkpoint.
* `enable_async_checkpointing`: Boolean to toggle asynchronous checkpointing execution.
* `timeout_secs`: Maximum time permitted for asynchronous writes natively.


* **`metrics_logging_options`**: Settings for logging. Includes project name, run name, and flush frequency.
Expand Down
22 changes: 22 additions & 0 deletions docs/reliability.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,28 @@ training step count. By default, checkpointing is disabled if
`checkpoint_root_directory` is not specified. Users can further customize
checkpointing behavior via `checkpointing_options` in the config.

Users customize background preservation behavior granularly using components
defined inside `checkpoint_options`:

* **Save Decision Policies**: Dictates when to initiate a checkpoint based on
defined steps or intervals. Supported configurations include
`FixedIntervalPolicy` and `ContinuousCheckpointingPolicy`. The default is
`ContinuousCheckpointingPolicy(minimum_interval_secs=180)` (saves every 180
seconds). Check Orbax v1 `save_decision_policies.py` for the complete
interface contracts.
* **Preservation Policies**: Sets specifications regarding tracking
checkpoints over bounded timelines (e.g., `LatestN`). The default is
`LatestN(n=3)` (keeps the latest 3 checkpoints). See Orbax v1
`preservation_policies.py`.
* **Step Name Format**: Defines the representation of directory names for step
checkpoints. The default is `ocp.path.step.standard_name_format()` (uses
simple integer step names).
* **Asynchronous Processing**: Manage asynchronous behavior by specifying:
* `enable_async_checkpointing`: Whether to use async checkpointing.
Defaults to `True`.
* `timeout_secs`: The timeout for asynchronous operations.
Defaults to `1200` seconds.

## Fault Tolerance

Tunix ensures fault tolerance primarily through its checkpointing mechanism,
Expand Down
13 changes: 8 additions & 5 deletions examples/agentic/gemma_grpo_demo_nb.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
import jax
from jax import numpy as jnp
import optax
from orbax import checkpoint as ocp
from orbax.checkpoint import v1 as ocp
from tunix.sft import checkpoint_options

# %%
if ENV == 'g3':
Expand Down Expand Up @@ -306,8 +307,7 @@ def get_ref_model():
abs_state,
nnx.get_named_sharding(abs_state, mesh),
)
checkpointer = ocp.StandardCheckpointer()
restored_params = checkpointer.restore(ckpt_path, target=abs_state)
restored_params = ocp.load_pytree(ckpt_path, abstract_pytree=abs_state)

graph_def, _ = nnx.split(abs_gemma)
gemma = nnx.merge(graph_def, restored_params)
Expand Down Expand Up @@ -499,8 +499,11 @@ def check_numbers(prompts, completions, answer, **kargs):

# %%
# Ckpt saving
checkpointing_options = ocp.CheckpointManagerOptions(
save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP
checkpointing_options = checkpoint_options.create_checkpointing_options(
save_decision_policy=ocp.training.save_decision_policies.FixedIntervalPolicy(
SAVE_INTERVAL_STEPS
),
preservation_policy=ocp.training.preservation_policies.LatestN(MAX_TO_KEEP),
)

# %%
Expand Down
28 changes: 7 additions & 21 deletions examples/deepscaler/math_eval_nb.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
import jax
from jax import numpy as jnp
from flax import nnx
import orbax.checkpoint as ocp
from orbax.checkpoint import v1 as ocp
from tqdm.auto import tqdm
import re

Expand Down Expand Up @@ -238,28 +238,14 @@ def model_from_orbax_ckpt(self):
abs_state,
nnx.get_named_sharding(abs_state, self.mesh),
)
item_handlers = {
"model_params": ocp.PyTreeCheckpointHandler(),
"optimizer_state": ocp.PyTreeCheckpointHandler(),
}
checkpointer = ocp.CheckpointManager(
self.model_path,
item_handlers=item_handlers,
)
model_cp_args = ocp.args.PyTreeRestore(
item=abs_state,
restore_args=ocp.checkpoint_utils.construct_restore_args(
target=abs_state
),
)
ckpt = checkpointer.restore(
160,
args=ocp.args.Composite(
model_params=model_cp_args,
),
step_path = os.path.join(self.model_path, "160")
ckpt_model_params = ocp.load_pytree(
step_path,
abstract_pytree=abs_state,
checkpointable_name="model_params",
)
graphdef, _ = nnx.split(abs_model)
new_state = nnx.State(ckpt.model_params)
new_state = nnx.State(ckpt_model_params)
self.model = nnx.merge(graphdef, new_state)

def load_model(self):
Expand Down
13 changes: 9 additions & 4 deletions examples/deepscaler/train_deepscaler_nb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from jax import numpy as jnp
import numpy as np
import optax
import optax
from orbax import checkpoint as ocp
from orbax.checkpoint import v1 as ocp
import qwix

# ====== Logging Configuration ======
Expand Down Expand Up @@ -59,6 +58,7 @@
with cm:
from tunix.models.qwen2 import params as params_lib
from tunix.models.qwen2 import model as model_lib
from tunix.sft import checkpoint_options
from tunix.sft import metrics_logger
from tunix.rl.agentic.agentic_grpo_learner import GRPOConfig, GRPOLearner
from tunix.rl.agentic.agents import model_agent
Expand Down Expand Up @@ -446,8 +446,13 @@ def get_lora_model(base_model, model_mesh):

# %%
# Ckpt saving
checkpointing_options = ocp.CheckpointManagerOptions(
save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP
checkpointing_options = checkpoint_options.create_checkpointing_options(
save_decision_policy=ocp.training.save_decision_policies.FixedIntervalPolicy(
SAVE_INTERVAL_STEPS
),
preservation_policy=ocp.training.preservation_policies.LatestN(
MAX_TO_KEEP
),
)

# Metrics logger
Expand Down
12 changes: 9 additions & 3 deletions examples/deepswe/train_deepswe_nb.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from kubernetes import client, config as k8s_config
import numpy as np
import optax
from orbax import checkpoint as ocp
from orbax.checkpoint import v1 as ocp
import qwix
from transformers import AutoTokenizer
from tunix.cli.utils import data as data_lib
Expand Down Expand Up @@ -169,6 +169,7 @@
from tunix.models.qwen3 import params as params_lib
from tunix.models.qwen3 import model as model_lib
from tunix.sft import utils as sft_utils
from tunix.sft import checkpoint_options
from tunix.sft import metrics_logger
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.rollout import base_rollout
Expand Down Expand Up @@ -448,8 +449,13 @@ def transform(entry):
# ==========================================
# 9. Optimizer & Checkpointing
# ==========================================
checkpointing_options = ocp.CheckpointManagerOptions(
save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP
checkpointing_options = checkpoint_options.create_checkpointing_options(
save_decision_policy=ocp.training.save_decision_policies.FixedIntervalPolicy(
SAVE_INTERVAL_STEPS
),
preservation_policy=ocp.training.preservation_policies.LatestN(
MAX_TO_KEEP
),
)
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
log_dir="/tmp/tensorboard/grpo", flush_every_n_steps=2
Expand Down
9 changes: 3 additions & 6 deletions examples/logit_distillation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
"import jax.numpy as jnp\n",
"import kagglehub\n",
"import optax\n",
"from orbax import checkpoint as ocp\n",
"from orbax.checkpoint import v1 as ocp\n",
"from tunix.distillation import distillation_trainer\n",
"from tunix.distillation import strategies\n",
"from tunix.generate import sampler as sampler_lib\n",
Expand Down Expand Up @@ -177,10 +177,8 @@
" gemma = gemma_lib.Gemma.from_params(params, version=ckpt_version)\n",
"\n",
" print(f\"Saving checkpoint to {ckpt_dir}...\")\n",
" checkpointer = ocp.StandardCheckpointer()\n",
" _, state = nnx.split(gemma)\n",
" checkpointer.save(os.path.join(ckpt_dir, \"state\"), state)\n",
" checkpointer.wait_until_finished()\n",
" ocp.save_pytree(os.path.join(ckpt_dir, \"state\"), state)\n",
" # Clean up to save memory\n",
" del params\n",
" del gemma\n",
Expand Down Expand Up @@ -228,8 +226,7 @@
" abs_state,\n",
" nnx.get_named_sharding(abs_state, mesh),\n",
" )\n",
" checkpointer = ocp.StandardCheckpointer()\n",
" restored_params = checkpointer.restore(ckpt_path, target=abs_state)\n",
" restored_params = ocp.load_pytree(ckpt_path, abstract_pytree=abs_state)\n",
"\n",
" graph_def, _ = nnx.split(abs_gemma)\n",
" gemma = nnx.merge(graph_def, restored_params)\n",
Expand Down
12 changes: 8 additions & 4 deletions examples/sft/vlm_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@
import jax.numpy as jnp
import numpy as np
import optax
from orbax import checkpoint as ocp
from orbax.checkpoint import v1 as ocp
import qwix
from tunix.generate import sampler as sampler_lib
from tunix.generate import tokenizer_adapter as tokenizer_lib
from tunix.models.gemma3 import model as model_lib
from tunix.models.gemma3 import params as params_lib
from tunix.processors import image_processor as image_processor_lib
from tunix.sft import checkpoint_options
from tunix.sft import metrics_logger
from tunix.sft import peft_trainer

Expand Down Expand Up @@ -313,9 +314,12 @@ def gen_model_input_fn(x):
log_dir=logging_dir, flush_every_n_steps=20
)
checkpointing_options = None
if full_ckpt_dir is not None:
checkpointing_options = ocp.CheckpointManagerOptions(
save_interval_steps=_SAVE_INTERVAL_STEPS.value, max_to_keep=1
if full_ckpt_dir is not None:
checkpointing_options = checkpoint_options.create_checkpointing_options(
save_decision_policy=ocp.training.save_decision_policies.FixedIntervalPolicy(
_SAVE_INTERVAL_STEPS.value
),
preservation_policy=ocp.training.preservation_policies.LatestN(1),
)

training_config = peft_trainer.TrainingConfig(
Expand Down
Loading
Loading