Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ jobs:
apt-get update; apt-get install -y less

cd tunix && python tests/generate/sglang_jax_sampler_test.py && python tests/generate/sglang_jax_lora_test.py
python tests/rl/rl_cluster_test.py
- name: Run tunix SFT integration tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
Expand Down
328 changes: 328 additions & 0 deletions tests/rl/rl_cluster_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import functools
import os
import os
import unittest
from unittest import mock

from absl.testing import absltest
Expand All @@ -32,6 +34,8 @@
from tunix.rl.rollout import base_rollout
from tunix.tests import test_common as tc

# Some tests relying on SGLang and vLLM cannot run in run_prod environment.
is_run_prod = os.environ.get('GITHUB_JOB') == 'run_prod'

PreTrainedTokenizerBase = tokenization_utils_base.PreTrainedTokenizerBase
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'
Expand Down Expand Up @@ -322,6 +326,320 @@ def test_generate_with_chat_template(self): # pylint: disable=g-doc-args
called_prompts = rl_cluster.rollout.generate.call_args[0][0]
self.assertEqual(called_prompts, ['formatted prompt'])

def _create_test_rl_cluster(
self,
rollout_engine: str,
rollout_config: base_rollout.RolloutConfig,
) -> rl_cluster_lib.RLCluster:
split_index = self.device_count // 2
actor_mesh = Mesh(
np.array(jax.devices()[:split_index]).reshape(split_index, 1),
('fsdp', 'tp'),
)
rollout_mesh = Mesh(
np.array(jax.devices()[split_index:]).reshape(1, split_index),
('fsdp', 'tp'),
)
cluster_config = rl_cluster_lib.ClusterConfig(
role_to_mesh={
rl_cluster_lib.Role.ACTOR: actor_mesh,
rl_cluster_lib.Role.REFERENCE: actor_mesh,
rl_cluster_lib.Role.ROLLOUT: rollout_mesh,
},
rollout_engine=rollout_engine,
offload_to_cpu=False,
training_config=rl_cluster_lib.RLTrainingConfig(
actor_optimizer=optax.sgd(1e-3),
eval_every_n_steps=1,
max_steps=10,
gradient_accumulation_steps=None,
),
rollout_config=rollout_config,
)
vocab = tc.MockVocab()
model = tc.ToyTransformer(
config=tc.ModelConfig(vocab_size=vocab.GetPieceSize()), rngs=nnx.Rngs(0)
)
return rl_cluster_lib.RLCluster(
actor=model, tokenizer=vocab, cluster_config=cluster_config
)

def test_init_cluster_invalid_engine_string(self):
with self.assertRaisesRegex(
ValueError, '`cluster_config.rollout_engine` should be one of'
):
self._create_test_rl_cluster(
'invalid_engine', base_rollout.RolloutConfig()
)

@parameterized.parameters('vanilla', 'vllm', 'sglang_jax')
def test_init_rollout_engine_missing_config_raises_error(self, engine):
with self.assertRaisesRegex(
ValueError, '`cluster_config.rollout_config` cannot be None.'
):
self._create_test_rl_cluster(engine, None)

@parameterized.parameters('vanilla', 'vllm', 'sglang_jax')
def test_init_rollout_engine_empty_dict_config_raises_error(self, engine):
with self.assertRaisesRegex(
ValueError,
'Rollout config is a dict but missing a train config.',
):
self._create_test_rl_cluster(engine, {})

@parameterized.named_parameters(
dict(
testcase_name='single_config',
rollout_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10,
kv_cache_size=1024,
data_type=jnp.bfloat16,
),
expected_cache_size=1024,
),
dict(
testcase_name='dict_config',
rollout_config={
rl_cluster_lib.Mode.TRAIN: base_rollout.RolloutConfig(
max_tokens_to_generate=10,
kv_cache_size=1024,
data_type=jnp.bfloat16,
),
rl_cluster_lib.Mode.EVAL: base_rollout.RolloutConfig(
max_tokens_to_generate=10,
kv_cache_size=2048,
data_type=jnp.bfloat16,
),
},
expected_cache_size=2048,
),
)
@mock.patch.object(
rl_cluster_lib.vanilla_rollout, 'VanillaRollout', autospec=True
)
def test_init_vanilla_rollout_engine(
self, mock_vanilla_cls, rollout_config, expected_cache_size
):
rl_cluster = self._create_test_rl_cluster('vanilla', rollout_config)

mock_vanilla_cls.assert_called_once()
self.assertEqual(rl_cluster.rollout, mock_vanilla_cls.return_value)
called_kwargs = mock_vanilla_cls.call_args.kwargs
self.assertIsInstance(
called_kwargs['cache_config_or_size'], base_rollout.CacheConfig
)
self.assertEqual(
called_kwargs['cache_config_or_size'].cache_size, expected_cache_size
)

def test_init_vanilla_rollout_engine_missing_model_config(self):
split_index = self.device_count // 2
actor_mesh = Mesh(
np.array(jax.devices()[:split_index]).reshape(split_index, 1),
('fsdp', 'tp'),
)
cluster_config = rl_cluster_lib.ClusterConfig(
role_to_mesh={
rl_cluster_lib.Role.ACTOR: actor_mesh,
rl_cluster_lib.Role.REFERENCE: actor_mesh,
rl_cluster_lib.Role.ROLLOUT: actor_mesh,
},
rollout_engine='vanilla',
offload_to_cpu=False,
training_config=rl_cluster_lib.RLTrainingConfig(
actor_optimizer=optax.sgd(1e-3),
eval_every_n_steps=1,
),
rollout_config=base_rollout.RolloutConfig(),
)

# A dummy model without config
class DummyModel(nnx.Module):

def __init__(self):
self.w = nnx.Param(jnp.zeros((1,)))

with self.assertRaisesRegex(
ValueError, '`self.rollout_actor` must have a config attribute.'
):
rl_cluster_lib.RLCluster(
actor=DummyModel(),
tokenizer=tc.MockVocab(),
cluster_config=cluster_config,
)

@parameterized.named_parameters(
dict(
testcase_name='single_config',
rollout_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10, kv_cache_size=1024
),
expected_train_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10, kv_cache_size=1024
),
),
dict(
testcase_name='dict_config',
rollout_config={
rl_cluster_lib.Mode.TRAIN: base_rollout.RolloutConfig(
max_tokens_to_generate=10, kv_cache_size=1024
),
rl_cluster_lib.Mode.EVAL: base_rollout.RolloutConfig(
max_tokens_to_generate=20, kv_cache_size=2048
),
},
expected_train_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10, kv_cache_size=1024
),
),
)
@mock.patch.object(rl_cluster_lib.mock_rollout, 'MockRollout', autospec=True)
def test_init_mock_rollout_engine(
self, mock_mock_cls, rollout_config, expected_train_config
):
rl_cluster = self._create_test_rl_cluster('mock', rollout_config)

mock_mock_cls.assert_called_once()
self.assertEqual(rl_cluster.rollout, mock_mock_cls.return_value)
called_kwargs = mock_mock_cls.call_args.kwargs
self.assertEqual(called_kwargs['rollout_config'], expected_train_config)

@parameterized.named_parameters(
dict(
testcase_name='single_config',
rollout_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10,
kv_cache_size=1024,
rollout_vllm_model_version='dummy_version',
),
expected_train_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10,
kv_cache_size=1024,
rollout_vllm_model_version='dummy_version',
),
expected_cache_size=1024,
),
dict(
testcase_name='dict_config',
rollout_config={
rl_cluster_lib.Mode.TRAIN: base_rollout.RolloutConfig(
max_tokens_to_generate=10,
kv_cache_size=1024,
rollout_vllm_model_version='dummy_version',
),
rl_cluster_lib.Mode.EVAL: base_rollout.RolloutConfig(
max_tokens_to_generate=20,
kv_cache_size=2048,
rollout_vllm_model_version='dummy_version',
),
},
expected_train_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10,
kv_cache_size=1024,
rollout_vllm_model_version='dummy_version',
),
expected_cache_size=2048,
),
)
@unittest.skipIf(is_run_prod, 'Skipping in run_prod')
def test_init_vllm_rollout_engine(
self,
rollout_config,
expected_train_config,
expected_cache_size,
):
# Internal placeholder for vllm rollout worker stub, don't change this line.
from tunix.rl.rollout import vllm_rollout

with mock.patch.object(
vllm_rollout, 'VllmRollout', autospec=True
) as mock_vllm_cls:
rl_cluster = self._create_test_rl_cluster('vllm', rollout_config)

mock_vllm_cls.assert_called_once()
self.assertEqual(rl_cluster.rollout, mock_vllm_cls.return_value)
called_kwargs = mock_vllm_cls.call_args.kwargs
self.assertEqual(called_kwargs['rollout_config'], expected_train_config)
self.assertEqual(
called_kwargs['cache_config_or_size'], expected_cache_size
)
self.assertIn('mesh', called_kwargs)

@unittest.skipIf(is_run_prod, 'Skipping in run_prod')
def test_init_vllm_rollout_engine_missing_version_raises(self):
rollout_config = base_rollout.RolloutConfig(
rollout_vllm_model_version=None,
)
with self.assertRaisesRegex(
ValueError, 'Rollout vllm model version or path is missing!'
):
self._create_test_rl_cluster('vllm', rollout_config)

@parameterized.named_parameters(
dict(
testcase_name='single_config',
rollout_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10, kv_cache_size=1024
),
expected_train_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10, kv_cache_size=1024
),
),
dict(
testcase_name='dict_config',
rollout_config={
rl_cluster_lib.Mode.TRAIN: base_rollout.RolloutConfig(
max_tokens_to_generate=10, kv_cache_size=1024
),
rl_cluster_lib.Mode.EVAL: base_rollout.RolloutConfig(
max_tokens_to_generate=20, kv_cache_size=2048
),
},
expected_train_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10, kv_cache_size=1024
),
),
)
@unittest.skipIf(is_run_prod, 'Skipping in run_prod')
def test_init_sglang_jax_rollout_engine(
self, rollout_config, expected_train_config
):
# Internal placeholder for sglang_jax rollout worker stub, don't change this line.
from tunix.rl.rollout import sglang_jax_rollout

with mock.patch.object(
sglang_jax_rollout, 'SglangJaxRollout', autospec=True
) as mock_sglang_cls:
rl_cluster = self._create_test_rl_cluster('sglang_jax', rollout_config)

mock_sglang_cls.assert_called_once()
self.assertEqual(rl_cluster.rollout, mock_sglang_cls.return_value)
called_kwargs = mock_sglang_cls.call_args.kwargs
self.assertEqual(called_kwargs['rollout_config'], expected_train_config)
self.assertIn('mesh', called_kwargs)

@unittest.skipIf(is_run_prod, 'Skipping in run_prod')
@mock.patch.object(rl_cluster_lib.sft_utils, 'is_lora_enabled', autospec=True)
def test_init_sglang_jax_rollout_engine_lora_error(self, mock_is_lora):
mock_is_lora.return_value = True
rollout_config = base_rollout.RolloutConfig(
rollout_sglang_jax_enable_static_lora=False
)

with self.assertRaisesRegex(
ValueError, 'Rollout sglang jax lora config is missing'
):
self._create_test_rl_cluster('sglang_jax', rollout_config)

def test_init_cluster_unsupported_engine_type(self):
class InvalidEngine:
pass

with self.assertRaisesRegex(
NotImplementedError, 'Rollout engine .* not supported'
):
self._create_test_rl_cluster(InvalidEngine, base_rollout.RolloutConfig())

def test_user_defined_rollout_engine_class(self):
class CustomRolloutEngine(base_rollout.BaseRollout):

Expand Down Expand Up @@ -363,6 +681,13 @@ def model(self) -> nnx.Module:
def update_params(self, params, filter_types):
pass

@property
def mesh(self):
return Mesh(
np.array(jax.devices()[:1]).reshape(1, 1),
('fsdp', 'tp'),
)

split_index = self.device_count // 2

actor_mesh = Mesh(
Expand Down Expand Up @@ -443,6 +768,9 @@ def create_cluster_config(rollout_engine):
self.assertIsInstance(rl_cluster.rollout, CustomRolloutEngine)
self.assertEqual(rl_cluster.rollout.my_arg, 0)
self.assertEqual(rl_cluster.rollout.config, cluster_config.rollout_config)
self.assertEqual(
rl_cluster.r2m[rl_cluster_lib.Role.ROLLOUT], rl_cluster.rollout.mesh
)

@parameterized.named_parameters(
dict(
Expand Down
Loading
Loading