Description
RLCluster._init_cluster() (rl_cluster.py line 392) unconditionally imports vllm_rollout regardless of the rollout_engine setting. When rollout_engine='vanilla' is set in the config, vLLM should not be required as a dependency.
Error
File "tunix/rl/rl_cluster.py", line 392, in _init_cluster
from tunix.rl.rollout import vllm_rollout
...
ModuleNotFoundError: No module named 'vllm'
Even after installing vLLM, there's a version mismatch:
from vllm.platforms import current_platform
ImportError: cannot import name 'current_platform' from 'vllm.platforms'
Environment
- Tunix 0.1.7 (pip install google-tunix)
- MaxText HEAD (installed from GitHub)
- JAX 0.9.2
- v5litepod-32 TPU (multi-host, 8 VMs × 4 chips)
Additional Context
Running MaxText's train_rl.py on direct TPU VMs (not Pathways/GKE). The vanilla rollout engine should work without vLLM for single-turn GRPO.
Additionally, MaxText HEAD's train_rl.py passes several rollout_vllm_* kwargs to RolloutConfig that Tunix 0.1.7 doesn't support:
rollout_vllm_swap_space_size_gb
rollout_vllm_hf_config_path
rollout_vllm_additional_config
rollout_vllm_init_with_random_weights
rollout_vllm_enable_dp_attention
rollout_vllm_max_num_batched_tokens
rollout_vllm_max_num_seqs
rollout_vllm_tpu_backend_type
This suggests MaxText HEAD and Tunix 0.1.7 are out of sync.
Suggested Fix
Make the vllm_rollout import conditional on rollout_engine != 'vanilla' in _init_cluster().
Description
RLCluster._init_cluster()(rl_cluster.py line 392) unconditionally importsvllm_rolloutregardless of therollout_enginesetting. Whenrollout_engine='vanilla'is set in the config, vLLM should not be required as a dependency.Error
Even after installing vLLM, there's a version mismatch:
Environment
Additional Context
Running MaxText's
train_rl.pyon direct TPU VMs (not Pathways/GKE). The vanilla rollout engine should work without vLLM for single-turn GRPO.Additionally, MaxText HEAD's
train_rl.pypasses severalrollout_vllm_*kwargs toRolloutConfigthat Tunix 0.1.7 doesn't support:rollout_vllm_swap_space_size_gbrollout_vllm_hf_config_pathrollout_vllm_additional_configrollout_vllm_init_with_random_weightsrollout_vllm_enable_dp_attentionrollout_vllm_max_num_batched_tokensrollout_vllm_max_num_seqsrollout_vllm_tpu_backend_typeThis suggests MaxText HEAD and Tunix 0.1.7 are out of sync.
Suggested Fix
Make the vllm_rollout import conditional on
rollout_engine != 'vanilla'in_init_cluster().