Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
"kagglehub",
"numba",
"omegaconf", # CLI config
"orbax-checkpoint>=0.11.35",
"pillow", # Image processing
"pylatexenc", # Eval result parsing
"python-dotenv", # Huggingface API key
Expand Down
15 changes: 10 additions & 5 deletions tests/rl/agentic/agentic_grpo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import jax.numpy as jnp
import numpy as np
import optax
import orbax.checkpoint as ocp
from orbax.checkpoint import v1 as ocp
from tunix.generate import tokenizer_adapter
from tunix.rl import common as rl_common
from tunix.rl import function_registry
Expand All @@ -48,6 +48,7 @@
from tunix.rl.agentic.environments.base_environment import BaseTaskEnv, EnvStepResult
from tunix.rl.queue import data_queue as queue_lib
from tunix.rl.rollout import base_rollout
from tunix.sft import checkpoint_options
from tunix.sft import metrics_logger
from tunix.tests import test_common
from tunix.utils import trajectory_logger
Expand Down Expand Up @@ -697,8 +698,10 @@ def create_learner(
train_micro_batch_size=mini_batch_size,
rollout_micro_batch_size=mini_batch_size,
compute_logps_micro_batch_size=mini_batch_size,
checkpointing_options=ocp.CheckpointManagerOptions(
save_interval_steps=1,
checkpointing_options=checkpoint_options.create_checkpointing_options(
save_decision_policy=(
ocp.training.save_decision_policies.FixedIntervalPolicy(1)
),
),
checkpoint_root_directory=ckpt_dir,
),
Expand Down Expand Up @@ -886,8 +889,10 @@ def create_learner(

mesh = pxla.thread_resources.env.physical_mesh
if ckpt_dir:
checkpointing_options = ocp.CheckpointManagerOptions(
save_interval_steps=1,
checkpointing_options = checkpoint_options.create_checkpointing_options(
save_decision_policy=(
ocp.training.save_decision_policies.FixedIntervalPolicy(1)
),
)
else:
checkpointing_options = None
Expand Down
19 changes: 13 additions & 6 deletions tests/rl/grpo/grpo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@
import jax.numpy as jnp
import numpy as np
import optax
import orbax.checkpoint as ocp
from orbax.checkpoint import v1 as ocp
from tunix.perf import trace as trace_lib
from tunix.perf.experimental import tracer as perf_tracer_v2
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo import grpo_learner as grpo_lib
from tunix.rl.queue import data_queue as queue_lib
from tunix.rl.rollout import base_rollout
from tunix.sft import checkpoint_options
from tunix.sft import profiler
from tunix.tests import test_common as tc
from typing_extensions import override
Expand Down Expand Up @@ -718,9 +719,13 @@ def test_resume_training(self):
eval_every_n_steps=10,
max_steps=10,
checkpoint_root_directory=temp_path,
checkpointing_options=ocp.CheckpointManagerOptions(
save_interval_steps=1,
max_to_keep=10,
checkpointing_options=checkpoint_options.create_checkpointing_options(
save_decision_policy=(
ocp.training.save_decision_policies.FixedIntervalPolicy(1)
),
preservation_policy=ocp.training.preservation_policies.LatestN(
10
),
),
),
rollout_config=base_rollout.RolloutConfig(
Expand Down Expand Up @@ -1050,8 +1055,10 @@ def create_learner(
train_micro_batch_size=mini_batch_size,
rollout_micro_batch_size=mini_batch_size,
compute_logps_micro_batch_size=mini_batch_size,
checkpointing_options=ocp.CheckpointManagerOptions(
save_interval_steps=4,
checkpointing_options=checkpoint_options.create_checkpointing_options(
save_decision_policy=(
ocp.training.save_decision_policies.FixedIntervalPolicy(4)
),
),
checkpoint_root_directory=ckpt_dir,
),
Expand Down
87 changes: 77 additions & 10 deletions tests/sft/checkpoint_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import optax
import qwix
from tunix.sft import checkpoint_manager
from tunix.sft import checkpoint_options

os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'

Expand Down Expand Up @@ -110,20 +111,25 @@ def test_empty_root_directory(self):
def test_checkpoint_manager_options_none_sets_default(self):
cp_path = f'{self.temp_path}/{self.id()}'
cp_manager = checkpoint_manager.CheckpointManager(cp_path, options=None)
self.assertIsNotNone(cp_manager._checkpoint_manager)
self.assertIsNotNone(cp_manager._checkpointer)
self.assertEqual(
cp_manager._checkpoint_manager._options, # pytype: disable=attribute-error
checkpoint_manager._DEFAULT_CHECKPOINTING_OPTIONS,
cp_manager._options,
checkpoint_options.DEFAULT_CHECKPOINTING_OPTIONS,
)

def test_context_property(self):
cp_path = f'{self.temp_path}/{self.id()}'
cp_manager = checkpoint_manager.CheckpointManager(cp_path)
self.assertIsNotNone(cp_manager.context)

def test_save(self):
cp_path = f'{self.temp_path}/{self.id()}'
cp_manager = checkpoint_manager.CheckpointManager(cp_path)
model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh)

# Save the model state.
self.assertTrue(cp_manager.save(1, model))
cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
cp_manager._checkpointer.wait() # pytype: disable=attribute-error
self.assertEqual(cp_manager.latest_step(), 1)

cp_manager.close()
Expand All @@ -139,7 +145,7 @@ def test_restore(self):

# Save the model params.
self.assertTrue(cp_manager.save(1, model))
cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
cp_manager._checkpointer.wait() # pytype: disable=attribute-error

# Change the model state.
changed_state = jax.tree.map(lambda x: x + 1, nnx.state(model))
Expand All @@ -162,7 +168,7 @@ def test_restore_different_sharding(self):

# Save the model params.
self.assertTrue(cp_manager.save(1, unsharded_model))
cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
cp_manager._checkpointer.wait() # pytype: disable=attribute-error

# Restore the model without shardings.
self.assertEqual(cp_manager.maybe_restore(unsharded_model), (1, {}))
Expand Down Expand Up @@ -211,7 +217,7 @@ def test_restore_with_lora(self):

# Save the model params.
self.assertTrue(cp_manager.save(1, model, save_only_lora_params=True))
cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
cp_manager._checkpointer.wait() # pytype: disable=attribute-error

# Change the model state.
changed_state = jax.tree.map(lambda x: x + 1, nnx.state(model))
Expand Down Expand Up @@ -241,7 +247,7 @@ def test_save_and_restore_with_custom_metadata(self):
model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh)
custom_metadata = {'foo': 1, 'bar': 2}
ckpt_manager.save(1, model, custom_metadata=custom_metadata)
ckpt_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
ckpt_manager._checkpointer.wait() # pytype: disable=attribute-error
restored_step, restored_metadata = ckpt_manager.maybe_restore(model)
self.assertEqual(restored_step, 1)
self.assertEqual(restored_metadata, custom_metadata)
Expand All @@ -257,7 +263,7 @@ def test_save_and_restore_with_optimizer_state(self):
)
custom_metadata = {'foo': 1, 'bar': 2}
ckpt_manager.save(1, model, optimizer, custom_metadata=custom_metadata)
ckpt_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
ckpt_manager._checkpointer.wait() # pytype: disable=attribute-error

new_optimizer = nnx.Optimizer(
model,
Expand All @@ -281,6 +287,67 @@ def test_save_and_restore_with_optimizer_state(self):
new_optimizer.opt_state.hyperparams['learning_rate'].value, 1e-3
)

def test_save_and_restore_with_forced_single_device_sharding(self):
cp_path = f'{self.temp_path}/{self.id()}'
ckpt_manager = checkpoint_manager.CheckpointManager(cp_path)
model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh)
optimizer = nnx.Optimizer(
model,
optax.inject_hyperparams(optax.adamw)(learning_rate=1e-3),
wrt=nnx.Param,
)
custom_metadata = {'foo': 1, 'bar': 2}
ckpt_manager.save(1, model, optimizer, custom_metadata=custom_metadata)
ckpt_manager._checkpointer.wait() # pytype: disable=attribute-error

new_optimizer = nnx.Optimizer(
model,
optax.inject_hyperparams(optax.adamw)(learning_rate=1e-5),
wrt=nnx.Param,
)

new_optimizer.opt_state.hyperparams['learning_rate'].value = jax.device_put(
new_optimizer.opt_state.hyperparams['learning_rate'].value,
jax.devices()[0],
)

self.assertIsInstance(
new_optimizer.opt_state.hyperparams['learning_rate'].value.sharding,
jax.sharding.SingleDeviceSharding,
)

restored_step, _ = ckpt_manager.maybe_restore(
model, new_optimizer
)
self.assertEqual(restored_step, 1)

errors = []
def assert_named_sharding(path, x):
if hasattr(x, 'sharding'):
try:
self.assertIsInstance(
x.sharding,
jax.sharding.NamedSharding,
f'Variable at {path} is not NamedSharding',
)
except AssertionError as e:
errors.append(str(e))
return

path_str = str(path)
if 'hyperparams' in path_str:
try:
self.assertEqual(x.sharding.spec, jax.sharding.PartitionSpec())
except AssertionError as e:
errors.append(str(e))

jax.tree.map_with_path(
assert_named_sharding,
nnx.state(new_optimizer, nnx.optimizer.OptState),
)
if errors:
self.fail('Found sharding mismatches:\n' + '\n'.join(errors))

def test_restore_without_optimizer(self):
cp_path = f'{self.temp_path}/{self.id()}'
ckpt_manager = checkpoint_manager.CheckpointManager(cp_path)
Expand All @@ -291,7 +358,7 @@ def test_restore_without_optimizer(self):
wrt=nnx.Param,
)
ckpt_manager.save(1, model, optimizer)
ckpt_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
ckpt_manager._checkpointer.wait() # pytype: disable=attribute-error
ckpt_manager.maybe_restore(model)

@parameterized.parameters(['test_data/checkpoints'])
Expand Down
Loading
Loading