Skip to content

[sglang] feat: Add SGLang async multi-turn rollout with tool support #1037

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

SwordFaith
Copy link
Collaborator

@SwordFaith SwordFaith commented Apr 11, 2025

A redesigned version of #917

Current Status

Develop log & Tracker

What Has Been Done

  • Async Rollout Refactoring: Integrate with the tool server to coordinate tool calls during generation, leveraging request IDs for state and progress tracking, support async multi-turn conversations in Agentic RL training (with Tool support).
  • Async Request Management: Encapsulate rollout requests into a unified structure, enabling efficient tracking and handling of concurrent multi-turn dialogues with chatml style messages.
  • Extensible Tools: A modular design for adapt tools in OpenAIFunctionTool format which is both support by SGLang and vLLM, with create separate instance, execute when tool call, calc score according to tool env state and release resource.
  • Multi-turn support has been implemented for the GSM8K task (new version working on). However, training has not yet converged, and we hope the community could join to investigate the issue.

What Is WIP

  • Merge loss mask to training process from last version
  • Add more user friendly tool config and e2e tests for gsm8k with tool training
  • We are going to validate our multiturn feature in open-source sandbox environments.

Key Features will be introduced in future version

  • Integrate a Ray-based agent trainer to enable explicit separation of the rollout and training pipeline. Provide support for partial rollout handling and fine-grained request state management.
  • Extend the framework to support simulated user interactions (e.g., roleplay, interactive feedback) and more complex environment-in-the-loop RL tasks.

Future Plan
Discussion Thread
RFC doc will be updated soon.

Contributors & Acknowledgement

@CLAassistant
Copy link

CLAassistant commented Apr 11, 2025

CLA assistant check
All committers have signed the CLA.

Copy link
Collaborator

@eric-haibin-lin eric-haibin-lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just waiting for CI to complete

@zhaochenyang20
Copy link
Collaborator

Great!

@eric-haibin-lin eric-haibin-lin merged commit e0d035c into volcengine:main Apr 29, 2025
24 of 25 checks passed
yellowbee686 pushed a commit to yellowbee686/verl that referenced this pull request Apr 30, 2025
…olcengine#1037)

A redesigned version of volcengine#917 

## Current Status
[Develop log &
Tracker](zhaochenyang20/Awesome-ML-SYS-Tutorial#113)

**What Has Been Done**
- Async Rollout Refactoring: Integrate with the tool server to
coordinate tool calls during generation, leveraging request IDs for
state and progress tracking, support async multi-turn conversations in
Agentic RL training (with Tool support).
- Async Request Management: Encapsulate rollout requests into a unified
structure, enabling efficient tracking and handling of concurrent
multi-turn dialogues with chatml style messages.
- Extensible Tools: A modular design for adapt tools in
OpenAIFunctionTool format which is both support by SGLang and vLLM, with
create separate instance, execute when tool call, calc score according
to tool env state and release resource.
- Multi-turn support has been implemented for the GSM8K task (new
version working on). However, training has not yet converged, and we
hope the community could join to investigate the issue.

**What Is WIP**
- [x] Merge loss mask to training process from last version
- [x] Add more user friendly tool config and e2e tests for gsm8k with
tool training
- [ ] We are going to validate our multiturn feature in open-source
sandbox environments.

## Key Features will be introduced in future version

- Integrate a Ray-based agent trainer to enable explicit separation of
the rollout and training pipeline. Provide support for partial rollout
handling and fine-grained request state management.
- Extend the framework to support simulated user interactions (e.g.,
roleplay, interactive feedback) and more complex environment-in-the-loop
RL tasks.

**Future Plan**
[Discussion
Thread](zhaochenyang20/Awesome-ML-SYS-Tutorial#74 (comment))
[RFC
doc](https://github.com/SwordFaith/verl-sglang-dev-log/blob/main/rlhf/verl/multi-turn/veRL-multiturn-rollout-RFC.md)
will be updated soon.

## Contributors & Acknowledgement

- Xiang Long [mid.of.change@gmail.com](mailto:mid.of.change@gmail.com)
@SwordFaith (Design RFC & core-dev of refactor part)
- Yuzhen Zhou [zyzshishui@gmail.com](mailto:zyzshishui@gmail.com)
@zyzshishui (Core-dev)
- Chenyang Zhao [zhaochen20@outlook.com](mailto:zhaochen20@outlook.com)
@zhaochenyang20 (PM)
- Guanhua Wang @WANG-GH 
- Junrong Lin @ocss884 (verl-sglang support)
- Hanchen Zhang
[zhanghanchen77@gmail.com](mailto:zhanghanchen77@gmail.com)
- Haoran Wang [ubecwang@gmail.com](mailto:ubecwang@gmail.com)
- Rui Lu [learningrate1@gmail.com](mailto:learningrate1@gmail.com)
- Yujiang Li [liyujiang2020@gmail.com](mailto:liyujiang2020@gmail.com)
- Jiajun Li [guapisolo@gmail.com](mailto:guapisolo@gmail.com)
- Jin Pan [jpan236@wisc.edu](mailto:jpan236@wisc.edu)
- Zhi Zheng [zhengzhi@modelbest.cn](mailto:zhengzhi@modelbest.cn)
@zh-zheng

---------

Co-authored-by: zyzshishui <492129152@qq.com>
Co-authored-by: guanhua <281484683@qq.com>
Co-authored-by: zhaochenyang20 <zhaochen20@outlook.com>
Co-authored-by: ocss884 <ocss.lin@gmail.com>
Co-authored-by: Shawn/Yuxuan Tong <tongyuxuan361@gmail.com>
Co-authored-by: HL <linhaibin.eric@gmail.com>
@ChrisRBXiong
Copy link

Hello, I used verl/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh for training, and the grad_norm became "nan" starting from the second step and afterwards.

The specific log is as follows:

Filtering prompts longer than 1024 tokens:   0%|          | 0/7473 [00:00<?, ? examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  13%|█▎        | 1000/7473 [00:00<00:03, 1835.43 examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  27%|██▋       | 2000/7473 [00:01<00:03, 1603.15 examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  40%|████      | 3000/7473 [00:01<00:02, 1757.81 examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  54%|█████▎    | 4000/7473 [00:02<00:01, 1834.26 examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  67%|██████▋   | 5000/7473 [00:02<00:01, 1883.81 examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  80%|████████  | 6000/7473 [00:03<00:00, 1914.56 examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  94%|█████████▎| 7000/7473 [00:03<00:00, 1933.69 examples/s]
(TaskRunner pid=362057) filter dataset len: 7473
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens: 100%|██████████| 7473/7473 [00:03<00:00, 1942.47 examples/s]
Filtering prompts longer than 1024 tokens: 100%|██████████| 7473/7473 [00:03<00:00, 1871.16 examples/s]
(TaskRunner pid=362057) dataset len: 1319
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:   0%|          | 0/1319 [00:00<?, ? examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  76%|███████▌  | 1000/1319 [00:00<00:00, 1962.69 examples/s]
(TaskRunner pid=362057) filter dataset len: 1319
(TaskRunner pid=362057) Size of train dataloader: 29, Size of val dataloader: 1
(TaskRunner pid=362057) Total training steps: 4350
(TaskRunner pid=362057) colocated worker base class <class 'verl.single_controller.base.worker.Worker'>
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens: 100%|██████████| 1319/1319 [00:00<00:00, 1947.57 examples/s]
Filtering prompts longer than 1024 tokens: 100%|██████████| 1319/1319 [00:00<00:00, 1947.28 examples/s]
(TaskRunner pid=362057) DeprecationWarning: `ray.state.available_resources_per_node` is a private attribute and access will be removed in a future Ray version.
(TaskRunner pid=362057) WARNING:2025-05-01 20:50:26,029:Waiting for register center actor w58LT1_register_center to be ready. Elapsed time: 0 seconds out of 300 seconds.
(WorkerDict pid=50072) Monkey patch _flash_attention_forward in transformers.integrations.flash_attention
(WorkerDict pid=50072) You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
(WorkerDict pid=50072) 
Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 85.04it/s]
(WorkerDict pid=50072) [rank3]:[W501 20:50:39.657321419 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 3]  using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
(WorkerDict pid=50070) 
Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]
(WorkerDict pid=50070) 
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 79.79it/s]
(WorkerDict pid=49813) Model config after override: Qwen2Config {
(WorkerDict pid=49813)   "architectures": [
(WorkerDict pid=49813)     "Qwen2ForCausalLM"
(WorkerDict pid=49813)   ],
(WorkerDict pid=49813)   "attention_dropout": 0.0,
(WorkerDict pid=49813)   "eos_token_id": 151645,
(WorkerDict pid=49813)   "hidden_act": "silu",
(WorkerDict pid=49813)   "hidden_size": 3584,
(WorkerDict pid=49813)   "initializer_range": 0.02,
(WorkerDict pid=49813)   "intermediate_size": 18944,
(WorkerDict pid=49813)   "max_position_embeddings": 32768,
(WorkerDict pid=49813)   "max_window_layers": 28,
(WorkerDict pid=49813)   "model_type": "qwen2",
(WorkerDict pid=49813)   "num_attention_heads": 28,
(WorkerDict pid=49813)   "num_hidden_layers": 28,
(WorkerDict pid=49813)   "num_key_value_heads": 4,
(WorkerDict pid=49813)   "pad_token_id": 151643,
(WorkerDict pid=49813)   "rms_norm_eps": 1e-06,
(WorkerDict pid=49813)   "rope_scaling": null,
(WorkerDict pid=49813)   "rope_theta": 1000000.0,
(WorkerDict pid=49813)   "sliding_window": 131072,
(WorkerDict pid=49813)   "tie_word_embeddings": false,
(WorkerDict pid=49813)   "torch_dtype": "bfloat16",
(WorkerDict pid=49813)   "transformers_version": "4.51.1",
(WorkerDict pid=49813)   "use_cache": true,
(WorkerDict pid=49813)   "use_sliding_window": false,
(WorkerDict pid=49813)   "vocab_size": 152064
(WorkerDict pid=49813) }
(WorkerDict pid=49813) 
(WorkerDict pid=49813) NCCL version 2.21.5+cuda12.4
(WorkerDict pid=49813) Qwen2ForCausalLM contains 7.62B parameters
(WorkerDict pid=49813) wrap_policy: functools.partial(<function _or_policy at 0x7ef5ed872320>, policies=[functools.partial(<function transformer_auto_wrap_policy at 0x7ef5ed872200>, transformer_layer_cls={<class 'transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer'>})])
(WorkerDict pid=49813) Monkey patch _flash_attention_forward in transformers.integrations.flash_attention [repeated 7x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)
(WorkerDict pid=49813) Actor use_remove_padding=True
(WorkerDict pid=50070) Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Qwen2ForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
(WorkerDict pid=49813) You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`. [repeated 7x across cluster]
(WorkerDict pid=49813) 
Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 63.93it/s] [repeated 4x across cluster]
(WorkerDict pid=49813) [rank0]:[W501 20:50:41.237036992 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 0]  using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id. [repeated 7x across cluster]
(WorkerDict pid=50075) 
Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s] [repeated 2x across cluster]
(WorkerDict pid=50075) 
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 75.93it/s] [repeated 2x across cluster]
(WorkerDict pid=49813) Model config after override: Qwen2Config {
(WorkerDict pid=49813)   "architectures": [
(WorkerDict pid=49813)     "Qwen2ForCausalLM"
(WorkerDict pid=49813)   ],
(WorkerDict pid=49813)   "attention_dropout": 0.0,
(WorkerDict pid=49813)   "eos_token_id": 151645,
(WorkerDict pid=49813)   "hidden_act": "silu",
(WorkerDict pid=49813)   "hidden_size": 3584,
(WorkerDict pid=49813)   "initializer_range": 0.02,
(WorkerDict pid=49813)   "intermediate_size": 18944,
(WorkerDict pid=49813)   "max_position_embeddings": 32768,
(WorkerDict pid=49813)   "max_window_layers": 28,
(WorkerDict pid=49813)   "model_type": "qwen2",
(WorkerDict pid=49813)   "num_attention_heads": 28,
(WorkerDict pid=49813)   "num_hidden_layers": 28,
(WorkerDict pid=49813)   "num_key_value_heads": 4,
(WorkerDict pid=49813)   "pad_token_id": 151643,
(WorkerDict pid=49813)   "rms_norm_eps": 1e-06,
(WorkerDict pid=49813)   "rope_scaling": null,
(WorkerDict pid=49813)   "rope_theta": 1000000.0,
(WorkerDict pid=49813)   "sliding_window": 131072,
(WorkerDict pid=49813)   "tie_word_embeddings": false,
(WorkerDict pid=49813)   "torch_dtype": "bfloat16",
(WorkerDict pid=49813)   "transformers_version": "4.51.1",
(WorkerDict pid=49813)   "use_cache": true,
(WorkerDict pid=49813)   "use_sliding_window": false,
(WorkerDict pid=49813)   "vocab_size": 152064
(WorkerDict pid=49813) }
(WorkerDict pid=49813) 
(WorkerDict pid=49813) 
Loading checkpoint shards:  25%|██▌       | 1/4 [00:02<00:08,  2.88s/it]
(WorkerDict pid=50074) Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Qwen2ForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)` [repeated 7x across cluster]
(WorkerDict pid=50075) 
Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s] [repeated 8x across cluster]
(WorkerDict pid=50074) 
Loading checkpoint shards:  75%|███████▌  | 3/4 [00:08<00:02,  2.88s/it] [repeated 17x across cluster]
(WorkerDict pid=50075) wrap_policy: functools.partial(<function _or_policy at 0x7f4483a6a560>, policies=[functools.partial(<function transformer_auto_wrap_policy at 0x7f4483a6a440>, transformer_layer_cls={<class 'transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer'>})]) [repeated 7x across cluster]
(WorkerDict pid=49813) Monkey patch _flash_attention_forward in transformers.integrations.flash_attention
(WorkerDict pid=50075) Actor use_remove_padding=True [repeated 7x across cluster]
(WorkerDict pid=49813) 
Loading checkpoint shards: 100%|██████████| 4/4 [00:10<00:00,  2.59s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [00:10<00:00,  2.62s/it]
(WorkerDict pid=50074) Monkey patch _flash_attention_forward in transformers.integrations.flash_attention
(WorkerDict pid=49813) Qwen2ForCausalLM contains 7.62B parameters
(WorkerDict pid=50070) Total steps: 4350, num_warmup_steps: 0
(WorkerDict pid=50075) wrap_policy: functools.partial(<function _or_policy at 0x7f4483a6a560>, policies=[functools.partial(<function transformer_auto_wrap_policy at 0x7f4483a6a440>, transformer_layer_cls={<class 'transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer'>})]) [repeated 8x across cluster]
(WorkerDict pid=49813) Actor use_remove_padding=True [repeated 8x across cluster]
(WorkerDict pid=50070) /usr/local/python/lib/python3.10/site-packages/vllm/connections.py:8: RuntimeWarning: Failed to read commit hash:
(WorkerDict pid=50070) No module named 'vllm._version'
(WorkerDict pid=50070)   from vllm.version import __version__ as VLLM_VERSION
(WorkerDict pid=50071) 
Loading checkpoint shards:  75%|███████▌  | 3/4 [00:10<00:03,  3.34s/it] [repeated 6x across cluster]
(WorkerDict pid=50071) 
Loading checkpoint shards: 100%|██████████| 4/4 [00:13<00:00,  3.17s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [00:13<00:00,  3.25s/it] [repeated 7x across cluster]
(WorkerDict pid=50074) NCCL version 2.21.5+cuda12.4
(WorkerDict pid=50071) Monkey patch _flash_attention_forward in transformers.integrations.flash_attention [repeated 6x across cluster]
(WorkerDict pid=49813) Before building sglang_async rollout, memory allocated (GB): 3.55, memory reserved (GB): 17.57, device memory used/total (GB): 20.63/95.00
(WorkerDict pid=50070) kwargs: {'n': 16, 'max_new_tokens': 1024, 'presence_penalty': 0.0, 'frequency_penalty': 0.0, 'repetition_penalty': 1.0, 'temperature': 1.0, 'top_k': -1, 'top_p': 1, 'ignore_eos': False}
(WorkerDict pid=49813) Total steps: 4350, num_warmup_steps: 0 [repeated 7x across cluster]
(WorkerDict pid=50070) /usr/local/python/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
(WorkerDict pid=50070)   warnings.warn(
(WorkerDict pid=50071) NCCL version 2.21.5+cuda12.4 [repeated 2x across cluster]
(WorkerDict pid=49813) /usr/local/python/lib/python3.10/site-packages/vllm/connections.py:8: RuntimeWarning: Failed to read commit hash: [repeated 7x across cluster]
(WorkerDict pid=49813) No module named 'vllm._version' [repeated 7x across cluster]
(WorkerDict pid=49813)   from vllm.version import __version__ as VLLM_VERSION [repeated 7x across cluster]
(WorkerDict pid=50075) /usr/local/python/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html . [repeated 3x across cluster]
(WorkerDict pid=50075)   warnings.warn( [repeated 3x across cluster]
(WorkerDict pid=50071) /usr/local/python/lib/python3.10/site-packages/vllm/connections.py:8: RuntimeWarning: Failed to read commit hash: [repeated 4x across cluster]
(WorkerDict pid=50071) No module named 'vllm._version' [repeated 4x across cluster]
(WorkerDict pid=50071)   from vllm.version import __version__ as VLLM_VERSION [repeated 4x across cluster]
(WorkerDict pid=50075) kwargs: {'n': 16, 'max_new_tokens': 1024, 'presence_penalty': 0.0, 'frequency_penalty': 0.0, 'repetition_penalty': 1.0, 'temperature': 1.0, 'top_k': -1, 'top_p': 1, 'ignore_eos': False} [repeated 3x across cluster]
(WorkerDict pid=49813) NCCL version 2.21.5+cuda12.4
(WorkerDict pid=50071) NCCL version 2.21.5+cuda12.4
(WorkerDict pid=50076) /usr/local/python/lib/python3.10/site-packages/vllm/connections.py:8: RuntimeWarning: Failed to read commit hash: [repeated 3x across cluster]
(WorkerDict pid=50076) No module named 'vllm._version' [repeated 3x across cluster]
(WorkerDict pid=50076)   from vllm.version import __version__ as VLLM_VERSION [repeated 3x across cluster]
(WorkerDict pid=49813) 
  0%|          | 0/35 [00:00<?, ?it/s]
Capturing batches (avail_mem=35.53 GB):   0%|          | 0/35 [00:00<?, ?it/s]
(WorkerDict pid=49813) 
Capturing batches (avail_mem=35.53 GB):   3%|▎         | 1/35 [00:00<00:23,  1.45it/s]
Capturing batches (avail_mem=35.08 GB):   3%|▎         | 1/35 [00:00<00:23,  1.45it/s]
(WorkerDict pid=49813) 
Capturing batches (avail_mem=35.08 GB):   6%|▌         | 2/35 [00:01<00:19,  1.73it/s]
Capturing batches (avail_mem=34.91 GB):   6%|▌         | 2/35 [00:01<00:19,  1.73it/s]
(WorkerDict pid=50074) NCCL version 2.21.5+cuda12.4
(WorkerDict pid=50076) NCCL version 2.21.5+cuda12.4
(WorkerDict pid=49813) 
Capturing batches (avail_mem=34.91 GB):   9%|▊         | 3/35 [00:01<00:19,  1.62it/s]
Capturing batches (avail_mem=34.74 GB):   9%|▊         | 3/35 [00:01<00:19,  1.62it/s]
(WorkerDict pid=50076) /usr/local/python/lib/python3.10/site-packages/vllm/connections.py:8: RuntimeWarning: Failed to read commit hash: [repeated 5x across cluster]
(WorkerDict pid=50076) No module named 'vllm._version' [repeated 5x across cluster]
(WorkerDict pid=50076)   from vllm.version import __version__ as VLLM_VERSION [repeated 5x across cluster]
(WorkerDict pid=50071) 
  0%|          | 0/35 [00:00<?, ?it/s]
Capturing batches (avail_mem=35.96 GB):   0%|          | 0/35 [00:00<?, ?it/s]
(WorkerDict pid=50071) 
Capturing batches (avail_mem=34.69 GB):  20%|██        | 7/35 [00:04<00:16,  1.71it/s]
Capturing batches (avail_mem=34.54 GB):  20%|██        | 7/35 [00:04<00:16,  1.71it/s] [repeated 13x across cluster]
(WorkerDict pid=50074) 
  0%|          | 0/35 [00:00<?, ?it/s]
Capturing batches (avail_mem=35.96 GB):   0%|          | 0/35 [00:00<?, ?it/s]
(WorkerDict pid=50076) 
Capturing batches (avail_mem=34.54 GB):  23%|██▎       | 8/35 [00:03<00:11,  2.26it/s]
Capturing batches (avail_mem=34.39 GB):  23%|██▎       | 8/35 [00:03<00:11,  2.26it/s] [repeated 40x across cluster]
(WorkerDict pid=50076) 
  0%|          | 0/35 [00:00<?, ?it/s]
Capturing batches (avail_mem=35.96 GB):   0%|          | 0/35 [00:00<?, ?it/s]
(WorkerDict pid=49813) 
Capturing batches (avail_mem=32.36 GB):  91%|█████████▏| 32/35 [00:15<00:01,  2.34it/s]
Capturing batches (avail_mem=32.35 GB):  91%|█████████▏| 32/35 [00:15<00:01,  2.34it/s]
(WorkerDict pid=49813) 
Capturing batches (avail_mem=32.35 GB):  94%|█████████▍| 33/35 [00:15<00:00,  2.33it/s]
Capturing batches (avail_mem=32.35 GB):  94%|█████████▍| 33/35 [00:15<00:00,  2.33it/s]
(WorkerDict pid=50076) 
Capturing batches (avail_mem=33.24 GB):  57%|█████▋    | 20/35 [00:08<00:06,  2.38it/s]
Capturing batches (avail_mem=33.16 GB):  57%|█████▋    | 20/35 [00:08<00:06,  2.38it/s] [repeated 47x across cluster]
(WorkerDict pid=49813) kwargs: {'n': 16, 'max_new_tokens': 1024, 'presence_penalty': 0.0, 'frequency_penalty': 0.0, 'repetition_penalty': 1.0, 'temperature': 1.0, 'top_k': -1, 'top_p': 1, 'ignore_eos': False}
(WorkerDict pid=49813) After building sglang_async rollout, memory allocated (GB): 3.55, memory reserved (GB): 17.57, device memory used/total (GB): 26.44/95.00
(WorkerDict pid=49813) After building sharding manager, memory allocated (GB): 3.55, memory reserved (GB): 17.57, device memory used/total (GB): 26.44/95.00
(WorkerDict pid=49813) /usr/local/python/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
(WorkerDict pid=49813)   warnings.warn(
(WorkerDict pid=50074) 
Capturing batches (avail_mem=32.78 GB):  94%|█████████▍| 33/35 [00:13<00:00,  2.49it/s]
Capturing batches (avail_mem=32.78 GB):  94%|█████████▍| 33/35 [00:13<00:00,  2.49it/s] [repeated 8x across cluster]
(WorkerDict pid=50076) 
Capturing batches (avail_mem=32.80 GB):  89%|████████▊ | 31/35 [00:13<00:01,  2.49it/s]
Capturing batches (avail_mem=32.79 GB):  89%|████████▊ | 31/35 [00:13<00:01,  2.49it/s] [repeated 21x across cluster]
(WorkerDict pid=50076) kwargs: {'n': 16, 'max_new_tokens': 1024, 'presence_penalty': 0.0, 'frequency_penalty': 0.0, 'repetition_penalty': 1.0, 'temperature': 1.0, 'top_k': -1, 'top_p': 1, 'ignore_eos': False} [repeated 3x across cluster]
(WorkerDict pid=50076) /usr/local/python/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html . [repeated 3x across cluster]
(WorkerDict pid=50076)   warnings.warn( [repeated 3x across cluster]
(TaskRunner pid=362057) Using LocalLogger is deprecated. The constructor API will change 
(TaskRunner pid=362057) Checkpoint tracker file does not exist: %s /tmp/ray/session_2025-05-01_20-49-28_274415_338741/runtime_resources/working_dir_files/_ray_pkg_4ce66e647b0fe766/checkpoints/test/latest_checkpointed_iteration.txt
(TaskRunner pid=362057) Training from scratch
(TaskRunner pid=362057) test_gen_batch meta info: {'eos_token_id': 151645, 'pad_token_id': 151643, 'recompute_log_prob': False, 'do_sample': False, 'validate': True}
(WorkerDict pid=49813) /usr/local/python/lib/python3.10/site-packages/sglang/srt/utils.py:888: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:203.)
(WorkerDict pid=49813)   tensor_data = torch.ByteTensor(
(WorkerDict pid=50076) 
Capturing batches (avail_mem=32.77 GB): 100%|██████████| 35/35 [00:15<00:00,  2.44it/s]
Capturing batches (avail_mem=32.77 GB): 100%|██████████| 35/35 [00:15<00:00,  2.32it/s] [repeated 6x across cluster]
(WorkerDict pid=50071) /usr/local/python/lib/python3.10/site-packages/sglang/srt/utils.py:888: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:203.) [repeated 3x across cluster]
(WorkerDict pid=50071)   tensor_data = torch.ByteTensor( [repeated 3x across cluster]
(TaskRunner pid=362057) validation generation end
(TaskRunner pid=362057) [prompt] system
(TaskRunner pid=362057) 
(TaskRunner pid=362057)                             You are a math expert. You are given a question and you need to solve it step by step.  
(TaskRunner pid=362057)                             `calc_gsm8k_reward` is a tool for calculating the reward of gsm8k. You should use this 
(TaskRunner pid=362057)                             tool to calculate the reward of your answer(1.0 if your answer is correct, 0.0 if your 
(TaskRunner pid=362057)                             answer is incorrect) before submitting it and refine your answer if necessary. Put your 
(TaskRunner pid=362057)                             final answer in the format of `#### <answer>`.
(TaskRunner pid=362057) 
(TaskRunner pid=362057) # Tools
(TaskRunner pid=362057) 
(TaskRunner pid=362057) You may call one or more functions to assist with the user query.
(TaskRunner pid=362057) 
(TaskRunner pid=362057) You are provided with function signatures within <tools></tools> XML tags:
(TaskRunner pid=362057) <tools>
(TaskRunner pid=362057) {"type": "function", "function": {"name": "calc_gsm8k_reward", "description": "A tool for calculating the reward of gsm8k. (1.0 if your answer is correct, 0.0 if your answer is incorrect)", "parameters": {"type": "object", "properties": {"answer": {"type": "string", "description": "The model's answer to the GSM8K math problem, must be a digits", "enum": null}}, "required": ["answer"]}, "strict": false}}
(TaskRunner pid=362057) </tools>
(TaskRunner pid=362057) 
(TaskRunner pid=362057) For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
(TaskRunner pid=362057) <tool_call>
(TaskRunner pid=362057) {"name": <function-name>, "arguments": <args-json-object>}
(TaskRunner pid=362057) </tool_call>
(TaskRunner pid=362057) user
(TaskRunner pid=362057) Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? 
(TaskRunner pid=362057)         You must use the `calc_gsm8k_reward` tool to calculate the reward 
(TaskRunner pid=362057)         of your answer(1.0 if your answer is correct, 0.0 if your answer is incorrect) 
(TaskRunner pid=362057)         before submitting it at least once and refine your answer if necessary. 
(TaskRunner pid=362057)         Put your final answer in the format of `#### <answer>`.
(TaskRunner pid=362057)     
(TaskRunner pid=362057) assistant
(TaskRunner pid=362057) 
(TaskRunner pid=362057) [response] To find out how much Janet makes every day at the farmers' market, we need to follow these steps:
(TaskRunner pid=362057) 
(TaskRunner pid=362057) 1. Calculate the total number of eggs laid by the ducks per day.
(TaskRunner pid=362057) 2. Subtract the number of eggs Janet eats for breakfast and the number of eggs she uses for baking muffins.
(TaskRunner pid=362057) 3. Multiply the remaining number of eggs by the price per egg to find out how much she makes at the farmers' market.
(TaskRunner pid=362057) 
(TaskRunner pid=362057) Let's do the calculations:
(TaskRunner pid=362057) 
(TaskRunner pid=362057) 1. Total number of eggs laid per day: 16
(TaskRunner pid=362057) 2. Number of eggs Janet eats for breakfast: 3
(TaskRunner pid=362057) 3. Number of eggs Janet uses for baking muffins: 4
(TaskRunner pid=362057) 4. Remaining number of eggs: 16 - 3 - 4 = 9
(TaskRunner pid=362057) 5. Price per egg: $2
(TaskRunner pid=362057) 
(TaskRunner pid=362057) Now, let's calculate the total amount Janet makes at the farmers' market:
(TaskRunner pid=362057) 
(TaskRunner pid=362057) \[ \text{Total amount} = \text{Remaining number of eggs} \times \text{Price per egg} \]
(TaskRunner pid=362057) \[ \text{Total amount} = 9 \times 2 = 18 \]
(TaskRunner pid=362057) 
(TaskRunner pid=362057) So, Janet makes $18 every day at the farmers' market.
(TaskRunner pid=362057) 
(TaskRunner pid=362057) Now, let's use the `calc_gsm8k_reward` tool to check the correctness of our answer.
(TaskRunner pid=362057) <tool_call>
(TaskRunner pid=362057) {"name": "calc_gsm8k_reward", "arguments": "{\"answer\": \"18\"}"}
(TaskRunner pid=362057) </tool_call>
(TaskRunner pid=362057) tool
(TaskRunner pid=362057) Current parsed answer='18' reward=1.0
(TaskRunner pid=362057) assistant
(TaskRunner pid=362057) #### 18
(TaskRunner pid=362057) [ground_truth] 18
(TaskRunner pid=362057) [score] 1.0
(TaskRunner pid=362057) 'Initial validation metrics: {}'
(TaskRunner pid=362057) step:0
(TaskRunner pid=362057) 
Training Progress:   0%|          | 0/4350 [00:00<?, ?it/s]
(TaskRunner pid=362057) list(reward_extra_infos_dict.keys())=[]
(TaskRunner pid=362057) step:1 - global_seqlen/min:379680.000 - global_seqlen/max:457144.000 - global_seqlen/minmax_diff:77464.000 - global_seqlen/balanced_min:416128.000 - global_seqlen/balanced_max:416129.000 - global_seqlen/mean:416128.750 - actor/entropy_loss:0.216 - actor/kl_loss:0.000 - actor/kl_coef:0.001 - actor/pg_loss:-0.004 - actor/pg_clipfrac:0.000 - actor/ppo_kl:0.000 - actor/pg_clipfrac_lower:0.000 - actor/grad_norm:0.010 - perf/mfu/actor:0.700 - perf/max_memory_allocated_gb:49.644 - perf/max_memory_reserved_gb:80.434 - perf/cpu_memory_used_gb:108.922 - actor/lr:0.000 - training/global_step:1.000 - training/epoch:0.000 - critic/score/mean:0.905 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.905 - critic/rewards/max:1.000 - critic/rewards/min:0.000 - critic/advantages/mean:-0.000 - critic/advantages/max:0.653 - critic/advantages/min:-1.436 - critic/returns/mean:-0.000 - critic/returns/max:0.653 - critic/returns/min:-1.436 - response_length/mean:364.025 - response_length/max:1024.000 - response_length/min:150.000 - response_length/clip_ratio:0.026 - prompt_length/mean:448.727 - prompt_length/max:533.000 - prompt_length/min:410.000 - prompt_length/clip_ratio:0.000 - timing_s/gen:118.270 - timing_s/reward:1.181 - timing_s/old_log_prob:52.755 - timing_s/ref:54.759 - timing_s/adv:0.059 - timing_s/update_actor:188.158 - timing_s/step:415.307 - timing_per_token_ms/gen:0.079 - timing_per_token_ms/update_actor:0.057 - timing_per_token_ms/adv:0.000 - timing_per_token_ms/ref:0.016 - perf/total_num_tokens:3329030.000 - perf/time_per_step:415.307 - perf/throughput:1001.979
(TaskRunner pid=362057) 
Training Progress:   0%|          | 1/4350 [06:57<504:26:22, 417.56s/it]
(TaskRunner pid=362057) list(reward_extra_infos_dict.keys())=[]
(WorkerDict pid=49813) WARN: rank 0 grad_norm is not finite: nan
(TaskRunner pid=362057) step:2 - global_seqlen/min:407422.000 - global_seqlen/max:440440.000 - global_seqlen/minmax_diff:33018.000 - global_seqlen/balanced_min:418815.000 - global_seqlen/balanced_max:418816.000 - global_seqlen/mean:418815.625 - actor/entropy_loss:0.218 - actor/kl_loss:0.000 - actor/kl_coef:0.001 - actor/pg_loss:0.006 - actor/pg_clipfrac:0.000 - actor/ppo_kl:0.000 - actor/pg_clipfrac_lower:0.000 - actor/grad_norm:nan - perf/mfu/actor:0.703 - perf/max_memory_allocated_gb:54.050 - perf/max_memory_reserved_gb:80.617 - perf/cpu_memory_used_gb:109.974 - actor/lr:0.000 - training/global_step:2.000 - training/epoch:0.000 - critic/score/mean:0.903 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.903 - critic/rewards/max:1.000 - critic/rewards/min:0.000 - critic/advantages/mean:-0.001 - critic/advantages/max:1.436 - critic/advantages/min:-1.677 - critic/returns/mean:-0.001 - critic/returns/max:1.436 - critic/returns/min:-1.677 - response_length/mean:368.902 - response_length/max:1024.000 - response_length/min:156.000 - response_length/clip_ratio:0.028 - prompt_length/mean:449.098 - prompt_length/max:559.000 - prompt_length/min:409.000 - prompt_length/clip_ratio:0.000 - timing_s/gen:112.222 - timing_s/reward:1.169 - timing_s/old_log_prob:50.543 - timing_s/ref:49.970 - timing_s/adv:0.057 - timing_s/update_actor:188.153 - timing_s/step:402.215 - timing_per_token_ms/gen:0.074 - timing_per_token_ms/update_actor:0.056 - timing_per_token_ms/adv:0.000 - timing_per_token_ms/ref:0.015 - perf/total_num_tokens:3350525.000 - perf/time_per_step:402.215 - perf/throughput:1041.273
(TaskRunner pid=362057) 
Training Progress:   0%|          | 2/4350 [13:40<493:36:46, 408.70s/it]
(TaskRunner pid=362057) list(reward_extra_infos_dict.keys())=[]
(WorkerDict pid=50075) WARN: rank 5 grad_norm is not finite: nan [repeated 7x across cluster]
(WorkerDict pid=49813) WARN: rank 0 grad_norm is not finite: nan
(WorkerDict pid=50070) WARN: rank 1 grad_norm is not finite: nan
(TaskRunner pid=362057) step:3 - global_seqlen/min:387116.000 - global_seqlen/max:433827.000 - global_seqlen/minmax_diff:46711.000 - global_seqlen/balanced_min:416119.000 - global_seqlen/balanced_max:416120.000 - global_seqlen/mean:416119.125 - actor/entropy_loss:0.212 - actor/kl_loss:0.000 - actor/kl_coef:0.001 - actor/pg_loss:-0.001 - actor/pg_clipfrac:0.000 - actor/ppo_kl:-0.000 - actor/pg_clipfrac_lower:0.000 - actor/grad_norm:nan - perf/mfu/actor:0.709 - perf/max_memory_allocated_gb:54.534 - perf/max_memory_reserved_gb:80.617 - perf/cpu_memory_used_gb:109.965 - actor/lr:0.000 - training/global_step:3.000 - training/epoch:0.000 - critic/score/mean:0.890 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.890 - critic/rewards/max:1.000 - critic/rewards/min:0.000 - critic/advantages/mean:-0.000 - critic/advantages/max:0.250 - critic/advantages/min:-3.750 - critic/returns/mean:-0.000 - critic/returns/max:0.250 - critic/returns/min:-3.750 - response_length/mean:364.510 - response_length/max:1024.000 - response_length/min:150.000 - response_length/clip_ratio:0.020 - prompt_length/mean:448.223 - prompt_length/max:534.000 - prompt_length/min:411.000 - prompt_length/clip_ratio:0.000 - timing_s/gen:125.431 - timing_s/reward:1.166 - timing_s/old_log_prob:49.390 - timing_s/ref:49.245 - timing_s/adv:0.058 - timing_s/update_actor:185.307 - timing_s/step:410.658 - timing_per_token_ms/gen:0.084 - timing_per_token_ms/update_actor:0.056 - timing_per_token_ms/adv:0.000 - timing_per_token_ms/ref:0.015 - perf/total_num_tokens:3328953.000 - perf/time_per_step:410.658 - perf/throughput:1013.299
(TaskRunner pid=362057) 
Training Progress:   0%|          | 3/4350 [20:30<494:43:43, 409.71s/it]
(TaskRunner pid=362057) list(reward_extra_infos_dict.keys())=[]
(WorkerDict pid=50075) WARN: rank 5 grad_norm is not finite: nan [repeated 6x across cluster]
(WorkerDict pid=49813) WARN: rank 0 grad_norm is not finite: nan
(WorkerDict pid=50070) WARN: rank 1 grad_norm is not finite: nan
(TaskRunner pid=362057) step:4 - global_seqlen/min:408386.000 - global_seqlen/max:454859.000 - global_seqlen/minmax_diff:46473.000 - global_seqlen/balanced_min:422598.000 - global_seqlen/balanced_max:422599.000 - global_seqlen/mean:422598.750 - actor/entropy_loss:0.216 - actor/kl_loss:0.000 - actor/kl_coef:0.001 - actor/pg_loss:0.003 - actor/pg_clipfrac:0.000 - actor/ppo_kl:0.000 - actor/pg_clipfrac_lower:0.000 - actor/grad_norm:nan - perf/mfu/actor:0.705 - perf/max_memory_allocated_gb:55.359 - perf/max_memory_reserved_gb:80.617 - perf/cpu_memory_used_gb:110.119 - actor/lr:0.000 - training/global_step:4.000 - training/epoch:0.000 - critic/score/mean:0.872 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.872 - critic/rewards/max:1.000 - critic/rewards/min:0.000 - critic/advantages/mean:-0.000 - critic/advantages/max:2.016 - critic/advantages/min:-0.465 - critic/returns/mean:-0.000 - critic/returns/max:2.016 - critic/returns/min:-0.465 - response_length/mean:377.724 - response_length/max:1024.000 - response_length/min:95.000 - response_length/clip_ratio:0.029 - prompt_length/mean:447.664 - prompt_length/max:550.000 - prompt_length/min:415.000 - prompt_length/clip_ratio:0.000 - timing_s/gen:133.983 - timing_s/reward:1.183 - timing_s/old_log_prob:50.797 - timing_s/ref:50.276 - timing_s/adv:0.059 - timing_s/update_actor:190.039 - timing_s/step:426.396 - timing_per_token_ms/gen:0.087 - timing_per_token_ms/update_actor:0.056 - timing_per_token_ms/adv:0.000 - timing_per_token_ms/ref:0.015 - perf/total_num_tokens:3380790.000 - perf/time_per_step:426.396 - perf/throughput:991.094
(TaskRunner pid=362057) 

The versions of my Python packages are as follows:

Package                           Version       Editable project location
--------------------------------- ------------- -------------------------
absl-py                           2.1.0
accelerate                        1.6.0
aiohappyeyeballs                  2.6.1
aiohttp                           3.11.18
aiohttp-cors                      0.8.1
aiosignal                         1.3.2
annotated-types                   0.7.0
anthropic                         0.50.0
antlr4-python3-runtime            4.9.3
anyio                             4.9.0
asttokens                         3.0.0
astunparse                        1.6.3
async-timeout                     5.0.1
attrs                             25.3.0
beautifulsoup4                    4.13.4
blobfile                          3.0.0
boto3                             1.36.16
botocore                          1.36.17
cachetools                        5.5.2
certifi                           2025.4.26
cffi                              1.17.1
cfgv                              3.4.0
chardet                           5.2.0
charset-normalizer                3.4.1
click                             8.1.8
cloudpickle                       3.1.1
cmake                             3.31.4
codetiming                        1.4.0
colorful                          0.5.6
compressed-tensors                0.9.4
cuda-bindings                     12.8.0
cuda-python                       12.8.0
datasets                          3.5.1
decorator                         5.2.1
decord                            0.6.0
dill                              0.3.8
diskcache                         5.6.3
distlib                           0.3.9
distro                            1.9.0
docker-pycreds                    0.4.0
duckduckgo_search                 8.0.1
einops                            0.8.1
einops-exts                       0.0.4
exceptiongroup                    1.2.2
executing                         2.2.0
expecttest                        0.3.0
fastapi                           0.115.12
filelock                          3.18.0
flamingo-pytorch                  0.1.2
flash_attn                        2.7.4.post1
flashinfer-python                 0.2.3
frozenlist                        1.6.0
fsspec                            2025.3.0
ftfy                              6.3.1
gguf                              0.10.0
gitdb                             4.0.12
GitPython                         3.1.44
google-api-core                   2.24.2
google-auth                       2.39.0
googleapis-common-protos          1.70.0
grpcio                            1.71.0
h11                               0.16.0
hf_transfer                       0.1.9
httpcore                          1.0.9
httptools                         0.6.4
httpx                             0.28.1
huggingface-hub                   0.30.2
hydra-core                        1.3.2
hypothesis                        6.125.2
identify                          2.6.10
idna                              3.10
importlib_metadata                8.7.0
interegular                       0.3.3
ipython                           8.36.0
jedi                              0.19.2
Jinja2                            3.1.6
jiter                             0.9.0
jmespath                          1.0.1
jsonschema                        4.23.0
jsonschema-specifications         2025.4.1
lark                              1.2.2
liger_kernel                      0.5.8
lintrunner                        0.12.7
litellm                           1.67.5
llguidance                        0.7.19
llvmlite                          0.44.0
lm-format-enforcer                0.10.6
lxml                              5.4.0
Markdown                          3.7
markdown-it-py                    3.0.0
markdownify                       1.1.0
MarkupSafe                        3.0.2
matplotlib-inline                 0.1.7
mdurl                             0.1.2
mistral_common                    1.5.4
modelscope                        1.25.0
mpmath                            1.3.0
msgpack                           1.1.0
msgspec                           0.19.0
multidict                         6.4.3
multiprocess                      0.70.16
nanobind                          2.7.0
nest-asyncio                      1.6.0
networkx                          3.4.2
ninja                             1.11.1.3
nodeenv                           1.9.1
numba                             0.61.2
numpy                             1.26.4
nvidia-cublas-cu12                12.4.5.8
nvidia-cuda-cupti-cu12            12.4.127
nvidia-cuda-nvrtc-cu12            12.4.127
nvidia-cuda-runtime-cu12          12.4.127
nvidia-cudnn-cu12                 9.1.0.70
nvidia-cufft-cu12                 11.2.1.3
nvidia-curand-cu12                10.3.5.147
nvidia-cusolver-cu12              11.6.1.9
nvidia-cusparse-cu12              12.3.1.170
nvidia-cusparselt-cu12            0.6.2
nvidia-ml-py                      12.570.86
nvidia-nccl-cu12                  2.21.5
nvidia-nvjitlink-cu12             12.4.127
nvidia-nvtx-cu12                  12.4.127
omegaconf                         2.3.0
open_clip_torch                   2.30.0
openai                            1.76.2
opencensus                        0.11.4
opencensus-context                0.1.3
opencv-contrib-python             4.11.0.86
opencv-python                     4.11.0.86
opencv-python-headless            4.11.0.86
optree                            0.14.0
orjson                            3.10.18
outlines                          0.0.46
packaging                         25.0
pandas                            2.2.3
parso                             0.8.4
partial-json-parser               0.2.1.1.post5
peft                              0.15.2
pexpect                           4.9.0
pillow                            11.1.0
pip                               25.0.1
platformdirs                      4.3.7
pre_commit                        4.2.0
primp                             0.15.0
prometheus_client                 0.21.1
prometheus-fastapi-instrumentator 7.1.0
prompt_toolkit                    3.0.51
propcache                         0.3.1
proto-plus                        1.26.1
protobuf                          6.30.2
psutil                            7.0.0
ptyprocess                        0.7.0
pure_eval                         0.2.3
py-cpuinfo                        9.0.0
py-spy                            0.4.0
pyairports                        2.1.1
pyarrow                           20.0.0
pyasn1                            0.6.1
pyasn1_modules                    0.4.2
pybind11                          2.13.6
pycountry                         24.6.1
pycparser                         2.22
pycryptodomex                     3.22.0
pydantic                          2.11.4
pydantic_core                     2.33.2
Pygments                          2.19.1
pylatexenc                        2.10
pynvml                            12.0.0
python-dateutil                   2.9.0.post0
python-dotenv                     1.1.0
python-multipart                  0.0.20
pytz                              2025.2
PyYAML                            6.0.2
pyzmq                             26.4.0
ray                               2.45.0
referencing                       0.36.2
regex                             2024.11.6
requests                          2.32.3
rich                              14.0.0
rpds-py                           0.24.0
rsa                               4.9.1
ruamel.yaml                       0.18.10
ruamel.yaml.clib                  0.2.12
s3transfer                        0.11.2
safetensors                       0.5.3
sentencepiece                     0.2.0
sentry-sdk                        2.27.0
setproctitle                      1.3.6
setuptools                        80.1.0
sgl-kernel                        0.0.9.post2
sglang                            0.4.5.post3
six                               1.17.0
smart-open                        7.1.0
smmap                             5.0.2
smolagents                        1.14.0
sniffio                           1.3.1
sortedcontainers                  2.4.0
soundfile                         0.13.1
soupsieve                         2.7
sox                               1.5.0
stack-data                        0.6.3
starlette                         0.46.2
sympy                             1.13.1
tensorboard                       2.18.0
tensorboard-data-server           0.7.2
tensorboardX                      2.6.2.2
tensordict                        0.6.2
termcolor                         2.5.0
tiktoken                          0.8.0
timm                              1.0.14
tokenizers                        0.21.1
torch                             2.6.0+cu124
torch_memory_saver                0.0.5
torchao                           0.10.0
torchdata                         0.11.0
torchvision                       0.21.0
tqdm                              4.67.1
traitlets                         5.14.3
transformers                      4.51.1
triton                            3.2.0
types-dataclasses                 0.6.6
typing_extensions                 4.13.2
typing-inspection                 0.4.0
tzdata                            2025.2
urllib3                           2.4.0
uvicorn                           0.34.2
uvloop                            0.21.0
verl                              0.2.0.dev0    /root/verl
virtualenv                        20.30.0
vllm                              0.6.3
wandb                             0.19.10
watchfiles                        1.0.5
wcwidth                           0.2.13
websockets                        15.0.1
Werkzeug                          3.1.3
wheel                             0.45.1
wrapt                             1.17.2
xformers                          0.0.27.post2
xgrammar                          0.1.17
xxhash                            3.5.0
yarl                              1.20.0
zipp                              3.21.0

My CUDA version is Cuda compilation tools, release 12.4, V12.4.131.

The specific execution command is:

export VLLM_ATTENTION_BACKEND=XFORMERS
export NCCL_DEBUG=INFO
export GLOO_DEBUG=1

export NCCL_IB_GID_INDEX=3
export NCCL_IB_SL=3
export NCCL_CHECK_DISABLE=1
export NCCL_P2P_DISABLE=0
export NCCL_IB_DISABLE=0
export NCCL_LL_THRESHOLD=16384
export NCCL_IB_CUDA_SUPPORT=1
export NCCL_SOCKET_IFNAME=bond1
export GLOO_SOCKET_IFNAME=bond1

export UCX_NET_DEVICES=bond1
export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
export NCCL_COLLNET_ENABLE=0
export SHARP_COLL_ENABLE_SAT=0
export NCCL_NET_GDR_LEVEL=2
export NCCL_IB_QPS_PER_CONNECTION=4
export NCCL_IB_TC=160
export NCCL_PXN_DISABLE=0
export NCCL_DEBUG="INFO"
export HYDRA_FULL_ERROR=1

ray job submit \
    -- python3 -m verl.trainer.main_ppo \
    --config-path="$CONFIG_PATH" \
    --config-name='gsm8k_multiturn_grpo' \
    algorithm.adv_estimator=grpo \
    data.train_batch_size=256 \
    data.max_prompt_length=1024 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    data.return_raw_chat=True \
    actor_rollout_ref.model.path=model/Qwen2.5-7B-Instruct \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=256 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
    actor_rollout_ref.rollout.name=sglang_async \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
    actor_rollout_ref.rollout.n=${rollout_n} \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger=['console','tensorboard'] \
    trainer.project_name="${project_name}" \
    trainer.experiment_name="${exp_name}" \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=1 \
    trainer.save_freq=-1 \
    trainer.test_freq=20 \
    data.train_files=${TRAIN_FILE} \
    data.val_files=${TEST_FILE} \ actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \
    trainer.total_epochs=150

I sincerely hope to get some help. Thank you very much!

@SwordFaith
Copy link
Collaborator Author

Hello, I used verl/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh for training, and the grad_norm became "nan" starting from the second step and afterwards.

The specific log is as follows:

Filtering prompts longer than 1024 tokens:   0%|          | 0/7473 [00:00<?, ? examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  13%|█▎        | 1000/7473 [00:00<00:03, 1835.43 examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  27%|██▋       | 2000/7473 [00:01<00:03, 1603.15 examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  40%|████      | 3000/7473 [00:01<00:02, 1757.81 examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  54%|█████▎    | 4000/7473 [00:02<00:01, 1834.26 examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  67%|██████▋   | 5000/7473 [00:02<00:01, 1883.81 examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  80%|████████  | 6000/7473 [00:03<00:00, 1914.56 examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  94%|█████████▎| 7000/7473 [00:03<00:00, 1933.69 examples/s]
(TaskRunner pid=362057) filter dataset len: 7473
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens: 100%|██████████| 7473/7473 [00:03<00:00, 1942.47 examples/s]
Filtering prompts longer than 1024 tokens: 100%|██████████| 7473/7473 [00:03<00:00, 1871.16 examples/s]
(TaskRunner pid=362057) dataset len: 1319
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:   0%|          | 0/1319 [00:00<?, ? examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  76%|███████▌  | 1000/1319 [00:00<00:00, 1962.69 examples/s]
(TaskRunner pid=362057) filter dataset len: 1319
(TaskRunner pid=362057) Size of train dataloader: 29, Size of val dataloader: 1
(TaskRunner pid=362057) Total training steps: 4350
(TaskRunner pid=362057) colocated worker base class <class 'verl.single_controller.base.worker.Worker'>
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens: 100%|██████████| 1319/1319 [00:00<00:00, 1947.57 examples/s]
Filtering prompts longer than 1024 tokens: 100%|██████████| 1319/1319 [00:00<00:00, 1947.28 examples/s]
(TaskRunner pid=362057) DeprecationWarning: `ray.state.available_resources_per_node` is a private attribute and access will be removed in a future Ray version.
(TaskRunner pid=362057) WARNING:2025-05-01 20:50:26,029:Waiting for register center actor w58LT1_register_center to be ready. Elapsed time: 0 seconds out of 300 seconds.
(WorkerDict pid=50072) Monkey patch _flash_attention_forward in transformers.integrations.flash_attention
(WorkerDict pid=50072) You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
(WorkerDict pid=50072) 
Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 85.04it/s]
(WorkerDict pid=50072) [rank3]:[W501 20:50:39.657321419 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 3]  using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
(WorkerDict pid=50070) 
Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]
(WorkerDict pid=50070) 
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 79.79it/s]
(WorkerDict pid=49813) Model config after override: Qwen2Config {
(WorkerDict pid=49813)   "architectures": [
(WorkerDict pid=49813)     "Qwen2ForCausalLM"
(WorkerDict pid=49813)   ],
(WorkerDict pid=49813)   "attention_dropout": 0.0,
(WorkerDict pid=49813)   "eos_token_id": 151645,
(WorkerDict pid=49813)   "hidden_act": "silu",
(WorkerDict pid=49813)   "hidden_size": 3584,
(WorkerDict pid=49813)   "initializer_range": 0.02,
(WorkerDict pid=49813)   "intermediate_size": 18944,
(WorkerDict pid=49813)   "max_position_embeddings": 32768,
(WorkerDict pid=49813)   "max_window_layers": 28,
(WorkerDict pid=49813)   "model_type": "qwen2",
(WorkerDict pid=49813)   "num_attention_heads": 28,
(WorkerDict pid=49813)   "num_hidden_layers": 28,
(WorkerDict pid=49813)   "num_key_value_heads": 4,
(WorkerDict pid=49813)   "pad_token_id": 151643,
(WorkerDict pid=49813)   "rms_norm_eps": 1e-06,
(WorkerDict pid=49813)   "rope_scaling": null,
(WorkerDict pid=49813)   "rope_theta": 1000000.0,
(WorkerDict pid=49813)   "sliding_window": 131072,
(WorkerDict pid=49813)   "tie_word_embeddings": false,
(WorkerDict pid=49813)   "torch_dtype": "bfloat16",
(WorkerDict pid=49813)   "transformers_version": "4.51.1",
(WorkerDict pid=49813)   "use_cache": true,
(WorkerDict pid=49813)   "use_sliding_window": false,
(WorkerDict pid=49813)   "vocab_size": 152064
(WorkerDict pid=49813) }
(WorkerDict pid=49813) 
(WorkerDict pid=49813) NCCL version 2.21.5+cuda12.4
(WorkerDict pid=49813) Qwen2ForCausalLM contains 7.62B parameters
(WorkerDict pid=49813) wrap_policy: functools.partial(<function _or_policy at 0x7ef5ed872320>, policies=[functools.partial(<function transformer_auto_wrap_policy at 0x7ef5ed872200>, transformer_layer_cls={<class 'transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer'>})])
(WorkerDict pid=49813) Monkey patch _flash_attention_forward in transformers.integrations.flash_attention [repeated 7x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)
(WorkerDict pid=49813) Actor use_remove_padding=True
(WorkerDict pid=50070) Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Qwen2ForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
(WorkerDict pid=49813) You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`. [repeated 7x across cluster]
(WorkerDict pid=49813) 
Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 63.93it/s] [repeated 4x across cluster]
(WorkerDict pid=49813) [rank0]:[W501 20:50:41.237036992 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 0]  using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id. [repeated 7x across cluster]
(WorkerDict pid=50075) 
Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s] [repeated 2x across cluster]
(WorkerDict pid=50075) 
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 75.93it/s] [repeated 2x across cluster]
(WorkerDict pid=49813) Model config after override: Qwen2Config {
(WorkerDict pid=49813)   "architectures": [
(WorkerDict pid=49813)     "Qwen2ForCausalLM"
(WorkerDict pid=49813)   ],
(WorkerDict pid=49813)   "attention_dropout": 0.0,
(WorkerDict pid=49813)   "eos_token_id": 151645,
(WorkerDict pid=49813)   "hidden_act": "silu",
(WorkerDict pid=49813)   "hidden_size": 3584,
(WorkerDict pid=49813)   "initializer_range": 0.02,
(WorkerDict pid=49813)   "intermediate_size": 18944,
(WorkerDict pid=49813)   "max_position_embeddings": 32768,
(WorkerDict pid=49813)   "max_window_layers": 28,
(WorkerDict pid=49813)   "model_type": "qwen2",
(WorkerDict pid=49813)   "num_attention_heads": 28,
(WorkerDict pid=49813)   "num_hidden_layers": 28,
(WorkerDict pid=49813)   "num_key_value_heads": 4,
(WorkerDict pid=49813)   "pad_token_id": 151643,
(WorkerDict pid=49813)   "rms_norm_eps": 1e-06,
(WorkerDict pid=49813)   "rope_scaling": null,
(WorkerDict pid=49813)   "rope_theta": 1000000.0,
(WorkerDict pid=49813)   "sliding_window": 131072,
(WorkerDict pid=49813)   "tie_word_embeddings": false,
(WorkerDict pid=49813)   "torch_dtype": "bfloat16",
(WorkerDict pid=49813)   "transformers_version": "4.51.1",
(WorkerDict pid=49813)   "use_cache": true,
(WorkerDict pid=49813)   "use_sliding_window": false,
(WorkerDict pid=49813)   "vocab_size": 152064
(WorkerDict pid=49813) }
(WorkerDict pid=49813) 
(WorkerDict pid=49813) 
Loading checkpoint shards:  25%|██▌       | 1/4 [00:02<00:08,  2.88s/it]
(WorkerDict pid=50074) Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Qwen2ForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)` [repeated 7x across cluster]
(WorkerDict pid=50075) 
Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s] [repeated 8x across cluster]
(WorkerDict pid=50074) 
Loading checkpoint shards:  75%|███████▌  | 3/4 [00:08<00:02,  2.88s/it] [repeated 17x across cluster]
(WorkerDict pid=50075) wrap_policy: functools.partial(<function _or_policy at 0x7f4483a6a560>, policies=[functools.partial(<function transformer_auto_wrap_policy at 0x7f4483a6a440>, transformer_layer_cls={<class 'transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer'>})]) [repeated 7x across cluster]
(WorkerDict pid=49813) Monkey patch _flash_attention_forward in transformers.integrations.flash_attention
(WorkerDict pid=50075) Actor use_remove_padding=True [repeated 7x across cluster]
(WorkerDict pid=49813) 
Loading checkpoint shards: 100%|██████████| 4/4 [00:10<00:00,  2.59s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [00:10<00:00,  2.62s/it]
(WorkerDict pid=50074) Monkey patch _flash_attention_forward in transformers.integrations.flash_attention
(WorkerDict pid=49813) Qwen2ForCausalLM contains 7.62B parameters
(WorkerDict pid=50070) Total steps: 4350, num_warmup_steps: 0
(WorkerDict pid=50075) wrap_policy: functools.partial(<function _or_policy at 0x7f4483a6a560>, policies=[functools.partial(<function transformer_auto_wrap_policy at 0x7f4483a6a440>, transformer_layer_cls={<class 'transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer'>})]) [repeated 8x across cluster]
(WorkerDict pid=49813) Actor use_remove_padding=True [repeated 8x across cluster]
(WorkerDict pid=50070) /usr/local/python/lib/python3.10/site-packages/vllm/connections.py:8: RuntimeWarning: Failed to read commit hash:
(WorkerDict pid=50070) No module named 'vllm._version'
(WorkerDict pid=50070)   from vllm.version import __version__ as VLLM_VERSION
(WorkerDict pid=50071) 
Loading checkpoint shards:  75%|███████▌  | 3/4 [00:10<00:03,  3.34s/it] [repeated 6x across cluster]
(WorkerDict pid=50071) 
Loading checkpoint shards: 100%|██████████| 4/4 [00:13<00:00,  3.17s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [00:13<00:00,  3.25s/it] [repeated 7x across cluster]
(WorkerDict pid=50074) NCCL version 2.21.5+cuda12.4
(WorkerDict pid=50071) Monkey patch _flash_attention_forward in transformers.integrations.flash_attention [repeated 6x across cluster]
(WorkerDict pid=49813) Before building sglang_async rollout, memory allocated (GB): 3.55, memory reserved (GB): 17.57, device memory used/total (GB): 20.63/95.00
(WorkerDict pid=50070) kwargs: {'n': 16, 'max_new_tokens': 1024, 'presence_penalty': 0.0, 'frequency_penalty': 0.0, 'repetition_penalty': 1.0, 'temperature': 1.0, 'top_k': -1, 'top_p': 1, 'ignore_eos': False}
(WorkerDict pid=49813) Total steps: 4350, num_warmup_steps: 0 [repeated 7x across cluster]
(WorkerDict pid=50070) /usr/local/python/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
(WorkerDict pid=50070)   warnings.warn(
(WorkerDict pid=50071) NCCL version 2.21.5+cuda12.4 [repeated 2x across cluster]
(WorkerDict pid=49813) /usr/local/python/lib/python3.10/site-packages/vllm/connections.py:8: RuntimeWarning: Failed to read commit hash: [repeated 7x across cluster]
(WorkerDict pid=49813) No module named 'vllm._version' [repeated 7x across cluster]
(WorkerDict pid=49813)   from vllm.version import __version__ as VLLM_VERSION [repeated 7x across cluster]
(WorkerDict pid=50075) /usr/local/python/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html . [repeated 3x across cluster]
(WorkerDict pid=50075)   warnings.warn( [repeated 3x across cluster]
(WorkerDict pid=50071) /usr/local/python/lib/python3.10/site-packages/vllm/connections.py:8: RuntimeWarning: Failed to read commit hash: [repeated 4x across cluster]
(WorkerDict pid=50071) No module named 'vllm._version' [repeated 4x across cluster]
(WorkerDict pid=50071)   from vllm.version import __version__ as VLLM_VERSION [repeated 4x across cluster]
(WorkerDict pid=50075) kwargs: {'n': 16, 'max_new_tokens': 1024, 'presence_penalty': 0.0, 'frequency_penalty': 0.0, 'repetition_penalty': 1.0, 'temperature': 1.0, 'top_k': -1, 'top_p': 1, 'ignore_eos': False} [repeated 3x across cluster]
(WorkerDict pid=49813) NCCL version 2.21.5+cuda12.4
(WorkerDict pid=50071) NCCL version 2.21.5+cuda12.4
(WorkerDict pid=50076) /usr/local/python/lib/python3.10/site-packages/vllm/connections.py:8: RuntimeWarning: Failed to read commit hash: [repeated 3x across cluster]
(WorkerDict pid=50076) No module named 'vllm._version' [repeated 3x across cluster]
(WorkerDict pid=50076)   from vllm.version import __version__ as VLLM_VERSION [repeated 3x across cluster]
(WorkerDict pid=49813) 
  0%|          | 0/35 [00:00<?, ?it/s]
Capturing batches (avail_mem=35.53 GB):   0%|          | 0/35 [00:00<?, ?it/s]
(WorkerDict pid=49813) 
Capturing batches (avail_mem=35.53 GB):   3%|▎         | 1/35 [00:00<00:23,  1.45it/s]
Capturing batches (avail_mem=35.08 GB):   3%|▎         | 1/35 [00:00<00:23,  1.45it/s]
(WorkerDict pid=49813) 
Capturing batches (avail_mem=35.08 GB):   6%|▌         | 2/35 [00:01<00:19,  1.73it/s]
Capturing batches (avail_mem=34.91 GB):   6%|▌         | 2/35 [00:01<00:19,  1.73it/s]
(WorkerDict pid=50074) NCCL version 2.21.5+cuda12.4
(WorkerDict pid=50076) NCCL version 2.21.5+cuda12.4
(WorkerDict pid=49813) 
Capturing batches (avail_mem=34.91 GB):   9%|▊         | 3/35 [00:01<00:19,  1.62it/s]
Capturing batches (avail_mem=34.74 GB):   9%|▊         | 3/35 [00:01<00:19,  1.62it/s]
(WorkerDict pid=50076) /usr/local/python/lib/python3.10/site-packages/vllm/connections.py:8: RuntimeWarning: Failed to read commit hash: [repeated 5x across cluster]
(WorkerDict pid=50076) No module named 'vllm._version' [repeated 5x across cluster]
(WorkerDict pid=50076)   from vllm.version import __version__ as VLLM_VERSION [repeated 5x across cluster]
(WorkerDict pid=50071) 
  0%|          | 0/35 [00:00<?, ?it/s]
Capturing batches (avail_mem=35.96 GB):   0%|          | 0/35 [00:00<?, ?it/s]
(WorkerDict pid=50071) 
Capturing batches (avail_mem=34.69 GB):  20%|██        | 7/35 [00:04<00:16,  1.71it/s]
Capturing batches (avail_mem=34.54 GB):  20%|██        | 7/35 [00:04<00:16,  1.71it/s] [repeated 13x across cluster]
(WorkerDict pid=50074) 
  0%|          | 0/35 [00:00<?, ?it/s]
Capturing batches (avail_mem=35.96 GB):   0%|          | 0/35 [00:00<?, ?it/s]
(WorkerDict pid=50076) 
Capturing batches (avail_mem=34.54 GB):  23%|██▎       | 8/35 [00:03<00:11,  2.26it/s]
Capturing batches (avail_mem=34.39 GB):  23%|██▎       | 8/35 [00:03<00:11,  2.26it/s] [repeated 40x across cluster]
(WorkerDict pid=50076) 
  0%|          | 0/35 [00:00<?, ?it/s]
Capturing batches (avail_mem=35.96 GB):   0%|          | 0/35 [00:00<?, ?it/s]
(WorkerDict pid=49813) 
Capturing batches (avail_mem=32.36 GB):  91%|█████████▏| 32/35 [00:15<00:01,  2.34it/s]
Capturing batches (avail_mem=32.35 GB):  91%|█████████▏| 32/35 [00:15<00:01,  2.34it/s]
(WorkerDict pid=49813) 
Capturing batches (avail_mem=32.35 GB):  94%|█████████▍| 33/35 [00:15<00:00,  2.33it/s]
Capturing batches (avail_mem=32.35 GB):  94%|█████████▍| 33/35 [00:15<00:00,  2.33it/s]
(WorkerDict pid=50076) 
Capturing batches (avail_mem=33.24 GB):  57%|█████▋    | 20/35 [00:08<00:06,  2.38it/s]
Capturing batches (avail_mem=33.16 GB):  57%|█████▋    | 20/35 [00:08<00:06,  2.38it/s] [repeated 47x across cluster]
(WorkerDict pid=49813) kwargs: {'n': 16, 'max_new_tokens': 1024, 'presence_penalty': 0.0, 'frequency_penalty': 0.0, 'repetition_penalty': 1.0, 'temperature': 1.0, 'top_k': -1, 'top_p': 1, 'ignore_eos': False}
(WorkerDict pid=49813) After building sglang_async rollout, memory allocated (GB): 3.55, memory reserved (GB): 17.57, device memory used/total (GB): 26.44/95.00
(WorkerDict pid=49813) After building sharding manager, memory allocated (GB): 3.55, memory reserved (GB): 17.57, device memory used/total (GB): 26.44/95.00
(WorkerDict pid=49813) /usr/local/python/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
(WorkerDict pid=49813)   warnings.warn(
(WorkerDict pid=50074) 
Capturing batches (avail_mem=32.78 GB):  94%|█████████▍| 33/35 [00:13<00:00,  2.49it/s]
Capturing batches (avail_mem=32.78 GB):  94%|█████████▍| 33/35 [00:13<00:00,  2.49it/s] [repeated 8x across cluster]
(WorkerDict pid=50076) 
Capturing batches (avail_mem=32.80 GB):  89%|████████▊ | 31/35 [00:13<00:01,  2.49it/s]
Capturing batches (avail_mem=32.79 GB):  89%|████████▊ | 31/35 [00:13<00:01,  2.49it/s] [repeated 21x across cluster]
(WorkerDict pid=50076) kwargs: {'n': 16, 'max_new_tokens': 1024, 'presence_penalty': 0.0, 'frequency_penalty': 0.0, 'repetition_penalty': 1.0, 'temperature': 1.0, 'top_k': -1, 'top_p': 1, 'ignore_eos': False} [repeated 3x across cluster]
(WorkerDict pid=50076) /usr/local/python/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html . [repeated 3x across cluster]
(WorkerDict pid=50076)   warnings.warn( [repeated 3x across cluster]
(TaskRunner pid=362057) Using LocalLogger is deprecated. The constructor API will change 
(TaskRunner pid=362057) Checkpoint tracker file does not exist: %s /tmp/ray/session_2025-05-01_20-49-28_274415_338741/runtime_resources/working_dir_files/_ray_pkg_4ce66e647b0fe766/checkpoints/test/latest_checkpointed_iteration.txt
(TaskRunner pid=362057) Training from scratch
(TaskRunner pid=362057) test_gen_batch meta info: {'eos_token_id': 151645, 'pad_token_id': 151643, 'recompute_log_prob': False, 'do_sample': False, 'validate': True}
(WorkerDict pid=49813) /usr/local/python/lib/python3.10/site-packages/sglang/srt/utils.py:888: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:203.)
(WorkerDict pid=49813)   tensor_data = torch.ByteTensor(
(WorkerDict pid=50076) 
Capturing batches (avail_mem=32.77 GB): 100%|██████████| 35/35 [00:15<00:00,  2.44it/s]
Capturing batches (avail_mem=32.77 GB): 100%|██████████| 35/35 [00:15<00:00,  2.32it/s] [repeated 6x across cluster]
(WorkerDict pid=50071) /usr/local/python/lib/python3.10/site-packages/sglang/srt/utils.py:888: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:203.) [repeated 3x across cluster]
(WorkerDict pid=50071)   tensor_data = torch.ByteTensor( [repeated 3x across cluster]
(TaskRunner pid=362057) validation generation end
(TaskRunner pid=362057) [prompt] system
(TaskRunner pid=362057) 
(TaskRunner pid=362057)                             You are a math expert. You are given a question and you need to solve it step by step.  
(TaskRunner pid=362057)                             `calc_gsm8k_reward` is a tool for calculating the reward of gsm8k. You should use this 
(TaskRunner pid=362057)                             tool to calculate the reward of your answer(1.0 if your answer is correct, 0.0 if your 
(TaskRunner pid=362057)                             answer is incorrect) before submitting it and refine your answer if necessary. Put your 
(TaskRunner pid=362057)                             final answer in the format of `#### <answer>`.
(TaskRunner pid=362057) 
(TaskRunner pid=362057) # Tools
(TaskRunner pid=362057) 
(TaskRunner pid=362057) You may call one or more functions to assist with the user query.
(TaskRunner pid=362057) 
(TaskRunner pid=362057) You are provided with function signatures within <tools></tools> XML tags:
(TaskRunner pid=362057) <tools>
(TaskRunner pid=362057) {"type": "function", "function": {"name": "calc_gsm8k_reward", "description": "A tool for calculating the reward of gsm8k. (1.0 if your answer is correct, 0.0 if your answer is incorrect)", "parameters": {"type": "object", "properties": {"answer": {"type": "string", "description": "The model's answer to the GSM8K math problem, must be a digits", "enum": null}}, "required": ["answer"]}, "strict": false}}
(TaskRunner pid=362057) </tools>
(TaskRunner pid=362057) 
(TaskRunner pid=362057) For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
(TaskRunner pid=362057) <tool_call>
(TaskRunner pid=362057) {"name": <function-name>, "arguments": <args-json-object>}
(TaskRunner pid=362057) </tool_call>
(TaskRunner pid=362057) user
(TaskRunner pid=362057) Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? 
(TaskRunner pid=362057)         You must use the `calc_gsm8k_reward` tool to calculate the reward 
(TaskRunner pid=362057)         of your answer(1.0 if your answer is correct, 0.0 if your answer is incorrect) 
(TaskRunner pid=362057)         before submitting it at least once and refine your answer if necessary. 
(TaskRunner pid=362057)         Put your final answer in the format of `#### <answer>`.
(TaskRunner pid=362057)     
(TaskRunner pid=362057) assistant
(TaskRunner pid=362057) 
(TaskRunner pid=362057) [response] To find out how much Janet makes every day at the farmers' market, we need to follow these steps:
(TaskRunner pid=362057) 
(TaskRunner pid=362057) 1. Calculate the total number of eggs laid by the ducks per day.
(TaskRunner pid=362057) 2. Subtract the number of eggs Janet eats for breakfast and the number of eggs she uses for baking muffins.
(TaskRunner pid=362057) 3. Multiply the remaining number of eggs by the price per egg to find out how much she makes at the farmers' market.
(TaskRunner pid=362057) 
(TaskRunner pid=362057) Let's do the calculations:
(TaskRunner pid=362057) 
(TaskRunner pid=362057) 1. Total number of eggs laid per day: 16
(TaskRunner pid=362057) 2. Number of eggs Janet eats for breakfast: 3
(TaskRunner pid=362057) 3. Number of eggs Janet uses for baking muffins: 4
(TaskRunner pid=362057) 4. Remaining number of eggs: 16 - 3 - 4 = 9
(TaskRunner pid=362057) 5. Price per egg: $2
(TaskRunner pid=362057) 
(TaskRunner pid=362057) Now, let's calculate the total amount Janet makes at the farmers' market:
(TaskRunner pid=362057) 
(TaskRunner pid=362057) \[ \text{Total amount} = \text{Remaining number of eggs} \times \text{Price per egg} \]
(TaskRunner pid=362057) \[ \text{Total amount} = 9 \times 2 = 18 \]
(TaskRunner pid=362057) 
(TaskRunner pid=362057) So, Janet makes $18 every day at the farmers' market.
(TaskRunner pid=362057) 
(TaskRunner pid=362057) Now, let's use the `calc_gsm8k_reward` tool to check the correctness of our answer.
(TaskRunner pid=362057) <tool_call>
(TaskRunner pid=362057) {"name": "calc_gsm8k_reward", "arguments": "{\"answer\": \"18\"}"}
(TaskRunner pid=362057) </tool_call>
(TaskRunner pid=362057) tool
(TaskRunner pid=362057) Current parsed answer='18' reward=1.0
(TaskRunner pid=362057) assistant
(TaskRunner pid=362057) #### 18
(TaskRunner pid=362057) [ground_truth] 18
(TaskRunner pid=362057) [score] 1.0
(TaskRunner pid=362057) 'Initial validation metrics: {}'
(TaskRunner pid=362057) step:0
(TaskRunner pid=362057) 
Training Progress:   0%|          | 0/4350 [00:00<?, ?it/s]
(TaskRunner pid=362057) list(reward_extra_infos_dict.keys())=[]
(TaskRunner pid=362057) step:1 - global_seqlen/min:379680.000 - global_seqlen/max:457144.000 - global_seqlen/minmax_diff:77464.000 - global_seqlen/balanced_min:416128.000 - global_seqlen/balanced_max:416129.000 - global_seqlen/mean:416128.750 - actor/entropy_loss:0.216 - actor/kl_loss:0.000 - actor/kl_coef:0.001 - actor/pg_loss:-0.004 - actor/pg_clipfrac:0.000 - actor/ppo_kl:0.000 - actor/pg_clipfrac_lower:0.000 - actor/grad_norm:0.010 - perf/mfu/actor:0.700 - perf/max_memory_allocated_gb:49.644 - perf/max_memory_reserved_gb:80.434 - perf/cpu_memory_used_gb:108.922 - actor/lr:0.000 - training/global_step:1.000 - training/epoch:0.000 - critic/score/mean:0.905 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.905 - critic/rewards/max:1.000 - critic/rewards/min:0.000 - critic/advantages/mean:-0.000 - critic/advantages/max:0.653 - critic/advantages/min:-1.436 - critic/returns/mean:-0.000 - critic/returns/max:0.653 - critic/returns/min:-1.436 - response_length/mean:364.025 - response_length/max:1024.000 - response_length/min:150.000 - response_length/clip_ratio:0.026 - prompt_length/mean:448.727 - prompt_length/max:533.000 - prompt_length/min:410.000 - prompt_length/clip_ratio:0.000 - timing_s/gen:118.270 - timing_s/reward:1.181 - timing_s/old_log_prob:52.755 - timing_s/ref:54.759 - timing_s/adv:0.059 - timing_s/update_actor:188.158 - timing_s/step:415.307 - timing_per_token_ms/gen:0.079 - timing_per_token_ms/update_actor:0.057 - timing_per_token_ms/adv:0.000 - timing_per_token_ms/ref:0.016 - perf/total_num_tokens:3329030.000 - perf/time_per_step:415.307 - perf/throughput:1001.979
(TaskRunner pid=362057) 
Training Progress:   0%|          | 1/4350 [06:57<504:26:22, 417.56s/it]
(TaskRunner pid=362057) list(reward_extra_infos_dict.keys())=[]
(WorkerDict pid=49813) WARN: rank 0 grad_norm is not finite: nan
(TaskRunner pid=362057) step:2 - global_seqlen/min:407422.000 - global_seqlen/max:440440.000 - global_seqlen/minmax_diff:33018.000 - global_seqlen/balanced_min:418815.000 - global_seqlen/balanced_max:418816.000 - global_seqlen/mean:418815.625 - actor/entropy_loss:0.218 - actor/kl_loss:0.000 - actor/kl_coef:0.001 - actor/pg_loss:0.006 - actor/pg_clipfrac:0.000 - actor/ppo_kl:0.000 - actor/pg_clipfrac_lower:0.000 - actor/grad_norm:nan - perf/mfu/actor:0.703 - perf/max_memory_allocated_gb:54.050 - perf/max_memory_reserved_gb:80.617 - perf/cpu_memory_used_gb:109.974 - actor/lr:0.000 - training/global_step:2.000 - training/epoch:0.000 - critic/score/mean:0.903 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.903 - critic/rewards/max:1.000 - critic/rewards/min:0.000 - critic/advantages/mean:-0.001 - critic/advantages/max:1.436 - critic/advantages/min:-1.677 - critic/returns/mean:-0.001 - critic/returns/max:1.436 - critic/returns/min:-1.677 - response_length/mean:368.902 - response_length/max:1024.000 - response_length/min:156.000 - response_length/clip_ratio:0.028 - prompt_length/mean:449.098 - prompt_length/max:559.000 - prompt_length/min:409.000 - prompt_length/clip_ratio:0.000 - timing_s/gen:112.222 - timing_s/reward:1.169 - timing_s/old_log_prob:50.543 - timing_s/ref:49.970 - timing_s/adv:0.057 - timing_s/update_actor:188.153 - timing_s/step:402.215 - timing_per_token_ms/gen:0.074 - timing_per_token_ms/update_actor:0.056 - timing_per_token_ms/adv:0.000 - timing_per_token_ms/ref:0.015 - perf/total_num_tokens:3350525.000 - perf/time_per_step:402.215 - perf/throughput:1041.273
(TaskRunner pid=362057) 
Training Progress:   0%|          | 2/4350 [13:40<493:36:46, 408.70s/it]
(TaskRunner pid=362057) list(reward_extra_infos_dict.keys())=[]
(WorkerDict pid=50075) WARN: rank 5 grad_norm is not finite: nan [repeated 7x across cluster]
(WorkerDict pid=49813) WARN: rank 0 grad_norm is not finite: nan
(WorkerDict pid=50070) WARN: rank 1 grad_norm is not finite: nan
(TaskRunner pid=362057) step:3 - global_seqlen/min:387116.000 - global_seqlen/max:433827.000 - global_seqlen/minmax_diff:46711.000 - global_seqlen/balanced_min:416119.000 - global_seqlen/balanced_max:416120.000 - global_seqlen/mean:416119.125 - actor/entropy_loss:0.212 - actor/kl_loss:0.000 - actor/kl_coef:0.001 - actor/pg_loss:-0.001 - actor/pg_clipfrac:0.000 - actor/ppo_kl:-0.000 - actor/pg_clipfrac_lower:0.000 - actor/grad_norm:nan - perf/mfu/actor:0.709 - perf/max_memory_allocated_gb:54.534 - perf/max_memory_reserved_gb:80.617 - perf/cpu_memory_used_gb:109.965 - actor/lr:0.000 - training/global_step:3.000 - training/epoch:0.000 - critic/score/mean:0.890 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.890 - critic/rewards/max:1.000 - critic/rewards/min:0.000 - critic/advantages/mean:-0.000 - critic/advantages/max:0.250 - critic/advantages/min:-3.750 - critic/returns/mean:-0.000 - critic/returns/max:0.250 - critic/returns/min:-3.750 - response_length/mean:364.510 - response_length/max:1024.000 - response_length/min:150.000 - response_length/clip_ratio:0.020 - prompt_length/mean:448.223 - prompt_length/max:534.000 - prompt_length/min:411.000 - prompt_length/clip_ratio:0.000 - timing_s/gen:125.431 - timing_s/reward:1.166 - timing_s/old_log_prob:49.390 - timing_s/ref:49.245 - timing_s/adv:0.058 - timing_s/update_actor:185.307 - timing_s/step:410.658 - timing_per_token_ms/gen:0.084 - timing_per_token_ms/update_actor:0.056 - timing_per_token_ms/adv:0.000 - timing_per_token_ms/ref:0.015 - perf/total_num_tokens:3328953.000 - perf/time_per_step:410.658 - perf/throughput:1013.299
(TaskRunner pid=362057) 
Training Progress:   0%|          | 3/4350 [20:30<494:43:43, 409.71s/it]
(TaskRunner pid=362057) list(reward_extra_infos_dict.keys())=[]
(WorkerDict pid=50075) WARN: rank 5 grad_norm is not finite: nan [repeated 6x across cluster]
(WorkerDict pid=49813) WARN: rank 0 grad_norm is not finite: nan
(WorkerDict pid=50070) WARN: rank 1 grad_norm is not finite: nan
(TaskRunner pid=362057) step:4 - global_seqlen/min:408386.000 - global_seqlen/max:454859.000 - global_seqlen/minmax_diff:46473.000 - global_seqlen/balanced_min:422598.000 - global_seqlen/balanced_max:422599.000 - global_seqlen/mean:422598.750 - actor/entropy_loss:0.216 - actor/kl_loss:0.000 - actor/kl_coef:0.001 - actor/pg_loss:0.003 - actor/pg_clipfrac:0.000 - actor/ppo_kl:0.000 - actor/pg_clipfrac_lower:0.000 - actor/grad_norm:nan - perf/mfu/actor:0.705 - perf/max_memory_allocated_gb:55.359 - perf/max_memory_reserved_gb:80.617 - perf/cpu_memory_used_gb:110.119 - actor/lr:0.000 - training/global_step:4.000 - training/epoch:0.000 - critic/score/mean:0.872 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.872 - critic/rewards/max:1.000 - critic/rewards/min:0.000 - critic/advantages/mean:-0.000 - critic/advantages/max:2.016 - critic/advantages/min:-0.465 - critic/returns/mean:-0.000 - critic/returns/max:2.016 - critic/returns/min:-0.465 - response_length/mean:377.724 - response_length/max:1024.000 - response_length/min:95.000 - response_length/clip_ratio:0.029 - prompt_length/mean:447.664 - prompt_length/max:550.000 - prompt_length/min:415.000 - prompt_length/clip_ratio:0.000 - timing_s/gen:133.983 - timing_s/reward:1.183 - timing_s/old_log_prob:50.797 - timing_s/ref:50.276 - timing_s/adv:0.059 - timing_s/update_actor:190.039 - timing_s/step:426.396 - timing_per_token_ms/gen:0.087 - timing_per_token_ms/update_actor:0.056 - timing_per_token_ms/adv:0.000 - timing_per_token_ms/ref:0.015 - perf/total_num_tokens:3380790.000 - perf/time_per_step:426.396 - perf/throughput:991.094
(TaskRunner pid=362057) 

The versions of my Python packages are as follows:

Package                           Version       Editable project location
--------------------------------- ------------- -------------------------
absl-py                           2.1.0
accelerate                        1.6.0
aiohappyeyeballs                  2.6.1
aiohttp                           3.11.18
aiohttp-cors                      0.8.1
aiosignal                         1.3.2
annotated-types                   0.7.0
anthropic                         0.50.0
antlr4-python3-runtime            4.9.3
anyio                             4.9.0
asttokens                         3.0.0
astunparse                        1.6.3
async-timeout                     5.0.1
attrs                             25.3.0
beautifulsoup4                    4.13.4
blobfile                          3.0.0
boto3                             1.36.16
botocore                          1.36.17
cachetools                        5.5.2
certifi                           2025.4.26
cffi                              1.17.1
cfgv                              3.4.0
chardet                           5.2.0
charset-normalizer                3.4.1
click                             8.1.8
cloudpickle                       3.1.1
cmake                             3.31.4
codetiming                        1.4.0
colorful                          0.5.6
compressed-tensors                0.9.4
cuda-bindings                     12.8.0
cuda-python                       12.8.0
datasets                          3.5.1
decorator                         5.2.1
decord                            0.6.0
dill                              0.3.8
diskcache                         5.6.3
distlib                           0.3.9
distro                            1.9.0
docker-pycreds                    0.4.0
duckduckgo_search                 8.0.1
einops                            0.8.1
einops-exts                       0.0.4
exceptiongroup                    1.2.2
executing                         2.2.0
expecttest                        0.3.0
fastapi                           0.115.12
filelock                          3.18.0
flamingo-pytorch                  0.1.2
flash_attn                        2.7.4.post1
flashinfer-python                 0.2.3
frozenlist                        1.6.0
fsspec                            2025.3.0
ftfy                              6.3.1
gguf                              0.10.0
gitdb                             4.0.12
GitPython                         3.1.44
google-api-core                   2.24.2
google-auth                       2.39.0
googleapis-common-protos          1.70.0
grpcio                            1.71.0
h11                               0.16.0
hf_transfer                       0.1.9
httpcore                          1.0.9
httptools                         0.6.4
httpx                             0.28.1
huggingface-hub                   0.30.2
hydra-core                        1.3.2
hypothesis                        6.125.2
identify                          2.6.10
idna                              3.10
importlib_metadata                8.7.0
interegular                       0.3.3
ipython                           8.36.0
jedi                              0.19.2
Jinja2                            3.1.6
jiter                             0.9.0
jmespath                          1.0.1
jsonschema                        4.23.0
jsonschema-specifications         2025.4.1
lark                              1.2.2
liger_kernel                      0.5.8
lintrunner                        0.12.7
litellm                           1.67.5
llguidance                        0.7.19
llvmlite                          0.44.0
lm-format-enforcer                0.10.6
lxml                              5.4.0
Markdown                          3.7
markdown-it-py                    3.0.0
markdownify                       1.1.0
MarkupSafe                        3.0.2
matplotlib-inline                 0.1.7
mdurl                             0.1.2
mistral_common                    1.5.4
modelscope                        1.25.0
mpmath                            1.3.0
msgpack                           1.1.0
msgspec                           0.19.0
multidict                         6.4.3
multiprocess                      0.70.16
nanobind                          2.7.0
nest-asyncio                      1.6.0
networkx                          3.4.2
ninja                             1.11.1.3
nodeenv                           1.9.1
numba                             0.61.2
numpy                             1.26.4
nvidia-cublas-cu12                12.4.5.8
nvidia-cuda-cupti-cu12            12.4.127
nvidia-cuda-nvrtc-cu12            12.4.127
nvidia-cuda-runtime-cu12          12.4.127
nvidia-cudnn-cu12                 9.1.0.70
nvidia-cufft-cu12                 11.2.1.3
nvidia-curand-cu12                10.3.5.147
nvidia-cusolver-cu12              11.6.1.9
nvidia-cusparse-cu12              12.3.1.170
nvidia-cusparselt-cu12            0.6.2
nvidia-ml-py                      12.570.86
nvidia-nccl-cu12                  2.21.5
nvidia-nvjitlink-cu12             12.4.127
nvidia-nvtx-cu12                  12.4.127
omegaconf                         2.3.0
open_clip_torch                   2.30.0
openai                            1.76.2
opencensus                        0.11.4
opencensus-context                0.1.3
opencv-contrib-python             4.11.0.86
opencv-python                     4.11.0.86
opencv-python-headless            4.11.0.86
optree                            0.14.0
orjson                            3.10.18
outlines                          0.0.46
packaging                         25.0
pandas                            2.2.3
parso                             0.8.4
partial-json-parser               0.2.1.1.post5
peft                              0.15.2
pexpect                           4.9.0
pillow                            11.1.0
pip                               25.0.1
platformdirs                      4.3.7
pre_commit                        4.2.0
primp                             0.15.0
prometheus_client                 0.21.1
prometheus-fastapi-instrumentator 7.1.0
prompt_toolkit                    3.0.51
propcache                         0.3.1
proto-plus                        1.26.1
protobuf                          6.30.2
psutil                            7.0.0
ptyprocess                        0.7.0
pure_eval                         0.2.3
py-cpuinfo                        9.0.0
py-spy                            0.4.0
pyairports                        2.1.1
pyarrow                           20.0.0
pyasn1                            0.6.1
pyasn1_modules                    0.4.2
pybind11                          2.13.6
pycountry                         24.6.1
pycparser                         2.22
pycryptodomex                     3.22.0
pydantic                          2.11.4
pydantic_core                     2.33.2
Pygments                          2.19.1
pylatexenc                        2.10
pynvml                            12.0.0
python-dateutil                   2.9.0.post0
python-dotenv                     1.1.0
python-multipart                  0.0.20
pytz                              2025.2
PyYAML                            6.0.2
pyzmq                             26.4.0
ray                               2.45.0
referencing                       0.36.2
regex                             2024.11.6
requests                          2.32.3
rich                              14.0.0
rpds-py                           0.24.0
rsa                               4.9.1
ruamel.yaml                       0.18.10
ruamel.yaml.clib                  0.2.12
s3transfer                        0.11.2
safetensors                       0.5.3
sentencepiece                     0.2.0
sentry-sdk                        2.27.0
setproctitle                      1.3.6
setuptools                        80.1.0
sgl-kernel                        0.0.9.post2
sglang                            0.4.5.post3
six                               1.17.0
smart-open                        7.1.0
smmap                             5.0.2
smolagents                        1.14.0
sniffio                           1.3.1
sortedcontainers                  2.4.0
soundfile                         0.13.1
soupsieve                         2.7
sox                               1.5.0
stack-data                        0.6.3
starlette                         0.46.2
sympy                             1.13.1
tensorboard                       2.18.0
tensorboard-data-server           0.7.2
tensorboardX                      2.6.2.2
tensordict                        0.6.2
termcolor                         2.5.0
tiktoken                          0.8.0
timm                              1.0.14
tokenizers                        0.21.1
torch                             2.6.0+cu124
torch_memory_saver                0.0.5
torchao                           0.10.0
torchdata                         0.11.0
torchvision                       0.21.0
tqdm                              4.67.1
traitlets                         5.14.3
transformers                      4.51.1
triton                            3.2.0
types-dataclasses                 0.6.6
typing_extensions                 4.13.2
typing-inspection                 0.4.0
tzdata                            2025.2
urllib3                           2.4.0
uvicorn                           0.34.2
uvloop                            0.21.0
verl                              0.2.0.dev0    /root/verl
virtualenv                        20.30.0
vllm                              0.6.3
wandb                             0.19.10
watchfiles                        1.0.5
wcwidth                           0.2.13
websockets                        15.0.1
Werkzeug                          3.1.3
wheel                             0.45.1
wrapt                             1.17.2
xformers                          0.0.27.post2
xgrammar                          0.1.17
xxhash                            3.5.0
yarl                              1.20.0
zipp                              3.21.0

My CUDA version is Cuda compilation tools, release 12.4, V12.4.131.

The specific execution command is:

export VLLM_ATTENTION_BACKEND=XFORMERS
export NCCL_DEBUG=INFO
export GLOO_DEBUG=1

export NCCL_IB_GID_INDEX=3
export NCCL_IB_SL=3
export NCCL_CHECK_DISABLE=1
export NCCL_P2P_DISABLE=0
export NCCL_IB_DISABLE=0
export NCCL_LL_THRESHOLD=16384
export NCCL_IB_CUDA_SUPPORT=1
export NCCL_SOCKET_IFNAME=bond1
export GLOO_SOCKET_IFNAME=bond1

export UCX_NET_DEVICES=bond1
export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
export NCCL_COLLNET_ENABLE=0
export SHARP_COLL_ENABLE_SAT=0
export NCCL_NET_GDR_LEVEL=2
export NCCL_IB_QPS_PER_CONNECTION=4
export NCCL_IB_TC=160
export NCCL_PXN_DISABLE=0
export NCCL_DEBUG="INFO"
export HYDRA_FULL_ERROR=1

ray job submit \
    -- python3 -m verl.trainer.main_ppo \
    --config-path="$CONFIG_PATH" \
    --config-name='gsm8k_multiturn_grpo' \
    algorithm.adv_estimator=grpo \
    data.train_batch_size=256 \
    data.max_prompt_length=1024 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    data.return_raw_chat=True \
    actor_rollout_ref.model.path=model/Qwen2.5-7B-Instruct \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=256 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
    actor_rollout_ref.rollout.name=sglang_async \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
    actor_rollout_ref.rollout.n=${rollout_n} \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger=['console','tensorboard'] \
    trainer.project_name="${project_name}" \
    trainer.experiment_name="${exp_name}" \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=1 \
    trainer.save_freq=-1 \
    trainer.test_freq=20 \
    data.train_files=${TRAIN_FILE} \
    data.val_files=${TEST_FILE} \ actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \
    trainer.total_epochs=150

I sincerely hope to get some help. Thank you very much!

Thanks for sharing your settings with early nan. It appears to have some instability during the training process. We'll increase monitoring in wandb and examine it closely.

@eric-haibin-lin eric-haibin-lin mentioned this pull request May 2, 2025
33 tasks
ScottCTD pushed a commit to ScottCTD/verl that referenced this pull request May 5, 2025
…olcengine#1037)

A redesigned version of volcengine#917 

## Current Status
[Develop log &
Tracker](zhaochenyang20/Awesome-ML-SYS-Tutorial#113)

**What Has Been Done**
- Async Rollout Refactoring: Integrate with the tool server to
coordinate tool calls during generation, leveraging request IDs for
state and progress tracking, support async multi-turn conversations in
Agentic RL training (with Tool support).
- Async Request Management: Encapsulate rollout requests into a unified
structure, enabling efficient tracking and handling of concurrent
multi-turn dialogues with chatml style messages.
- Extensible Tools: A modular design for adapt tools in
OpenAIFunctionTool format which is both support by SGLang and vLLM, with
create separate instance, execute when tool call, calc score according
to tool env state and release resource.
- Multi-turn support has been implemented for the GSM8K task (new
version working on). However, training has not yet converged, and we
hope the community could join to investigate the issue.

**What Is WIP**
- [x] Merge loss mask to training process from last version
- [x] Add more user friendly tool config and e2e tests for gsm8k with
tool training
- [ ] We are going to validate our multiturn feature in open-source
sandbox environments.

## Key Features will be introduced in future version

- Integrate a Ray-based agent trainer to enable explicit separation of
the rollout and training pipeline. Provide support for partial rollout
handling and fine-grained request state management.
- Extend the framework to support simulated user interactions (e.g.,
roleplay, interactive feedback) and more complex environment-in-the-loop
RL tasks.

**Future Plan**
[Discussion
Thread](zhaochenyang20/Awesome-ML-SYS-Tutorial#74 (comment))
[RFC
doc](https://github.com/SwordFaith/verl-sglang-dev-log/blob/main/rlhf/verl/multi-turn/veRL-multiturn-rollout-RFC.md)
will be updated soon.

## Contributors & Acknowledgement

- Xiang Long [mid.of.change@gmail.com](mailto:mid.of.change@gmail.com)
@SwordFaith (Design RFC & core-dev of refactor part)
- Yuzhen Zhou [zyzshishui@gmail.com](mailto:zyzshishui@gmail.com)
@zyzshishui (Core-dev)
- Chenyang Zhao [zhaochen20@outlook.com](mailto:zhaochen20@outlook.com)
@zhaochenyang20 (PM)
- Guanhua Wang @WANG-GH 
- Junrong Lin @ocss884 (verl-sglang support)
- Hanchen Zhang
[zhanghanchen77@gmail.com](mailto:zhanghanchen77@gmail.com)
- Haoran Wang [ubecwang@gmail.com](mailto:ubecwang@gmail.com)
- Rui Lu [learningrate1@gmail.com](mailto:learningrate1@gmail.com)
- Yujiang Li [liyujiang2020@gmail.com](mailto:liyujiang2020@gmail.com)
- Jiajun Li [guapisolo@gmail.com](mailto:guapisolo@gmail.com)
- Jin Pan [jpan236@wisc.edu](mailto:jpan236@wisc.edu)
- Zhi Zheng [zhengzhi@modelbest.cn](mailto:zhengzhi@modelbest.cn)
@zh-zheng

---------

Co-authored-by: zyzshishui <492129152@qq.com>
Co-authored-by: guanhua <281484683@qq.com>
Co-authored-by: zhaochenyang20 <zhaochen20@outlook.com>
Co-authored-by: ocss884 <ocss.lin@gmail.com>
Co-authored-by: Shawn/Yuxuan Tong <tongyuxuan361@gmail.com>
Co-authored-by: HL <linhaibin.eric@gmail.com>
@jiani-huang
Copy link

Hello, I used verl/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh for training, and the grad_norm became "nan" starting from the second step and afterwards.
The specific log is as follows:

Filtering prompts longer than 1024 tokens:   0%|          | 0/7473 [00:00<?, ? examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  13%|█▎        | 1000/7473 [00:00<00:03, 1835.43 examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  27%|██▋       | 2000/7473 [00:01<00:03, 1603.15 examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  40%|████      | 3000/7473 [00:01<00:02, 1757.81 examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  54%|█████▎    | 4000/7473 [00:02<00:01, 1834.26 examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  67%|██████▋   | 5000/7473 [00:02<00:01, 1883.81 examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  80%|████████  | 6000/7473 [00:03<00:00, 1914.56 examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  94%|█████████▎| 7000/7473 [00:03<00:00, 1933.69 examples/s]
(TaskRunner pid=362057) filter dataset len: 7473
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens: 100%|██████████| 7473/7473 [00:03<00:00, 1942.47 examples/s]
Filtering prompts longer than 1024 tokens: 100%|██████████| 7473/7473 [00:03<00:00, 1871.16 examples/s]
(TaskRunner pid=362057) dataset len: 1319
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:   0%|          | 0/1319 [00:00<?, ? examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  76%|███████▌  | 1000/1319 [00:00<00:00, 1962.69 examples/s]
(TaskRunner pid=362057) filter dataset len: 1319
(TaskRunner pid=362057) Size of train dataloader: 29, Size of val dataloader: 1
(TaskRunner pid=362057) Total training steps: 4350
(TaskRunner pid=362057) colocated worker base class <class 'verl.single_controller.base.worker.Worker'>
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens: 100%|██████████| 1319/1319 [00:00<00:00, 1947.57 examples/s]
Filtering prompts longer than 1024 tokens: 100%|██████████| 1319/1319 [00:00<00:00, 1947.28 examples/s]
(TaskRunner pid=362057) DeprecationWarning: `ray.state.available_resources_per_node` is a private attribute and access will be removed in a future Ray version.
(TaskRunner pid=362057) WARNING:2025-05-01 20:50:26,029:Waiting for register center actor w58LT1_register_center to be ready. Elapsed time: 0 seconds out of 300 seconds.
(WorkerDict pid=50072) Monkey patch _flash_attention_forward in transformers.integrations.flash_attention
(WorkerDict pid=50072) You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
(WorkerDict pid=50072) 
Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 85.04it/s]
(WorkerDict pid=50072) [rank3]:[W501 20:50:39.657321419 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 3]  using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
(WorkerDict pid=50070) 
Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]
(WorkerDict pid=50070) 
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 79.79it/s]
(WorkerDict pid=49813) Model config after override: Qwen2Config {
(WorkerDict pid=49813)   "architectures": [
(WorkerDict pid=49813)     "Qwen2ForCausalLM"
(WorkerDict pid=49813)   ],
(WorkerDict pid=49813)   "attention_dropout": 0.0,
(WorkerDict pid=49813)   "eos_token_id": 151645,
(WorkerDict pid=49813)   "hidden_act": "silu",
(WorkerDict pid=49813)   "hidden_size": 3584,
(WorkerDict pid=49813)   "initializer_range": 0.02,
(WorkerDict pid=49813)   "intermediate_size": 18944,
(WorkerDict pid=49813)   "max_position_embeddings": 32768,
(WorkerDict pid=49813)   "max_window_layers": 28,
(WorkerDict pid=49813)   "model_type": "qwen2",
(WorkerDict pid=49813)   "num_attention_heads": 28,
(WorkerDict pid=49813)   "num_hidden_layers": 28,
(WorkerDict pid=49813)   "num_key_value_heads": 4,
(WorkerDict pid=49813)   "pad_token_id": 151643,
(WorkerDict pid=49813)   "rms_norm_eps": 1e-06,
(WorkerDict pid=49813)   "rope_scaling": null,
(WorkerDict pid=49813)   "rope_theta": 1000000.0,
(WorkerDict pid=49813)   "sliding_window": 131072,
(WorkerDict pid=49813)   "tie_word_embeddings": false,
(WorkerDict pid=49813)   "torch_dtype": "bfloat16",
(WorkerDict pid=49813)   "transformers_version": "4.51.1",
(WorkerDict pid=49813)   "use_cache": true,
(WorkerDict pid=49813)   "use_sliding_window": false,
(WorkerDict pid=49813)   "vocab_size": 152064
(WorkerDict pid=49813) }
(WorkerDict pid=49813) 
(WorkerDict pid=49813) NCCL version 2.21.5+cuda12.4
(WorkerDict pid=49813) Qwen2ForCausalLM contains 7.62B parameters
(WorkerDict pid=49813) wrap_policy: functools.partial(<function _or_policy at 0x7ef5ed872320>, policies=[functools.partial(<function transformer_auto_wrap_policy at 0x7ef5ed872200>, transformer_layer_cls={<class 'transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer'>})])
(WorkerDict pid=49813) Monkey patch _flash_attention_forward in transformers.integrations.flash_attention [repeated 7x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)
(WorkerDict pid=49813) Actor use_remove_padding=True
(WorkerDict pid=50070) Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Qwen2ForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
(WorkerDict pid=49813) You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`. [repeated 7x across cluster]
(WorkerDict pid=49813) 
Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 63.93it/s] [repeated 4x across cluster]
(WorkerDict pid=49813) [rank0]:[W501 20:50:41.237036992 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 0]  using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id. [repeated 7x across cluster]
(WorkerDict pid=50075) 
Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s] [repeated 2x across cluster]
(WorkerDict pid=50075) 
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 75.93it/s] [repeated 2x across cluster]
(WorkerDict pid=49813) Model config after override: Qwen2Config {
(WorkerDict pid=49813)   "architectures": [
(WorkerDict pid=49813)     "Qwen2ForCausalLM"
(WorkerDict pid=49813)   ],
(WorkerDict pid=49813)   "attention_dropout": 0.0,
(WorkerDict pid=49813)   "eos_token_id": 151645,
(WorkerDict pid=49813)   "hidden_act": "silu",
(WorkerDict pid=49813)   "hidden_size": 3584,
(WorkerDict pid=49813)   "initializer_range": 0.02,
(WorkerDict pid=49813)   "intermediate_size": 18944,
(WorkerDict pid=49813)   "max_position_embeddings": 32768,
(WorkerDict pid=49813)   "max_window_layers": 28,
(WorkerDict pid=49813)   "model_type": "qwen2",
(WorkerDict pid=49813)   "num_attention_heads": 28,
(WorkerDict pid=49813)   "num_hidden_layers": 28,
(WorkerDict pid=49813)   "num_key_value_heads": 4,
(WorkerDict pid=49813)   "pad_token_id": 151643,
(WorkerDict pid=49813)   "rms_norm_eps": 1e-06,
(WorkerDict pid=49813)   "rope_scaling": null,
(WorkerDict pid=49813)   "rope_theta": 1000000.0,
(WorkerDict pid=49813)   "sliding_window": 131072,
(WorkerDict pid=49813)   "tie_word_embeddings": false,
(WorkerDict pid=49813)   "torch_dtype": "bfloat16",
(WorkerDict pid=49813)   "transformers_version": "4.51.1",
(WorkerDict pid=49813)   "use_cache": true,
(WorkerDict pid=49813)   "use_sliding_window": false,
(WorkerDict pid=49813)   "vocab_size": 152064
(WorkerDict pid=49813) }
(WorkerDict pid=49813) 
(WorkerDict pid=49813) 
Loading checkpoint shards:  25%|██▌       | 1/4 [00:02<00:08,  2.88s/it]
(WorkerDict pid=50074) Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Qwen2ForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)` [repeated 7x across cluster]
(WorkerDict pid=50075) 
Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s] [repeated 8x across cluster]
(WorkerDict pid=50074) 
Loading checkpoint shards:  75%|███████▌  | 3/4 [00:08<00:02,  2.88s/it] [repeated 17x across cluster]
(WorkerDict pid=50075) wrap_policy: functools.partial(<function _or_policy at 0x7f4483a6a560>, policies=[functools.partial(<function transformer_auto_wrap_policy at 0x7f4483a6a440>, transformer_layer_cls={<class 'transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer'>})]) [repeated 7x across cluster]
(WorkerDict pid=49813) Monkey patch _flash_attention_forward in transformers.integrations.flash_attention
(WorkerDict pid=50075) Actor use_remove_padding=True [repeated 7x across cluster]
(WorkerDict pid=49813) 
Loading checkpoint shards: 100%|██████████| 4/4 [00:10<00:00,  2.59s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [00:10<00:00,  2.62s/it]
(WorkerDict pid=50074) Monkey patch _flash_attention_forward in transformers.integrations.flash_attention
(WorkerDict pid=49813) Qwen2ForCausalLM contains 7.62B parameters
(WorkerDict pid=50070) Total steps: 4350, num_warmup_steps: 0
(WorkerDict pid=50075) wrap_policy: functools.partial(<function _or_policy at 0x7f4483a6a560>, policies=[functools.partial(<function transformer_auto_wrap_policy at 0x7f4483a6a440>, transformer_layer_cls={<class 'transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer'>})]) [repeated 8x across cluster]
(WorkerDict pid=49813) Actor use_remove_padding=True [repeated 8x across cluster]
(WorkerDict pid=50070) /usr/local/python/lib/python3.10/site-packages/vllm/connections.py:8: RuntimeWarning: Failed to read commit hash:
(WorkerDict pid=50070) No module named 'vllm._version'
(WorkerDict pid=50070)   from vllm.version import __version__ as VLLM_VERSION
(WorkerDict pid=50071) 
Loading checkpoint shards:  75%|███████▌  | 3/4 [00:10<00:03,  3.34s/it] [repeated 6x across cluster]
(WorkerDict pid=50071) 
Loading checkpoint shards: 100%|██████████| 4/4 [00:13<00:00,  3.17s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [00:13<00:00,  3.25s/it] [repeated 7x across cluster]
(WorkerDict pid=50074) NCCL version 2.21.5+cuda12.4
(WorkerDict pid=50071) Monkey patch _flash_attention_forward in transformers.integrations.flash_attention [repeated 6x across cluster]
(WorkerDict pid=49813) Before building sglang_async rollout, memory allocated (GB): 3.55, memory reserved (GB): 17.57, device memory used/total (GB): 20.63/95.00
(WorkerDict pid=50070) kwargs: {'n': 16, 'max_new_tokens': 1024, 'presence_penalty': 0.0, 'frequency_penalty': 0.0, 'repetition_penalty': 1.0, 'temperature': 1.0, 'top_k': -1, 'top_p': 1, 'ignore_eos': False}
(WorkerDict pid=49813) Total steps: 4350, num_warmup_steps: 0 [repeated 7x across cluster]
(WorkerDict pid=50070) /usr/local/python/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
(WorkerDict pid=50070)   warnings.warn(
(WorkerDict pid=50071) NCCL version 2.21.5+cuda12.4 [repeated 2x across cluster]
(WorkerDict pid=49813) /usr/local/python/lib/python3.10/site-packages/vllm/connections.py:8: RuntimeWarning: Failed to read commit hash: [repeated 7x across cluster]
(WorkerDict pid=49813) No module named 'vllm._version' [repeated 7x across cluster]
(WorkerDict pid=49813)   from vllm.version import __version__ as VLLM_VERSION [repeated 7x across cluster]
(WorkerDict pid=50075) /usr/local/python/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html . [repeated 3x across cluster]
(WorkerDict pid=50075)   warnings.warn( [repeated 3x across cluster]
(WorkerDict pid=50071) /usr/local/python/lib/python3.10/site-packages/vllm/connections.py:8: RuntimeWarning: Failed to read commit hash: [repeated 4x across cluster]
(WorkerDict pid=50071) No module named 'vllm._version' [repeated 4x across cluster]
(WorkerDict pid=50071)   from vllm.version import __version__ as VLLM_VERSION [repeated 4x across cluster]
(WorkerDict pid=50075) kwargs: {'n': 16, 'max_new_tokens': 1024, 'presence_penalty': 0.0, 'frequency_penalty': 0.0, 'repetition_penalty': 1.0, 'temperature': 1.0, 'top_k': -1, 'top_p': 1, 'ignore_eos': False} [repeated 3x across cluster]
(WorkerDict pid=49813) NCCL version 2.21.5+cuda12.4
(WorkerDict pid=50071) NCCL version 2.21.5+cuda12.4
(WorkerDict pid=50076) /usr/local/python/lib/python3.10/site-packages/vllm/connections.py:8: RuntimeWarning: Failed to read commit hash: [repeated 3x across cluster]
(WorkerDict pid=50076) No module named 'vllm._version' [repeated 3x across cluster]
(WorkerDict pid=50076)   from vllm.version import __version__ as VLLM_VERSION [repeated 3x across cluster]
(WorkerDict pid=49813) 
  0%|          | 0/35 [00:00<?, ?it/s]
Capturing batches (avail_mem=35.53 GB):   0%|          | 0/35 [00:00<?, ?it/s]
(WorkerDict pid=49813) 
Capturing batches (avail_mem=35.53 GB):   3%|▎         | 1/35 [00:00<00:23,  1.45it/s]
Capturing batches (avail_mem=35.08 GB):   3%|▎         | 1/35 [00:00<00:23,  1.45it/s]
(WorkerDict pid=49813) 
Capturing batches (avail_mem=35.08 GB):   6%|▌         | 2/35 [00:01<00:19,  1.73it/s]
Capturing batches (avail_mem=34.91 GB):   6%|▌         | 2/35 [00:01<00:19,  1.73it/s]
(WorkerDict pid=50074) NCCL version 2.21.5+cuda12.4
(WorkerDict pid=50076) NCCL version 2.21.5+cuda12.4
(WorkerDict pid=49813) 
Capturing batches (avail_mem=34.91 GB):   9%|▊         | 3/35 [00:01<00:19,  1.62it/s]
Capturing batches (avail_mem=34.74 GB):   9%|▊         | 3/35 [00:01<00:19,  1.62it/s]
(WorkerDict pid=50076) /usr/local/python/lib/python3.10/site-packages/vllm/connections.py:8: RuntimeWarning: Failed to read commit hash: [repeated 5x across cluster]
(WorkerDict pid=50076) No module named 'vllm._version' [repeated 5x across cluster]
(WorkerDict pid=50076)   from vllm.version import __version__ as VLLM_VERSION [repeated 5x across cluster]
(WorkerDict pid=50071) 
  0%|          | 0/35 [00:00<?, ?it/s]
Capturing batches (avail_mem=35.96 GB):   0%|          | 0/35 [00:00<?, ?it/s]
(WorkerDict pid=50071) 
Capturing batches (avail_mem=34.69 GB):  20%|██        | 7/35 [00:04<00:16,  1.71it/s]
Capturing batches (avail_mem=34.54 GB):  20%|██        | 7/35 [00:04<00:16,  1.71it/s] [repeated 13x across cluster]
(WorkerDict pid=50074) 
  0%|          | 0/35 [00:00<?, ?it/s]
Capturing batches (avail_mem=35.96 GB):   0%|          | 0/35 [00:00<?, ?it/s]
(WorkerDict pid=50076) 
Capturing batches (avail_mem=34.54 GB):  23%|██▎       | 8/35 [00:03<00:11,  2.26it/s]
Capturing batches (avail_mem=34.39 GB):  23%|██▎       | 8/35 [00:03<00:11,  2.26it/s] [repeated 40x across cluster]
(WorkerDict pid=50076) 
  0%|          | 0/35 [00:00<?, ?it/s]
Capturing batches (avail_mem=35.96 GB):   0%|          | 0/35 [00:00<?, ?it/s]
(WorkerDict pid=49813) 
Capturing batches (avail_mem=32.36 GB):  91%|█████████▏| 32/35 [00:15<00:01,  2.34it/s]
Capturing batches (avail_mem=32.35 GB):  91%|█████████▏| 32/35 [00:15<00:01,  2.34it/s]
(WorkerDict pid=49813) 
Capturing batches (avail_mem=32.35 GB):  94%|█████████▍| 33/35 [00:15<00:00,  2.33it/s]
Capturing batches (avail_mem=32.35 GB):  94%|█████████▍| 33/35 [00:15<00:00,  2.33it/s]
(WorkerDict pid=50076) 
Capturing batches (avail_mem=33.24 GB):  57%|█████▋    | 20/35 [00:08<00:06,  2.38it/s]
Capturing batches (avail_mem=33.16 GB):  57%|█████▋    | 20/35 [00:08<00:06,  2.38it/s] [repeated 47x across cluster]
(WorkerDict pid=49813) kwargs: {'n': 16, 'max_new_tokens': 1024, 'presence_penalty': 0.0, 'frequency_penalty': 0.0, 'repetition_penalty': 1.0, 'temperature': 1.0, 'top_k': -1, 'top_p': 1, 'ignore_eos': False}
(WorkerDict pid=49813) After building sglang_async rollout, memory allocated (GB): 3.55, memory reserved (GB): 17.57, device memory used/total (GB): 26.44/95.00
(WorkerDict pid=49813) After building sharding manager, memory allocated (GB): 3.55, memory reserved (GB): 17.57, device memory used/total (GB): 26.44/95.00
(WorkerDict pid=49813) /usr/local/python/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
(WorkerDict pid=49813)   warnings.warn(
(WorkerDict pid=50074) 
Capturing batches (avail_mem=32.78 GB):  94%|█████████▍| 33/35 [00:13<00:00,  2.49it/s]
Capturing batches (avail_mem=32.78 GB):  94%|█████████▍| 33/35 [00:13<00:00,  2.49it/s] [repeated 8x across cluster]
(WorkerDict pid=50076) 
Capturing batches (avail_mem=32.80 GB):  89%|████████▊ | 31/35 [00:13<00:01,  2.49it/s]
Capturing batches (avail_mem=32.79 GB):  89%|████████▊ | 31/35 [00:13<00:01,  2.49it/s] [repeated 21x across cluster]
(WorkerDict pid=50076) kwargs: {'n': 16, 'max_new_tokens': 1024, 'presence_penalty': 0.0, 'frequency_penalty': 0.0, 'repetition_penalty': 1.0, 'temperature': 1.0, 'top_k': -1, 'top_p': 1, 'ignore_eos': False} [repeated 3x across cluster]
(WorkerDict pid=50076) /usr/local/python/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html . [repeated 3x across cluster]
(WorkerDict pid=50076)   warnings.warn( [repeated 3x across cluster]
(TaskRunner pid=362057) Using LocalLogger is deprecated. The constructor API will change 
(TaskRunner pid=362057) Checkpoint tracker file does not exist: %s /tmp/ray/session_2025-05-01_20-49-28_274415_338741/runtime_resources/working_dir_files/_ray_pkg_4ce66e647b0fe766/checkpoints/test/latest_checkpointed_iteration.txt
(TaskRunner pid=362057) Training from scratch
(TaskRunner pid=362057) test_gen_batch meta info: {'eos_token_id': 151645, 'pad_token_id': 151643, 'recompute_log_prob': False, 'do_sample': False, 'validate': True}
(WorkerDict pid=49813) /usr/local/python/lib/python3.10/site-packages/sglang/srt/utils.py:888: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:203.)
(WorkerDict pid=49813)   tensor_data = torch.ByteTensor(
(WorkerDict pid=50076) 
Capturing batches (avail_mem=32.77 GB): 100%|██████████| 35/35 [00:15<00:00,  2.44it/s]
Capturing batches (avail_mem=32.77 GB): 100%|██████████| 35/35 [00:15<00:00,  2.32it/s] [repeated 6x across cluster]
(WorkerDict pid=50071) /usr/local/python/lib/python3.10/site-packages/sglang/srt/utils.py:888: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:203.) [repeated 3x across cluster]
(WorkerDict pid=50071)   tensor_data = torch.ByteTensor( [repeated 3x across cluster]
(TaskRunner pid=362057) validation generation end
(TaskRunner pid=362057) [prompt] system
(TaskRunner pid=362057) 
(TaskRunner pid=362057)                             You are a math expert. You are given a question and you need to solve it step by step.  
(TaskRunner pid=362057)                             `calc_gsm8k_reward` is a tool for calculating the reward of gsm8k. You should use this 
(TaskRunner pid=362057)                             tool to calculate the reward of your answer(1.0 if your answer is correct, 0.0 if your 
(TaskRunner pid=362057)                             answer is incorrect) before submitting it and refine your answer if necessary. Put your 
(TaskRunner pid=362057)                             final answer in the format of `#### <answer>`.
(TaskRunner pid=362057) 
(TaskRunner pid=362057) # Tools
(TaskRunner pid=362057) 
(TaskRunner pid=362057) You may call one or more functions to assist with the user query.
(TaskRunner pid=362057) 
(TaskRunner pid=362057) You are provided with function signatures within <tools></tools> XML tags:
(TaskRunner pid=362057) <tools>
(TaskRunner pid=362057) {"type": "function", "function": {"name": "calc_gsm8k_reward", "description": "A tool for calculating the reward of gsm8k. (1.0 if your answer is correct, 0.0 if your answer is incorrect)", "parameters": {"type": "object", "properties": {"answer": {"type": "string", "description": "The model's answer to the GSM8K math problem, must be a digits", "enum": null}}, "required": ["answer"]}, "strict": false}}
(TaskRunner pid=362057) </tools>
(TaskRunner pid=362057) 
(TaskRunner pid=362057) For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
(TaskRunner pid=362057) <tool_call>
(TaskRunner pid=362057) {"name": <function-name>, "arguments": <args-json-object>}
(TaskRunner pid=362057) </tool_call>
(TaskRunner pid=362057) user
(TaskRunner pid=362057) Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? 
(TaskRunner pid=362057)         You must use the `calc_gsm8k_reward` tool to calculate the reward 
(TaskRunner pid=362057)         of your answer(1.0 if your answer is correct, 0.0 if your answer is incorrect) 
(TaskRunner pid=362057)         before submitting it at least once and refine your answer if necessary. 
(TaskRunner pid=362057)         Put your final answer in the format of `#### <answer>`.
(TaskRunner pid=362057)     
(TaskRunner pid=362057) assistant
(TaskRunner pid=362057) 
(TaskRunner pid=362057) [response] To find out how much Janet makes every day at the farmers' market, we need to follow these steps:
(TaskRunner pid=362057) 
(TaskRunner pid=362057) 1. Calculate the total number of eggs laid by the ducks per day.
(TaskRunner pid=362057) 2. Subtract the number of eggs Janet eats for breakfast and the number of eggs she uses for baking muffins.
(TaskRunner pid=362057) 3. Multiply the remaining number of eggs by the price per egg to find out how much she makes at the farmers' market.
(TaskRunner pid=362057) 
(TaskRunner pid=362057) Let's do the calculations:
(TaskRunner pid=362057) 
(TaskRunner pid=362057) 1. Total number of eggs laid per day: 16
(TaskRunner pid=362057) 2. Number of eggs Janet eats for breakfast: 3
(TaskRunner pid=362057) 3. Number of eggs Janet uses for baking muffins: 4
(TaskRunner pid=362057) 4. Remaining number of eggs: 16 - 3 - 4 = 9
(TaskRunner pid=362057) 5. Price per egg: $2
(TaskRunner pid=362057) 
(TaskRunner pid=362057) Now, let's calculate the total amount Janet makes at the farmers' market:
(TaskRunner pid=362057) 
(TaskRunner pid=362057) \[ \text{Total amount} = \text{Remaining number of eggs} \times \text{Price per egg} \]
(TaskRunner pid=362057) \[ \text{Total amount} = 9 \times 2 = 18 \]
(TaskRunner pid=362057) 
(TaskRunner pid=362057) So, Janet makes $18 every day at the farmers' market.
(TaskRunner pid=362057) 
(TaskRunner pid=362057) Now, let's use the `calc_gsm8k_reward` tool to check the correctness of our answer.
(TaskRunner pid=362057) <tool_call>
(TaskRunner pid=362057) {"name": "calc_gsm8k_reward", "arguments": "{\"answer\": \"18\"}"}
(TaskRunner pid=362057) </tool_call>
(TaskRunner pid=362057) tool
(TaskRunner pid=362057) Current parsed answer='18' reward=1.0
(TaskRunner pid=362057) assistant
(TaskRunner pid=362057) #### 18
(TaskRunner pid=362057) [ground_truth] 18
(TaskRunner pid=362057) [score] 1.0
(TaskRunner pid=362057) 'Initial validation metrics: {}'
(TaskRunner pid=362057) step:0
(TaskRunner pid=362057) 
Training Progress:   0%|          | 0/4350 [00:00<?, ?it/s]
(TaskRunner pid=362057) list(reward_extra_infos_dict.keys())=[]
(TaskRunner pid=362057) step:1 - global_seqlen/min:379680.000 - global_seqlen/max:457144.000 - global_seqlen/minmax_diff:77464.000 - global_seqlen/balanced_min:416128.000 - global_seqlen/balanced_max:416129.000 - global_seqlen/mean:416128.750 - actor/entropy_loss:0.216 - actor/kl_loss:0.000 - actor/kl_coef:0.001 - actor/pg_loss:-0.004 - actor/pg_clipfrac:0.000 - actor/ppo_kl:0.000 - actor/pg_clipfrac_lower:0.000 - actor/grad_norm:0.010 - perf/mfu/actor:0.700 - perf/max_memory_allocated_gb:49.644 - perf/max_memory_reserved_gb:80.434 - perf/cpu_memory_used_gb:108.922 - actor/lr:0.000 - training/global_step:1.000 - training/epoch:0.000 - critic/score/mean:0.905 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.905 - critic/rewards/max:1.000 - critic/rewards/min:0.000 - critic/advantages/mean:-0.000 - critic/advantages/max:0.653 - critic/advantages/min:-1.436 - critic/returns/mean:-0.000 - critic/returns/max:0.653 - critic/returns/min:-1.436 - response_length/mean:364.025 - response_length/max:1024.000 - response_length/min:150.000 - response_length/clip_ratio:0.026 - prompt_length/mean:448.727 - prompt_length/max:533.000 - prompt_length/min:410.000 - prompt_length/clip_ratio:0.000 - timing_s/gen:118.270 - timing_s/reward:1.181 - timing_s/old_log_prob:52.755 - timing_s/ref:54.759 - timing_s/adv:0.059 - timing_s/update_actor:188.158 - timing_s/step:415.307 - timing_per_token_ms/gen:0.079 - timing_per_token_ms/update_actor:0.057 - timing_per_token_ms/adv:0.000 - timing_per_token_ms/ref:0.016 - perf/total_num_tokens:3329030.000 - perf/time_per_step:415.307 - perf/throughput:1001.979
(TaskRunner pid=362057) 
Training Progress:   0%|          | 1/4350 [06:57<504:26:22, 417.56s/it]
(TaskRunner pid=362057) list(reward_extra_infos_dict.keys())=[]
(WorkerDict pid=49813) WARN: rank 0 grad_norm is not finite: nan
(TaskRunner pid=362057) step:2 - global_seqlen/min:407422.000 - global_seqlen/max:440440.000 - global_seqlen/minmax_diff:33018.000 - global_seqlen/balanced_min:418815.000 - global_seqlen/balanced_max:418816.000 - global_seqlen/mean:418815.625 - actor/entropy_loss:0.218 - actor/kl_loss:0.000 - actor/kl_coef:0.001 - actor/pg_loss:0.006 - actor/pg_clipfrac:0.000 - actor/ppo_kl:0.000 - actor/pg_clipfrac_lower:0.000 - actor/grad_norm:nan - perf/mfu/actor:0.703 - perf/max_memory_allocated_gb:54.050 - perf/max_memory_reserved_gb:80.617 - perf/cpu_memory_used_gb:109.974 - actor/lr:0.000 - training/global_step:2.000 - training/epoch:0.000 - critic/score/mean:0.903 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.903 - critic/rewards/max:1.000 - critic/rewards/min:0.000 - critic/advantages/mean:-0.001 - critic/advantages/max:1.436 - critic/advantages/min:-1.677 - critic/returns/mean:-0.001 - critic/returns/max:1.436 - critic/returns/min:-1.677 - response_length/mean:368.902 - response_length/max:1024.000 - response_length/min:156.000 - response_length/clip_ratio:0.028 - prompt_length/mean:449.098 - prompt_length/max:559.000 - prompt_length/min:409.000 - prompt_length/clip_ratio:0.000 - timing_s/gen:112.222 - timing_s/reward:1.169 - timing_s/old_log_prob:50.543 - timing_s/ref:49.970 - timing_s/adv:0.057 - timing_s/update_actor:188.153 - timing_s/step:402.215 - timing_per_token_ms/gen:0.074 - timing_per_token_ms/update_actor:0.056 - timing_per_token_ms/adv:0.000 - timing_per_token_ms/ref:0.015 - perf/total_num_tokens:3350525.000 - perf/time_per_step:402.215 - perf/throughput:1041.273
(TaskRunner pid=362057) 
Training Progress:   0%|          | 2/4350 [13:40<493:36:46, 408.70s/it]
(TaskRunner pid=362057) list(reward_extra_infos_dict.keys())=[]
(WorkerDict pid=50075) WARN: rank 5 grad_norm is not finite: nan [repeated 7x across cluster]
(WorkerDict pid=49813) WARN: rank 0 grad_norm is not finite: nan
(WorkerDict pid=50070) WARN: rank 1 grad_norm is not finite: nan
(TaskRunner pid=362057) step:3 - global_seqlen/min:387116.000 - global_seqlen/max:433827.000 - global_seqlen/minmax_diff:46711.000 - global_seqlen/balanced_min:416119.000 - global_seqlen/balanced_max:416120.000 - global_seqlen/mean:416119.125 - actor/entropy_loss:0.212 - actor/kl_loss:0.000 - actor/kl_coef:0.001 - actor/pg_loss:-0.001 - actor/pg_clipfrac:0.000 - actor/ppo_kl:-0.000 - actor/pg_clipfrac_lower:0.000 - actor/grad_norm:nan - perf/mfu/actor:0.709 - perf/max_memory_allocated_gb:54.534 - perf/max_memory_reserved_gb:80.617 - perf/cpu_memory_used_gb:109.965 - actor/lr:0.000 - training/global_step:3.000 - training/epoch:0.000 - critic/score/mean:0.890 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.890 - critic/rewards/max:1.000 - critic/rewards/min:0.000 - critic/advantages/mean:-0.000 - critic/advantages/max:0.250 - critic/advantages/min:-3.750 - critic/returns/mean:-0.000 - critic/returns/max:0.250 - critic/returns/min:-3.750 - response_length/mean:364.510 - response_length/max:1024.000 - response_length/min:150.000 - response_length/clip_ratio:0.020 - prompt_length/mean:448.223 - prompt_length/max:534.000 - prompt_length/min:411.000 - prompt_length/clip_ratio:0.000 - timing_s/gen:125.431 - timing_s/reward:1.166 - timing_s/old_log_prob:49.390 - timing_s/ref:49.245 - timing_s/adv:0.058 - timing_s/update_actor:185.307 - timing_s/step:410.658 - timing_per_token_ms/gen:0.084 - timing_per_token_ms/update_actor:0.056 - timing_per_token_ms/adv:0.000 - timing_per_token_ms/ref:0.015 - perf/total_num_tokens:3328953.000 - perf/time_per_step:410.658 - perf/throughput:1013.299
(TaskRunner pid=362057) 
Training Progress:   0%|          | 3/4350 [20:30<494:43:43, 409.71s/it]
(TaskRunner pid=362057) list(reward_extra_infos_dict.keys())=[]
(WorkerDict pid=50075) WARN: rank 5 grad_norm is not finite: nan [repeated 6x across cluster]
(WorkerDict pid=49813) WARN: rank 0 grad_norm is not finite: nan
(WorkerDict pid=50070) WARN: rank 1 grad_norm is not finite: nan
(TaskRunner pid=362057) step:4 - global_seqlen/min:408386.000 - global_seqlen/max:454859.000 - global_seqlen/minmax_diff:46473.000 - global_seqlen/balanced_min:422598.000 - global_seqlen/balanced_max:422599.000 - global_seqlen/mean:422598.750 - actor/entropy_loss:0.216 - actor/kl_loss:0.000 - actor/kl_coef:0.001 - actor/pg_loss:0.003 - actor/pg_clipfrac:0.000 - actor/ppo_kl:0.000 - actor/pg_clipfrac_lower:0.000 - actor/grad_norm:nan - perf/mfu/actor:0.705 - perf/max_memory_allocated_gb:55.359 - perf/max_memory_reserved_gb:80.617 - perf/cpu_memory_used_gb:110.119 - actor/lr:0.000 - training/global_step:4.000 - training/epoch:0.000 - critic/score/mean:0.872 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.872 - critic/rewards/max:1.000 - critic/rewards/min:0.000 - critic/advantages/mean:-0.000 - critic/advantages/max:2.016 - critic/advantages/min:-0.465 - critic/returns/mean:-0.000 - critic/returns/max:2.016 - critic/returns/min:-0.465 - response_length/mean:377.724 - response_length/max:1024.000 - response_length/min:95.000 - response_length/clip_ratio:0.029 - prompt_length/mean:447.664 - prompt_length/max:550.000 - prompt_length/min:415.000 - prompt_length/clip_ratio:0.000 - timing_s/gen:133.983 - timing_s/reward:1.183 - timing_s/old_log_prob:50.797 - timing_s/ref:50.276 - timing_s/adv:0.059 - timing_s/update_actor:190.039 - timing_s/step:426.396 - timing_per_token_ms/gen:0.087 - timing_per_token_ms/update_actor:0.056 - timing_per_token_ms/adv:0.000 - timing_per_token_ms/ref:0.015 - perf/total_num_tokens:3380790.000 - perf/time_per_step:426.396 - perf/throughput:991.094
(TaskRunner pid=362057) 

The versions of my Python packages are as follows:

Package                           Version       Editable project location
--------------------------------- ------------- -------------------------
absl-py                           2.1.0
accelerate                        1.6.0
aiohappyeyeballs                  2.6.1
aiohttp                           3.11.18
aiohttp-cors                      0.8.1
aiosignal                         1.3.2
annotated-types                   0.7.0
anthropic                         0.50.0
antlr4-python3-runtime            4.9.3
anyio                             4.9.0
asttokens                         3.0.0
astunparse                        1.6.3
async-timeout                     5.0.1
attrs                             25.3.0
beautifulsoup4                    4.13.4
blobfile                          3.0.0
boto3                             1.36.16
botocore                          1.36.17
cachetools                        5.5.2
certifi                           2025.4.26
cffi                              1.17.1
cfgv                              3.4.0
chardet                           5.2.0
charset-normalizer                3.4.1
click                             8.1.8
cloudpickle                       3.1.1
cmake                             3.31.4
codetiming                        1.4.0
colorful                          0.5.6
compressed-tensors                0.9.4
cuda-bindings                     12.8.0
cuda-python                       12.8.0
datasets                          3.5.1
decorator                         5.2.1
decord                            0.6.0
dill                              0.3.8
diskcache                         5.6.3
distlib                           0.3.9
distro                            1.9.0
docker-pycreds                    0.4.0
duckduckgo_search                 8.0.1
einops                            0.8.1
einops-exts                       0.0.4
exceptiongroup                    1.2.2
executing                         2.2.0
expecttest                        0.3.0
fastapi                           0.115.12
filelock                          3.18.0
flamingo-pytorch                  0.1.2
flash_attn                        2.7.4.post1
flashinfer-python                 0.2.3
frozenlist                        1.6.0
fsspec                            2025.3.0
ftfy                              6.3.1
gguf                              0.10.0
gitdb                             4.0.12
GitPython                         3.1.44
google-api-core                   2.24.2
google-auth                       2.39.0
googleapis-common-protos          1.70.0
grpcio                            1.71.0
h11                               0.16.0
hf_transfer                       0.1.9
httpcore                          1.0.9
httptools                         0.6.4
httpx                             0.28.1
huggingface-hub                   0.30.2
hydra-core                        1.3.2
hypothesis                        6.125.2
identify                          2.6.10
idna                              3.10
importlib_metadata                8.7.0
interegular                       0.3.3
ipython                           8.36.0
jedi                              0.19.2
Jinja2                            3.1.6
jiter                             0.9.0
jmespath                          1.0.1
jsonschema                        4.23.0
jsonschema-specifications         2025.4.1
lark                              1.2.2
liger_kernel                      0.5.8
lintrunner                        0.12.7
litellm                           1.67.5
llguidance                        0.7.19
llvmlite                          0.44.0
lm-format-enforcer                0.10.6
lxml                              5.4.0
Markdown                          3.7
markdown-it-py                    3.0.0
markdownify                       1.1.0
MarkupSafe                        3.0.2
matplotlib-inline                 0.1.7
mdurl                             0.1.2
mistral_common                    1.5.4
modelscope                        1.25.0
mpmath                            1.3.0
msgpack                           1.1.0
msgspec                           0.19.0
multidict                         6.4.3
multiprocess                      0.70.16
nanobind                          2.7.0
nest-asyncio                      1.6.0
networkx                          3.4.2
ninja                             1.11.1.3
nodeenv                           1.9.1
numba                             0.61.2
numpy                             1.26.4
nvidia-cublas-cu12                12.4.5.8
nvidia-cuda-cupti-cu12            12.4.127
nvidia-cuda-nvrtc-cu12            12.4.127
nvidia-cuda-runtime-cu12          12.4.127
nvidia-cudnn-cu12                 9.1.0.70
nvidia-cufft-cu12                 11.2.1.3
nvidia-curand-cu12                10.3.5.147
nvidia-cusolver-cu12              11.6.1.9
nvidia-cusparse-cu12              12.3.1.170
nvidia-cusparselt-cu12            0.6.2
nvidia-ml-py                      12.570.86
nvidia-nccl-cu12                  2.21.5
nvidia-nvjitlink-cu12             12.4.127
nvidia-nvtx-cu12                  12.4.127
omegaconf                         2.3.0
open_clip_torch                   2.30.0
openai                            1.76.2
opencensus                        0.11.4
opencensus-context                0.1.3
opencv-contrib-python             4.11.0.86
opencv-python                     4.11.0.86
opencv-python-headless            4.11.0.86
optree                            0.14.0
orjson                            3.10.18
outlines                          0.0.46
packaging                         25.0
pandas                            2.2.3
parso                             0.8.4
partial-json-parser               0.2.1.1.post5
peft                              0.15.2
pexpect                           4.9.0
pillow                            11.1.0
pip                               25.0.1
platformdirs                      4.3.7
pre_commit                        4.2.0
primp                             0.15.0
prometheus_client                 0.21.1
prometheus-fastapi-instrumentator 7.1.0
prompt_toolkit                    3.0.51
propcache                         0.3.1
proto-plus                        1.26.1
protobuf                          6.30.2
psutil                            7.0.0
ptyprocess                        0.7.0
pure_eval                         0.2.3
py-cpuinfo                        9.0.0
py-spy                            0.4.0
pyairports                        2.1.1
pyarrow                           20.0.0
pyasn1                            0.6.1
pyasn1_modules                    0.4.2
pybind11                          2.13.6
pycountry                         24.6.1
pycparser                         2.22
pycryptodomex                     3.22.0
pydantic                          2.11.4
pydantic_core                     2.33.2
Pygments                          2.19.1
pylatexenc                        2.10
pynvml                            12.0.0
python-dateutil                   2.9.0.post0
python-dotenv                     1.1.0
python-multipart                  0.0.20
pytz                              2025.2
PyYAML                            6.0.2
pyzmq                             26.4.0
ray                               2.45.0
referencing                       0.36.2
regex                             2024.11.6
requests                          2.32.3
rich                              14.0.0
rpds-py                           0.24.0
rsa                               4.9.1
ruamel.yaml                       0.18.10
ruamel.yaml.clib                  0.2.12
s3transfer                        0.11.2
safetensors                       0.5.3
sentencepiece                     0.2.0
sentry-sdk                        2.27.0
setproctitle                      1.3.6
setuptools                        80.1.0
sgl-kernel                        0.0.9.post2
sglang                            0.4.5.post3
six                               1.17.0
smart-open                        7.1.0
smmap                             5.0.2
smolagents                        1.14.0
sniffio                           1.3.1
sortedcontainers                  2.4.0
soundfile                         0.13.1
soupsieve                         2.7
sox                               1.5.0
stack-data                        0.6.3
starlette                         0.46.2
sympy                             1.13.1
tensorboard                       2.18.0
tensorboard-data-server           0.7.2
tensorboardX                      2.6.2.2
tensordict                        0.6.2
termcolor                         2.5.0
tiktoken                          0.8.0
timm                              1.0.14
tokenizers                        0.21.1
torch                             2.6.0+cu124
torch_memory_saver                0.0.5
torchao                           0.10.0
torchdata                         0.11.0
torchvision                       0.21.0
tqdm                              4.67.1
traitlets                         5.14.3
transformers                      4.51.1
triton                            3.2.0
types-dataclasses                 0.6.6
typing_extensions                 4.13.2
typing-inspection                 0.4.0
tzdata                            2025.2
urllib3                           2.4.0
uvicorn                           0.34.2
uvloop                            0.21.0
verl                              0.2.0.dev0    /root/verl
virtualenv                        20.30.0
vllm                              0.6.3
wandb                             0.19.10
watchfiles                        1.0.5
wcwidth                           0.2.13
websockets                        15.0.1
Werkzeug                          3.1.3
wheel                             0.45.1
wrapt                             1.17.2
xformers                          0.0.27.post2
xgrammar                          0.1.17
xxhash                            3.5.0
yarl                              1.20.0
zipp                              3.21.0

My CUDA version is Cuda compilation tools, release 12.4, V12.4.131.
The specific execution command is:

export VLLM_ATTENTION_BACKEND=XFORMERS
export NCCL_DEBUG=INFO
export GLOO_DEBUG=1

export NCCL_IB_GID_INDEX=3
export NCCL_IB_SL=3
export NCCL_CHECK_DISABLE=1
export NCCL_P2P_DISABLE=0
export NCCL_IB_DISABLE=0
export NCCL_LL_THRESHOLD=16384
export NCCL_IB_CUDA_SUPPORT=1
export NCCL_SOCKET_IFNAME=bond1
export GLOO_SOCKET_IFNAME=bond1

export UCX_NET_DEVICES=bond1
export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
export NCCL_COLLNET_ENABLE=0
export SHARP_COLL_ENABLE_SAT=0
export NCCL_NET_GDR_LEVEL=2
export NCCL_IB_QPS_PER_CONNECTION=4
export NCCL_IB_TC=160
export NCCL_PXN_DISABLE=0
export NCCL_DEBUG="INFO"
export HYDRA_FULL_ERROR=1

ray job submit \
    -- python3 -m verl.trainer.main_ppo \
    --config-path="$CONFIG_PATH" \
    --config-name='gsm8k_multiturn_grpo' \
    algorithm.adv_estimator=grpo \
    data.train_batch_size=256 \
    data.max_prompt_length=1024 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    data.return_raw_chat=True \
    actor_rollout_ref.model.path=model/Qwen2.5-7B-Instruct \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=256 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
    actor_rollout_ref.rollout.name=sglang_async \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
    actor_rollout_ref.rollout.n=${rollout_n} \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger=['console','tensorboard'] \
    trainer.project_name="${project_name}" \
    trainer.experiment_name="${exp_name}" \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=1 \
    trainer.save_freq=-1 \
    trainer.test_freq=20 \
    data.train_files=${TRAIN_FILE} \
    data.val_files=${TEST_FILE} \ actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \
    trainer.total_epochs=150

I sincerely hope to get some help. Thank you very much!

Thanks for sharing your settings with early nan. It appears to have some instability during the training process. We'll increase monitoring in wandb and examine it closely.

Same issue here. Any update on how this is solved? Thanks!

@SwordFaith
Copy link
Collaborator Author

Hello, I used verl/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh for training, and the grad_norm became "nan" starting from the second step and afterwards.
The specific log is as follows:

Filtering prompts longer than 1024 tokens:   0%|          | 0/7473 [00:00<?, ? examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  13%|█▎        | 1000/7473 [00:00<00:03, 1835.43 examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  27%|██▋       | 2000/7473 [00:01<00:03, 1603.15 examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  40%|████      | 3000/7473 [00:01<00:02, 1757.81 examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  54%|█████▎    | 4000/7473 [00:02<00:01, 1834.26 examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  67%|██████▋   | 5000/7473 [00:02<00:01, 1883.81 examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  80%|████████  | 6000/7473 [00:03<00:00, 1914.56 examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  94%|█████████▎| 7000/7473 [00:03<00:00, 1933.69 examples/s]
(TaskRunner pid=362057) filter dataset len: 7473
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens: 100%|██████████| 7473/7473 [00:03<00:00, 1942.47 examples/s]
Filtering prompts longer than 1024 tokens: 100%|██████████| 7473/7473 [00:03<00:00, 1871.16 examples/s]
(TaskRunner pid=362057) dataset len: 1319
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:   0%|          | 0/1319 [00:00<?, ? examples/s]
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens:  76%|███████▌  | 1000/1319 [00:00<00:00, 1962.69 examples/s]
(TaskRunner pid=362057) filter dataset len: 1319
(TaskRunner pid=362057) Size of train dataloader: 29, Size of val dataloader: 1
(TaskRunner pid=362057) Total training steps: 4350
(TaskRunner pid=362057) colocated worker base class <class 'verl.single_controller.base.worker.Worker'>
(TaskRunner pid=362057) 
Filtering prompts longer than 1024 tokens: 100%|██████████| 1319/1319 [00:00<00:00, 1947.57 examples/s]
Filtering prompts longer than 1024 tokens: 100%|██████████| 1319/1319 [00:00<00:00, 1947.28 examples/s]
(TaskRunner pid=362057) DeprecationWarning: `ray.state.available_resources_per_node` is a private attribute and access will be removed in a future Ray version.
(TaskRunner pid=362057) WARNING:2025-05-01 20:50:26,029:Waiting for register center actor w58LT1_register_center to be ready. Elapsed time: 0 seconds out of 300 seconds.
(WorkerDict pid=50072) Monkey patch _flash_attention_forward in transformers.integrations.flash_attention
(WorkerDict pid=50072) You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
(WorkerDict pid=50072) 
Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 85.04it/s]
(WorkerDict pid=50072) [rank3]:[W501 20:50:39.657321419 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 3]  using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
(WorkerDict pid=50070) 
Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]
(WorkerDict pid=50070) 
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 79.79it/s]
(WorkerDict pid=49813) Model config after override: Qwen2Config {
(WorkerDict pid=49813)   "architectures": [
(WorkerDict pid=49813)     "Qwen2ForCausalLM"
(WorkerDict pid=49813)   ],
(WorkerDict pid=49813)   "attention_dropout": 0.0,
(WorkerDict pid=49813)   "eos_token_id": 151645,
(WorkerDict pid=49813)   "hidden_act": "silu",
(WorkerDict pid=49813)   "hidden_size": 3584,
(WorkerDict pid=49813)   "initializer_range": 0.02,
(WorkerDict pid=49813)   "intermediate_size": 18944,
(WorkerDict pid=49813)   "max_position_embeddings": 32768,
(WorkerDict pid=49813)   "max_window_layers": 28,
(WorkerDict pid=49813)   "model_type": "qwen2",
(WorkerDict pid=49813)   "num_attention_heads": 28,
(WorkerDict pid=49813)   "num_hidden_layers": 28,
(WorkerDict pid=49813)   "num_key_value_heads": 4,
(WorkerDict pid=49813)   "pad_token_id": 151643,
(WorkerDict pid=49813)   "rms_norm_eps": 1e-06,
(WorkerDict pid=49813)   "rope_scaling": null,
(WorkerDict pid=49813)   "rope_theta": 1000000.0,
(WorkerDict pid=49813)   "sliding_window": 131072,
(WorkerDict pid=49813)   "tie_word_embeddings": false,
(WorkerDict pid=49813)   "torch_dtype": "bfloat16",
(WorkerDict pid=49813)   "transformers_version": "4.51.1",
(WorkerDict pid=49813)   "use_cache": true,
(WorkerDict pid=49813)   "use_sliding_window": false,
(WorkerDict pid=49813)   "vocab_size": 152064
(WorkerDict pid=49813) }
(WorkerDict pid=49813) 
(WorkerDict pid=49813) NCCL version 2.21.5+cuda12.4
(WorkerDict pid=49813) Qwen2ForCausalLM contains 7.62B parameters
(WorkerDict pid=49813) wrap_policy: functools.partial(<function _or_policy at 0x7ef5ed872320>, policies=[functools.partial(<function transformer_auto_wrap_policy at 0x7ef5ed872200>, transformer_layer_cls={<class 'transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer'>})])
(WorkerDict pid=49813) Monkey patch _flash_attention_forward in transformers.integrations.flash_attention [repeated 7x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)
(WorkerDict pid=49813) Actor use_remove_padding=True
(WorkerDict pid=50070) Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Qwen2ForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
(WorkerDict pid=49813) You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`. [repeated 7x across cluster]
(WorkerDict pid=49813) 
Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 63.93it/s] [repeated 4x across cluster]
(WorkerDict pid=49813) [rank0]:[W501 20:50:41.237036992 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 0]  using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id. [repeated 7x across cluster]
(WorkerDict pid=50075) 
Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s] [repeated 2x across cluster]
(WorkerDict pid=50075) 
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 75.93it/s] [repeated 2x across cluster]
(WorkerDict pid=49813) Model config after override: Qwen2Config {
(WorkerDict pid=49813)   "architectures": [
(WorkerDict pid=49813)     "Qwen2ForCausalLM"
(WorkerDict pid=49813)   ],
(WorkerDict pid=49813)   "attention_dropout": 0.0,
(WorkerDict pid=49813)   "eos_token_id": 151645,
(WorkerDict pid=49813)   "hidden_act": "silu",
(WorkerDict pid=49813)   "hidden_size": 3584,
(WorkerDict pid=49813)   "initializer_range": 0.02,
(WorkerDict pid=49813)   "intermediate_size": 18944,
(WorkerDict pid=49813)   "max_position_embeddings": 32768,
(WorkerDict pid=49813)   "max_window_layers": 28,
(WorkerDict pid=49813)   "model_type": "qwen2",
(WorkerDict pid=49813)   "num_attention_heads": 28,
(WorkerDict pid=49813)   "num_hidden_layers": 28,
(WorkerDict pid=49813)   "num_key_value_heads": 4,
(WorkerDict pid=49813)   "pad_token_id": 151643,
(WorkerDict pid=49813)   "rms_norm_eps": 1e-06,
(WorkerDict pid=49813)   "rope_scaling": null,
(WorkerDict pid=49813)   "rope_theta": 1000000.0,
(WorkerDict pid=49813)   "sliding_window": 131072,
(WorkerDict pid=49813)   "tie_word_embeddings": false,
(WorkerDict pid=49813)   "torch_dtype": "bfloat16",
(WorkerDict pid=49813)   "transformers_version": "4.51.1",
(WorkerDict pid=49813)   "use_cache": true,
(WorkerDict pid=49813)   "use_sliding_window": false,
(WorkerDict pid=49813)   "vocab_size": 152064
(WorkerDict pid=49813) }
(WorkerDict pid=49813) 
(WorkerDict pid=49813) 
Loading checkpoint shards:  25%|██▌       | 1/4 [00:02<00:08,  2.88s/it]
(WorkerDict pid=50074) Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Qwen2ForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)` [repeated 7x across cluster]
(WorkerDict pid=50075) 
Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s] [repeated 8x across cluster]
(WorkerDict pid=50074) 
Loading checkpoint shards:  75%|███████▌  | 3/4 [00:08<00:02,  2.88s/it] [repeated 17x across cluster]
(WorkerDict pid=50075) wrap_policy: functools.partial(<function _or_policy at 0x7f4483a6a560>, policies=[functools.partial(<function transformer_auto_wrap_policy at 0x7f4483a6a440>, transformer_layer_cls={<class 'transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer'>})]) [repeated 7x across cluster]
(WorkerDict pid=49813) Monkey patch _flash_attention_forward in transformers.integrations.flash_attention
(WorkerDict pid=50075) Actor use_remove_padding=True [repeated 7x across cluster]
(WorkerDict pid=49813) 
Loading checkpoint shards: 100%|██████████| 4/4 [00:10<00:00,  2.59s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [00:10<00:00,  2.62s/it]
(WorkerDict pid=50074) Monkey patch _flash_attention_forward in transformers.integrations.flash_attention
(WorkerDict pid=49813) Qwen2ForCausalLM contains 7.62B parameters
(WorkerDict pid=50070) Total steps: 4350, num_warmup_steps: 0
(WorkerDict pid=50075) wrap_policy: functools.partial(<function _or_policy at 0x7f4483a6a560>, policies=[functools.partial(<function transformer_auto_wrap_policy at 0x7f4483a6a440>, transformer_layer_cls={<class 'transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer'>})]) [repeated 8x across cluster]
(WorkerDict pid=49813) Actor use_remove_padding=True [repeated 8x across cluster]
(WorkerDict pid=50070) /usr/local/python/lib/python3.10/site-packages/vllm/connections.py:8: RuntimeWarning: Failed to read commit hash:
(WorkerDict pid=50070) No module named 'vllm._version'
(WorkerDict pid=50070)   from vllm.version import __version__ as VLLM_VERSION
(WorkerDict pid=50071) 
Loading checkpoint shards:  75%|███████▌  | 3/4 [00:10<00:03,  3.34s/it] [repeated 6x across cluster]
(WorkerDict pid=50071) 
Loading checkpoint shards: 100%|██████████| 4/4 [00:13<00:00,  3.17s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [00:13<00:00,  3.25s/it] [repeated 7x across cluster]
(WorkerDict pid=50074) NCCL version 2.21.5+cuda12.4
(WorkerDict pid=50071) Monkey patch _flash_attention_forward in transformers.integrations.flash_attention [repeated 6x across cluster]
(WorkerDict pid=49813) Before building sglang_async rollout, memory allocated (GB): 3.55, memory reserved (GB): 17.57, device memory used/total (GB): 20.63/95.00
(WorkerDict pid=50070) kwargs: {'n': 16, 'max_new_tokens': 1024, 'presence_penalty': 0.0, 'frequency_penalty': 0.0, 'repetition_penalty': 1.0, 'temperature': 1.0, 'top_k': -1, 'top_p': 1, 'ignore_eos': False}
(WorkerDict pid=49813) Total steps: 4350, num_warmup_steps: 0 [repeated 7x across cluster]
(WorkerDict pid=50070) /usr/local/python/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
(WorkerDict pid=50070)   warnings.warn(
(WorkerDict pid=50071) NCCL version 2.21.5+cuda12.4 [repeated 2x across cluster]
(WorkerDict pid=49813) /usr/local/python/lib/python3.10/site-packages/vllm/connections.py:8: RuntimeWarning: Failed to read commit hash: [repeated 7x across cluster]
(WorkerDict pid=49813) No module named 'vllm._version' [repeated 7x across cluster]
(WorkerDict pid=49813)   from vllm.version import __version__ as VLLM_VERSION [repeated 7x across cluster]
(WorkerDict pid=50075) /usr/local/python/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html . [repeated 3x across cluster]
(WorkerDict pid=50075)   warnings.warn( [repeated 3x across cluster]
(WorkerDict pid=50071) /usr/local/python/lib/python3.10/site-packages/vllm/connections.py:8: RuntimeWarning: Failed to read commit hash: [repeated 4x across cluster]
(WorkerDict pid=50071) No module named 'vllm._version' [repeated 4x across cluster]
(WorkerDict pid=50071)   from vllm.version import __version__ as VLLM_VERSION [repeated 4x across cluster]
(WorkerDict pid=50075) kwargs: {'n': 16, 'max_new_tokens': 1024, 'presence_penalty': 0.0, 'frequency_penalty': 0.0, 'repetition_penalty': 1.0, 'temperature': 1.0, 'top_k': -1, 'top_p': 1, 'ignore_eos': False} [repeated 3x across cluster]
(WorkerDict pid=49813) NCCL version 2.21.5+cuda12.4
(WorkerDict pid=50071) NCCL version 2.21.5+cuda12.4
(WorkerDict pid=50076) /usr/local/python/lib/python3.10/site-packages/vllm/connections.py:8: RuntimeWarning: Failed to read commit hash: [repeated 3x across cluster]
(WorkerDict pid=50076) No module named 'vllm._version' [repeated 3x across cluster]
(WorkerDict pid=50076)   from vllm.version import __version__ as VLLM_VERSION [repeated 3x across cluster]
(WorkerDict pid=49813) 
  0%|          | 0/35 [00:00<?, ?it/s]
Capturing batches (avail_mem=35.53 GB):   0%|          | 0/35 [00:00<?, ?it/s]
(WorkerDict pid=49813) 
Capturing batches (avail_mem=35.53 GB):   3%|▎         | 1/35 [00:00<00:23,  1.45it/s]
Capturing batches (avail_mem=35.08 GB):   3%|▎         | 1/35 [00:00<00:23,  1.45it/s]
(WorkerDict pid=49813) 
Capturing batches (avail_mem=35.08 GB):   6%|▌         | 2/35 [00:01<00:19,  1.73it/s]
Capturing batches (avail_mem=34.91 GB):   6%|▌         | 2/35 [00:01<00:19,  1.73it/s]
(WorkerDict pid=50074) NCCL version 2.21.5+cuda12.4
(WorkerDict pid=50076) NCCL version 2.21.5+cuda12.4
(WorkerDict pid=49813) 
Capturing batches (avail_mem=34.91 GB):   9%|▊         | 3/35 [00:01<00:19,  1.62it/s]
Capturing batches (avail_mem=34.74 GB):   9%|▊         | 3/35 [00:01<00:19,  1.62it/s]
(WorkerDict pid=50076) /usr/local/python/lib/python3.10/site-packages/vllm/connections.py:8: RuntimeWarning: Failed to read commit hash: [repeated 5x across cluster]
(WorkerDict pid=50076) No module named 'vllm._version' [repeated 5x across cluster]
(WorkerDict pid=50076)   from vllm.version import __version__ as VLLM_VERSION [repeated 5x across cluster]
(WorkerDict pid=50071) 
  0%|          | 0/35 [00:00<?, ?it/s]
Capturing batches (avail_mem=35.96 GB):   0%|          | 0/35 [00:00<?, ?it/s]
(WorkerDict pid=50071) 
Capturing batches (avail_mem=34.69 GB):  20%|██        | 7/35 [00:04<00:16,  1.71it/s]
Capturing batches (avail_mem=34.54 GB):  20%|██        | 7/35 [00:04<00:16,  1.71it/s] [repeated 13x across cluster]
(WorkerDict pid=50074) 
  0%|          | 0/35 [00:00<?, ?it/s]
Capturing batches (avail_mem=35.96 GB):   0%|          | 0/35 [00:00<?, ?it/s]
(WorkerDict pid=50076) 
Capturing batches (avail_mem=34.54 GB):  23%|██▎       | 8/35 [00:03<00:11,  2.26it/s]
Capturing batches (avail_mem=34.39 GB):  23%|██▎       | 8/35 [00:03<00:11,  2.26it/s] [repeated 40x across cluster]
(WorkerDict pid=50076) 
  0%|          | 0/35 [00:00<?, ?it/s]
Capturing batches (avail_mem=35.96 GB):   0%|          | 0/35 [00:00<?, ?it/s]
(WorkerDict pid=49813) 
Capturing batches (avail_mem=32.36 GB):  91%|█████████▏| 32/35 [00:15<00:01,  2.34it/s]
Capturing batches (avail_mem=32.35 GB):  91%|█████████▏| 32/35 [00:15<00:01,  2.34it/s]
(WorkerDict pid=49813) 
Capturing batches (avail_mem=32.35 GB):  94%|█████████▍| 33/35 [00:15<00:00,  2.33it/s]
Capturing batches (avail_mem=32.35 GB):  94%|█████████▍| 33/35 [00:15<00:00,  2.33it/s]
(WorkerDict pid=50076) 
Capturing batches (avail_mem=33.24 GB):  57%|█████▋    | 20/35 [00:08<00:06,  2.38it/s]
Capturing batches (avail_mem=33.16 GB):  57%|█████▋    | 20/35 [00:08<00:06,  2.38it/s] [repeated 47x across cluster]
(WorkerDict pid=49813) kwargs: {'n': 16, 'max_new_tokens': 1024, 'presence_penalty': 0.0, 'frequency_penalty': 0.0, 'repetition_penalty': 1.0, 'temperature': 1.0, 'top_k': -1, 'top_p': 1, 'ignore_eos': False}
(WorkerDict pid=49813) After building sglang_async rollout, memory allocated (GB): 3.55, memory reserved (GB): 17.57, device memory used/total (GB): 26.44/95.00
(WorkerDict pid=49813) After building sharding manager, memory allocated (GB): 3.55, memory reserved (GB): 17.57, device memory used/total (GB): 26.44/95.00
(WorkerDict pid=49813) /usr/local/python/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
(WorkerDict pid=49813)   warnings.warn(
(WorkerDict pid=50074) 
Capturing batches (avail_mem=32.78 GB):  94%|█████████▍| 33/35 [00:13<00:00,  2.49it/s]
Capturing batches (avail_mem=32.78 GB):  94%|█████████▍| 33/35 [00:13<00:00,  2.49it/s] [repeated 8x across cluster]
(WorkerDict pid=50076) 
Capturing batches (avail_mem=32.80 GB):  89%|████████▊ | 31/35 [00:13<00:01,  2.49it/s]
Capturing batches (avail_mem=32.79 GB):  89%|████████▊ | 31/35 [00:13<00:01,  2.49it/s] [repeated 21x across cluster]
(WorkerDict pid=50076) kwargs: {'n': 16, 'max_new_tokens': 1024, 'presence_penalty': 0.0, 'frequency_penalty': 0.0, 'repetition_penalty': 1.0, 'temperature': 1.0, 'top_k': -1, 'top_p': 1, 'ignore_eos': False} [repeated 3x across cluster]
(WorkerDict pid=50076) /usr/local/python/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html . [repeated 3x across cluster]
(WorkerDict pid=50076)   warnings.warn( [repeated 3x across cluster]
(TaskRunner pid=362057) Using LocalLogger is deprecated. The constructor API will change 
(TaskRunner pid=362057) Checkpoint tracker file does not exist: %s /tmp/ray/session_2025-05-01_20-49-28_274415_338741/runtime_resources/working_dir_files/_ray_pkg_4ce66e647b0fe766/checkpoints/test/latest_checkpointed_iteration.txt
(TaskRunner pid=362057) Training from scratch
(TaskRunner pid=362057) test_gen_batch meta info: {'eos_token_id': 151645, 'pad_token_id': 151643, 'recompute_log_prob': False, 'do_sample': False, 'validate': True}
(WorkerDict pid=49813) /usr/local/python/lib/python3.10/site-packages/sglang/srt/utils.py:888: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:203.)
(WorkerDict pid=49813)   tensor_data = torch.ByteTensor(
(WorkerDict pid=50076) 
Capturing batches (avail_mem=32.77 GB): 100%|██████████| 35/35 [00:15<00:00,  2.44it/s]
Capturing batches (avail_mem=32.77 GB): 100%|██████████| 35/35 [00:15<00:00,  2.32it/s] [repeated 6x across cluster]
(WorkerDict pid=50071) /usr/local/python/lib/python3.10/site-packages/sglang/srt/utils.py:888: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:203.) [repeated 3x across cluster]
(WorkerDict pid=50071)   tensor_data = torch.ByteTensor( [repeated 3x across cluster]
(TaskRunner pid=362057) validation generation end
(TaskRunner pid=362057) [prompt] system
(TaskRunner pid=362057) 
(TaskRunner pid=362057)                             You are a math expert. You are given a question and you need to solve it step by step.  
(TaskRunner pid=362057)                             `calc_gsm8k_reward` is a tool for calculating the reward of gsm8k. You should use this 
(TaskRunner pid=362057)                             tool to calculate the reward of your answer(1.0 if your answer is correct, 0.0 if your 
(TaskRunner pid=362057)                             answer is incorrect) before submitting it and refine your answer if necessary. Put your 
(TaskRunner pid=362057)                             final answer in the format of `#### <answer>`.
(TaskRunner pid=362057) 
(TaskRunner pid=362057) # Tools
(TaskRunner pid=362057) 
(TaskRunner pid=362057) You may call one or more functions to assist with the user query.
(TaskRunner pid=362057) 
(TaskRunner pid=362057) You are provided with function signatures within <tools></tools> XML tags:
(TaskRunner pid=362057) <tools>
(TaskRunner pid=362057) {"type": "function", "function": {"name": "calc_gsm8k_reward", "description": "A tool for calculating the reward of gsm8k. (1.0 if your answer is correct, 0.0 if your answer is incorrect)", "parameters": {"type": "object", "properties": {"answer": {"type": "string", "description": "The model's answer to the GSM8K math problem, must be a digits", "enum": null}}, "required": ["answer"]}, "strict": false}}
(TaskRunner pid=362057) </tools>
(TaskRunner pid=362057) 
(TaskRunner pid=362057) For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
(TaskRunner pid=362057) <tool_call>
(TaskRunner pid=362057) {"name": <function-name>, "arguments": <args-json-object>}
(TaskRunner pid=362057) </tool_call>
(TaskRunner pid=362057) user
(TaskRunner pid=362057) Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? 
(TaskRunner pid=362057)         You must use the `calc_gsm8k_reward` tool to calculate the reward 
(TaskRunner pid=362057)         of your answer(1.0 if your answer is correct, 0.0 if your answer is incorrect) 
(TaskRunner pid=362057)         before submitting it at least once and refine your answer if necessary. 
(TaskRunner pid=362057)         Put your final answer in the format of `#### <answer>`.
(TaskRunner pid=362057)     
(TaskRunner pid=362057) assistant
(TaskRunner pid=362057) 
(TaskRunner pid=362057) [response] To find out how much Janet makes every day at the farmers' market, we need to follow these steps:
(TaskRunner pid=362057) 
(TaskRunner pid=362057) 1. Calculate the total number of eggs laid by the ducks per day.
(TaskRunner pid=362057) 2. Subtract the number of eggs Janet eats for breakfast and the number of eggs she uses for baking muffins.
(TaskRunner pid=362057) 3. Multiply the remaining number of eggs by the price per egg to find out how much she makes at the farmers' market.
(TaskRunner pid=362057) 
(TaskRunner pid=362057) Let's do the calculations:
(TaskRunner pid=362057) 
(TaskRunner pid=362057) 1. Total number of eggs laid per day: 16
(TaskRunner pid=362057) 2. Number of eggs Janet eats for breakfast: 3
(TaskRunner pid=362057) 3. Number of eggs Janet uses for baking muffins: 4
(TaskRunner pid=362057) 4. Remaining number of eggs: 16 - 3 - 4 = 9
(TaskRunner pid=362057) 5. Price per egg: $2
(TaskRunner pid=362057) 
(TaskRunner pid=362057) Now, let's calculate the total amount Janet makes at the farmers' market:
(TaskRunner pid=362057) 
(TaskRunner pid=362057) \[ \text{Total amount} = \text{Remaining number of eggs} \times \text{Price per egg} \]
(TaskRunner pid=362057) \[ \text{Total amount} = 9 \times 2 = 18 \]
(TaskRunner pid=362057) 
(TaskRunner pid=362057) So, Janet makes $18 every day at the farmers' market.
(TaskRunner pid=362057) 
(TaskRunner pid=362057) Now, let's use the `calc_gsm8k_reward` tool to check the correctness of our answer.
(TaskRunner pid=362057) <tool_call>
(TaskRunner pid=362057) {"name": "calc_gsm8k_reward", "arguments": "{\"answer\": \"18\"}"}
(TaskRunner pid=362057) </tool_call>
(TaskRunner pid=362057) tool
(TaskRunner pid=362057) Current parsed answer='18' reward=1.0
(TaskRunner pid=362057) assistant
(TaskRunner pid=362057) #### 18
(TaskRunner pid=362057) [ground_truth] 18
(TaskRunner pid=362057) [score] 1.0
(TaskRunner pid=362057) 'Initial validation metrics: {}'
(TaskRunner pid=362057) step:0
(TaskRunner pid=362057) 
Training Progress:   0%|          | 0/4350 [00:00<?, ?it/s]
(TaskRunner pid=362057) list(reward_extra_infos_dict.keys())=[]
(TaskRunner pid=362057) step:1 - global_seqlen/min:379680.000 - global_seqlen/max:457144.000 - global_seqlen/minmax_diff:77464.000 - global_seqlen/balanced_min:416128.000 - global_seqlen/balanced_max:416129.000 - global_seqlen/mean:416128.750 - actor/entropy_loss:0.216 - actor/kl_loss:0.000 - actor/kl_coef:0.001 - actor/pg_loss:-0.004 - actor/pg_clipfrac:0.000 - actor/ppo_kl:0.000 - actor/pg_clipfrac_lower:0.000 - actor/grad_norm:0.010 - perf/mfu/actor:0.700 - perf/max_memory_allocated_gb:49.644 - perf/max_memory_reserved_gb:80.434 - perf/cpu_memory_used_gb:108.922 - actor/lr:0.000 - training/global_step:1.000 - training/epoch:0.000 - critic/score/mean:0.905 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.905 - critic/rewards/max:1.000 - critic/rewards/min:0.000 - critic/advantages/mean:-0.000 - critic/advantages/max:0.653 - critic/advantages/min:-1.436 - critic/returns/mean:-0.000 - critic/returns/max:0.653 - critic/returns/min:-1.436 - response_length/mean:364.025 - response_length/max:1024.000 - response_length/min:150.000 - response_length/clip_ratio:0.026 - prompt_length/mean:448.727 - prompt_length/max:533.000 - prompt_length/min:410.000 - prompt_length/clip_ratio:0.000 - timing_s/gen:118.270 - timing_s/reward:1.181 - timing_s/old_log_prob:52.755 - timing_s/ref:54.759 - timing_s/adv:0.059 - timing_s/update_actor:188.158 - timing_s/step:415.307 - timing_per_token_ms/gen:0.079 - timing_per_token_ms/update_actor:0.057 - timing_per_token_ms/adv:0.000 - timing_per_token_ms/ref:0.016 - perf/total_num_tokens:3329030.000 - perf/time_per_step:415.307 - perf/throughput:1001.979
(TaskRunner pid=362057) 
Training Progress:   0%|          | 1/4350 [06:57<504:26:22, 417.56s/it]
(TaskRunner pid=362057) list(reward_extra_infos_dict.keys())=[]
(WorkerDict pid=49813) WARN: rank 0 grad_norm is not finite: nan
(TaskRunner pid=362057) step:2 - global_seqlen/min:407422.000 - global_seqlen/max:440440.000 - global_seqlen/minmax_diff:33018.000 - global_seqlen/balanced_min:418815.000 - global_seqlen/balanced_max:418816.000 - global_seqlen/mean:418815.625 - actor/entropy_loss:0.218 - actor/kl_loss:0.000 - actor/kl_coef:0.001 - actor/pg_loss:0.006 - actor/pg_clipfrac:0.000 - actor/ppo_kl:0.000 - actor/pg_clipfrac_lower:0.000 - actor/grad_norm:nan - perf/mfu/actor:0.703 - perf/max_memory_allocated_gb:54.050 - perf/max_memory_reserved_gb:80.617 - perf/cpu_memory_used_gb:109.974 - actor/lr:0.000 - training/global_step:2.000 - training/epoch:0.000 - critic/score/mean:0.903 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.903 - critic/rewards/max:1.000 - critic/rewards/min:0.000 - critic/advantages/mean:-0.001 - critic/advantages/max:1.436 - critic/advantages/min:-1.677 - critic/returns/mean:-0.001 - critic/returns/max:1.436 - critic/returns/min:-1.677 - response_length/mean:368.902 - response_length/max:1024.000 - response_length/min:156.000 - response_length/clip_ratio:0.028 - prompt_length/mean:449.098 - prompt_length/max:559.000 - prompt_length/min:409.000 - prompt_length/clip_ratio:0.000 - timing_s/gen:112.222 - timing_s/reward:1.169 - timing_s/old_log_prob:50.543 - timing_s/ref:49.970 - timing_s/adv:0.057 - timing_s/update_actor:188.153 - timing_s/step:402.215 - timing_per_token_ms/gen:0.074 - timing_per_token_ms/update_actor:0.056 - timing_per_token_ms/adv:0.000 - timing_per_token_ms/ref:0.015 - perf/total_num_tokens:3350525.000 - perf/time_per_step:402.215 - perf/throughput:1041.273
(TaskRunner pid=362057) 
Training Progress:   0%|          | 2/4350 [13:40<493:36:46, 408.70s/it]
(TaskRunner pid=362057) list(reward_extra_infos_dict.keys())=[]
(WorkerDict pid=50075) WARN: rank 5 grad_norm is not finite: nan [repeated 7x across cluster]
(WorkerDict pid=49813) WARN: rank 0 grad_norm is not finite: nan
(WorkerDict pid=50070) WARN: rank 1 grad_norm is not finite: nan
(TaskRunner pid=362057) step:3 - global_seqlen/min:387116.000 - global_seqlen/max:433827.000 - global_seqlen/minmax_diff:46711.000 - global_seqlen/balanced_min:416119.000 - global_seqlen/balanced_max:416120.000 - global_seqlen/mean:416119.125 - actor/entropy_loss:0.212 - actor/kl_loss:0.000 - actor/kl_coef:0.001 - actor/pg_loss:-0.001 - actor/pg_clipfrac:0.000 - actor/ppo_kl:-0.000 - actor/pg_clipfrac_lower:0.000 - actor/grad_norm:nan - perf/mfu/actor:0.709 - perf/max_memory_allocated_gb:54.534 - perf/max_memory_reserved_gb:80.617 - perf/cpu_memory_used_gb:109.965 - actor/lr:0.000 - training/global_step:3.000 - training/epoch:0.000 - critic/score/mean:0.890 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.890 - critic/rewards/max:1.000 - critic/rewards/min:0.000 - critic/advantages/mean:-0.000 - critic/advantages/max:0.250 - critic/advantages/min:-3.750 - critic/returns/mean:-0.000 - critic/returns/max:0.250 - critic/returns/min:-3.750 - response_length/mean:364.510 - response_length/max:1024.000 - response_length/min:150.000 - response_length/clip_ratio:0.020 - prompt_length/mean:448.223 - prompt_length/max:534.000 - prompt_length/min:411.000 - prompt_length/clip_ratio:0.000 - timing_s/gen:125.431 - timing_s/reward:1.166 - timing_s/old_log_prob:49.390 - timing_s/ref:49.245 - timing_s/adv:0.058 - timing_s/update_actor:185.307 - timing_s/step:410.658 - timing_per_token_ms/gen:0.084 - timing_per_token_ms/update_actor:0.056 - timing_per_token_ms/adv:0.000 - timing_per_token_ms/ref:0.015 - perf/total_num_tokens:3328953.000 - perf/time_per_step:410.658 - perf/throughput:1013.299
(TaskRunner pid=362057) 
Training Progress:   0%|          | 3/4350 [20:30<494:43:43, 409.71s/it]
(TaskRunner pid=362057) list(reward_extra_infos_dict.keys())=[]
(WorkerDict pid=50075) WARN: rank 5 grad_norm is not finite: nan [repeated 6x across cluster]
(WorkerDict pid=49813) WARN: rank 0 grad_norm is not finite: nan
(WorkerDict pid=50070) WARN: rank 1 grad_norm is not finite: nan
(TaskRunner pid=362057) step:4 - global_seqlen/min:408386.000 - global_seqlen/max:454859.000 - global_seqlen/minmax_diff:46473.000 - global_seqlen/balanced_min:422598.000 - global_seqlen/balanced_max:422599.000 - global_seqlen/mean:422598.750 - actor/entropy_loss:0.216 - actor/kl_loss:0.000 - actor/kl_coef:0.001 - actor/pg_loss:0.003 - actor/pg_clipfrac:0.000 - actor/ppo_kl:0.000 - actor/pg_clipfrac_lower:0.000 - actor/grad_norm:nan - perf/mfu/actor:0.705 - perf/max_memory_allocated_gb:55.359 - perf/max_memory_reserved_gb:80.617 - perf/cpu_memory_used_gb:110.119 - actor/lr:0.000 - training/global_step:4.000 - training/epoch:0.000 - critic/score/mean:0.872 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.872 - critic/rewards/max:1.000 - critic/rewards/min:0.000 - critic/advantages/mean:-0.000 - critic/advantages/max:2.016 - critic/advantages/min:-0.465 - critic/returns/mean:-0.000 - critic/returns/max:2.016 - critic/returns/min:-0.465 - response_length/mean:377.724 - response_length/max:1024.000 - response_length/min:95.000 - response_length/clip_ratio:0.029 - prompt_length/mean:447.664 - prompt_length/max:550.000 - prompt_length/min:415.000 - prompt_length/clip_ratio:0.000 - timing_s/gen:133.983 - timing_s/reward:1.183 - timing_s/old_log_prob:50.797 - timing_s/ref:50.276 - timing_s/adv:0.059 - timing_s/update_actor:190.039 - timing_s/step:426.396 - timing_per_token_ms/gen:0.087 - timing_per_token_ms/update_actor:0.056 - timing_per_token_ms/adv:0.000 - timing_per_token_ms/ref:0.015 - perf/total_num_tokens:3380790.000 - perf/time_per_step:426.396 - perf/throughput:991.094
(TaskRunner pid=362057) 

The versions of my Python packages are as follows:

Package                           Version       Editable project location
--------------------------------- ------------- -------------------------
absl-py                           2.1.0
accelerate                        1.6.0
aiohappyeyeballs                  2.6.1
aiohttp                           3.11.18
aiohttp-cors                      0.8.1
aiosignal                         1.3.2
annotated-types                   0.7.0
anthropic                         0.50.0
antlr4-python3-runtime            4.9.3
anyio                             4.9.0
asttokens                         3.0.0
astunparse                        1.6.3
async-timeout                     5.0.1
attrs                             25.3.0
beautifulsoup4                    4.13.4
blobfile                          3.0.0
boto3                             1.36.16
botocore                          1.36.17
cachetools                        5.5.2
certifi                           2025.4.26
cffi                              1.17.1
cfgv                              3.4.0
chardet                           5.2.0
charset-normalizer                3.4.1
click                             8.1.8
cloudpickle                       3.1.1
cmake                             3.31.4
codetiming                        1.4.0
colorful                          0.5.6
compressed-tensors                0.9.4
cuda-bindings                     12.8.0
cuda-python                       12.8.0
datasets                          3.5.1
decorator                         5.2.1
decord                            0.6.0
dill                              0.3.8
diskcache                         5.6.3
distlib                           0.3.9
distro                            1.9.0
docker-pycreds                    0.4.0
duckduckgo_search                 8.0.1
einops                            0.8.1
einops-exts                       0.0.4
exceptiongroup                    1.2.2
executing                         2.2.0
expecttest                        0.3.0
fastapi                           0.115.12
filelock                          3.18.0
flamingo-pytorch                  0.1.2
flash_attn                        2.7.4.post1
flashinfer-python                 0.2.3
frozenlist                        1.6.0
fsspec                            2025.3.0
ftfy                              6.3.1
gguf                              0.10.0
gitdb                             4.0.12
GitPython                         3.1.44
google-api-core                   2.24.2
google-auth                       2.39.0
googleapis-common-protos          1.70.0
grpcio                            1.71.0
h11                               0.16.0
hf_transfer                       0.1.9
httpcore                          1.0.9
httptools                         0.6.4
httpx                             0.28.1
huggingface-hub                   0.30.2
hydra-core                        1.3.2
hypothesis                        6.125.2
identify                          2.6.10
idna                              3.10
importlib_metadata                8.7.0
interegular                       0.3.3
ipython                           8.36.0
jedi                              0.19.2
Jinja2                            3.1.6
jiter                             0.9.0
jmespath                          1.0.1
jsonschema                        4.23.0
jsonschema-specifications         2025.4.1
lark                              1.2.2
liger_kernel                      0.5.8
lintrunner                        0.12.7
litellm                           1.67.5
llguidance                        0.7.19
llvmlite                          0.44.0
lm-format-enforcer                0.10.6
lxml                              5.4.0
Markdown                          3.7
markdown-it-py                    3.0.0
markdownify                       1.1.0
MarkupSafe                        3.0.2
matplotlib-inline                 0.1.7
mdurl                             0.1.2
mistral_common                    1.5.4
modelscope                        1.25.0
mpmath                            1.3.0
msgpack                           1.1.0
msgspec                           0.19.0
multidict                         6.4.3
multiprocess                      0.70.16
nanobind                          2.7.0
nest-asyncio                      1.6.0
networkx                          3.4.2
ninja                             1.11.1.3
nodeenv                           1.9.1
numba                             0.61.2
numpy                             1.26.4
nvidia-cublas-cu12                12.4.5.8
nvidia-cuda-cupti-cu12            12.4.127
nvidia-cuda-nvrtc-cu12            12.4.127
nvidia-cuda-runtime-cu12          12.4.127
nvidia-cudnn-cu12                 9.1.0.70
nvidia-cufft-cu12                 11.2.1.3
nvidia-curand-cu12                10.3.5.147
nvidia-cusolver-cu12              11.6.1.9
nvidia-cusparse-cu12              12.3.1.170
nvidia-cusparselt-cu12            0.6.2
nvidia-ml-py                      12.570.86
nvidia-nccl-cu12                  2.21.5
nvidia-nvjitlink-cu12             12.4.127
nvidia-nvtx-cu12                  12.4.127
omegaconf                         2.3.0
open_clip_torch                   2.30.0
openai                            1.76.2
opencensus                        0.11.4
opencensus-context                0.1.3
opencv-contrib-python             4.11.0.86
opencv-python                     4.11.0.86
opencv-python-headless            4.11.0.86
optree                            0.14.0
orjson                            3.10.18
outlines                          0.0.46
packaging                         25.0
pandas                            2.2.3
parso                             0.8.4
partial-json-parser               0.2.1.1.post5
peft                              0.15.2
pexpect                           4.9.0
pillow                            11.1.0
pip                               25.0.1
platformdirs                      4.3.7
pre_commit                        4.2.0
primp                             0.15.0
prometheus_client                 0.21.1
prometheus-fastapi-instrumentator 7.1.0
prompt_toolkit                    3.0.51
propcache                         0.3.1
proto-plus                        1.26.1
protobuf                          6.30.2
psutil                            7.0.0
ptyprocess                        0.7.0
pure_eval                         0.2.3
py-cpuinfo                        9.0.0
py-spy                            0.4.0
pyairports                        2.1.1
pyarrow                           20.0.0
pyasn1                            0.6.1
pyasn1_modules                    0.4.2
pybind11                          2.13.6
pycountry                         24.6.1
pycparser                         2.22
pycryptodomex                     3.22.0
pydantic                          2.11.4
pydantic_core                     2.33.2
Pygments                          2.19.1
pylatexenc                        2.10
pynvml                            12.0.0
python-dateutil                   2.9.0.post0
python-dotenv                     1.1.0
python-multipart                  0.0.20
pytz                              2025.2
PyYAML                            6.0.2
pyzmq                             26.4.0
ray                               2.45.0
referencing                       0.36.2
regex                             2024.11.6
requests                          2.32.3
rich                              14.0.0
rpds-py                           0.24.0
rsa                               4.9.1
ruamel.yaml                       0.18.10
ruamel.yaml.clib                  0.2.12
s3transfer                        0.11.2
safetensors                       0.5.3
sentencepiece                     0.2.0
sentry-sdk                        2.27.0
setproctitle                      1.3.6
setuptools                        80.1.0
sgl-kernel                        0.0.9.post2
sglang                            0.4.5.post3
six                               1.17.0
smart-open                        7.1.0
smmap                             5.0.2
smolagents                        1.14.0
sniffio                           1.3.1
sortedcontainers                  2.4.0
soundfile                         0.13.1
soupsieve                         2.7
sox                               1.5.0
stack-data                        0.6.3
starlette                         0.46.2
sympy                             1.13.1
tensorboard                       2.18.0
tensorboard-data-server           0.7.2
tensorboardX                      2.6.2.2
tensordict                        0.6.2
termcolor                         2.5.0
tiktoken                          0.8.0
timm                              1.0.14
tokenizers                        0.21.1
torch                             2.6.0+cu124
torch_memory_saver                0.0.5
torchao                           0.10.0
torchdata                         0.11.0
torchvision                       0.21.0
tqdm                              4.67.1
traitlets                         5.14.3
transformers                      4.51.1
triton                            3.2.0
types-dataclasses                 0.6.6
typing_extensions                 4.13.2
typing-inspection                 0.4.0
tzdata                            2025.2
urllib3                           2.4.0
uvicorn                           0.34.2
uvloop                            0.21.0
verl                              0.2.0.dev0    /root/verl
virtualenv                        20.30.0
vllm                              0.6.3
wandb                             0.19.10
watchfiles                        1.0.5
wcwidth                           0.2.13
websockets                        15.0.1
Werkzeug                          3.1.3
wheel                             0.45.1
wrapt                             1.17.2
xformers                          0.0.27.post2
xgrammar                          0.1.17
xxhash                            3.5.0
yarl                              1.20.0
zipp                              3.21.0

My CUDA version is Cuda compilation tools, release 12.4, V12.4.131.
The specific execution command is:

export VLLM_ATTENTION_BACKEND=XFORMERS
export NCCL_DEBUG=INFO
export GLOO_DEBUG=1

export NCCL_IB_GID_INDEX=3
export NCCL_IB_SL=3
export NCCL_CHECK_DISABLE=1
export NCCL_P2P_DISABLE=0
export NCCL_IB_DISABLE=0
export NCCL_LL_THRESHOLD=16384
export NCCL_IB_CUDA_SUPPORT=1
export NCCL_SOCKET_IFNAME=bond1
export GLOO_SOCKET_IFNAME=bond1

export UCX_NET_DEVICES=bond1
export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
export NCCL_COLLNET_ENABLE=0
export SHARP_COLL_ENABLE_SAT=0
export NCCL_NET_GDR_LEVEL=2
export NCCL_IB_QPS_PER_CONNECTION=4
export NCCL_IB_TC=160
export NCCL_PXN_DISABLE=0
export NCCL_DEBUG="INFO"
export HYDRA_FULL_ERROR=1

ray job submit \
    -- python3 -m verl.trainer.main_ppo \
    --config-path="$CONFIG_PATH" \
    --config-name='gsm8k_multiturn_grpo' \
    algorithm.adv_estimator=grpo \
    data.train_batch_size=256 \
    data.max_prompt_length=1024 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    data.return_raw_chat=True \
    actor_rollout_ref.model.path=model/Qwen2.5-7B-Instruct \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=256 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
    actor_rollout_ref.rollout.name=sglang_async \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
    actor_rollout_ref.rollout.n=${rollout_n} \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger=['console','tensorboard'] \
    trainer.project_name="${project_name}" \
    trainer.experiment_name="${exp_name}" \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=1 \
    trainer.save_freq=-1 \
    trainer.test_freq=20 \
    data.train_files=${TRAIN_FILE} \
    data.val_files=${TEST_FILE} \ actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \
    trainer.total_epochs=150

I sincerely hope to get some help. Thank you very much!

Thanks for sharing your settings with early nan. It appears to have some instability during the training process. We'll increase monitoring in wandb and examine it closely.

Same issue here. Any update on how this is solved? Thanks!

After several days of investigation, we identified the issue might be caused by an excessively small log_p value. You can check the wandb log for details.

image

The log_p value around -32 corresponds to a probability of approximately 1e-14, which is significantly lower than the typical bfloat16 epsilon of 1e-6. This discrepancy leads to numerical instability. To address this, we are constructing an SFT dataset and initializing a cold-start model to resolve the issue.

@jiani-huang
Copy link

After several days of investigation, we identified the issue might be caused by an excessively small log_p value. You can check the wandb log for details.

image

The log_p value around -32 corresponds to a probability of approximately 1e-14, which is significantly lower than the typical bfloat16 epsilon of 1e-6. This discrepancy leads to numerical instability. To address this, we are constructing an SFT dataset and initializing a cold-start model to resolve the issue.

Thank you for your prompt response. Following your suggestion, I also reviewed my training logs. I am training the qwen-7b-instruct model on a custom dataset that requires tool calling, and I have encountered a similar issue where the log_p values are very small.

Here are some of the metrics from my training log:

(WorkerDict pid=3883786) {'actor/pg_loss': -0.16048212349414825, 'actor/pg_clipfrac': 0.0, 'actor/ppo_kl': 0.0, 'actor/pg_clipfrac_lower': 0.0, 'actor/policy_loss/old_log_prob': -0.5826774835586548, 'actor/policy_loss/log_prob': -0.58267742395401, 'actor/policy_loss/advantages': 0.04711419716477394, 'actor/policy_loss/old_log_prob_max': 0.0, 'actor/policy_loss/old_log_prob_min': -35.42192077636719, 'actor/policy_loss/log_prob_max': 0.0, 'actor/policy_loss/log_prob_min': -35.42192077636719, 'actor/policy_loss/advantages_max': 0.7302938103675842, 'actor/policy_loss/advantages_min': -0.7302965521812439}

I would like to know if constructing an SFT dataset for cold start is the only solution to this problem? Additionally, I noticed in the training logs provided in the verl multi-turn rollout documentation (https://wandb.ai/zhaochenyang20/gsm8k_async_rl/runs/1ro1r7om?nw=nwuserzhaochenyang20), which is for a 3B model trained on GSM8K, there is no issue with grad_norm becoming NaN. Why is that the difference?

@SwordFaith
Copy link
Collaborator Author

After several days of investigation, we identified the issue might be caused by an excessively small log_p value. You can check the wandb log for details.
image
The log_p value around -32 corresponds to a probability of approximately 1e-14, which is significantly lower than the typical bfloat16 epsilon of 1e-6. This discrepancy leads to numerical instability. To address this, we are constructing an SFT dataset and initializing a cold-start model to resolve the issue.

Thank you for your prompt response. Following your suggestion, I also reviewed my training logs. I am training the qwen-7b-instruct model on a custom dataset that requires tool calling, and I have encountered a similar issue where the log_p values are very small.

Here are some of the metrics from my training log:

(WorkerDict pid=3883786) {'actor/pg_loss': -0.16048212349414825, 'actor/pg_clipfrac': 0.0, 'actor/ppo_kl': 0.0, 'actor/pg_clipfrac_lower': 0.0, 'actor/policy_loss/old_log_prob': -0.5826774835586548, 'actor/policy_loss/log_prob': -0.58267742395401, 'actor/policy_loss/advantages': 0.04711419716477394, 'actor/policy_loss/old_log_prob_max': 0.0, 'actor/policy_loss/old_log_prob_min': -35.42192077636719, 'actor/policy_loss/log_prob_max': 0.0, 'actor/policy_loss/log_prob_min': -35.42192077636719, 'actor/policy_loss/advantages_max': 0.7302938103675842, 'actor/policy_loss/advantages_min': -0.7302965521812439}

I would like to know if constructing an SFT dataset for cold start is the only solution to this problem? Additionally, I noticed in the training logs provided in the verl multi-turn rollout documentation (https://wandb.ai/zhaochenyang20/gsm8k_async_rl/runs/1ro1r7om?nw=nwuserzhaochenyang20), which is for a 3B model trained on GSM8K, there is no issue with grad_norm becoming NaN. Why is that the difference?

I apologize for the earlier assumption that the log_prob issue was solely based on differences in stable math/code training, as it is not the root cause. After testing with cold start SFT variants, we found that adding cold start SFT significantly increases log_prob_min and reduces some occurrences of NaN, but it does not completely resolve the issue. During this process, we identified and fixed a few bugs in the previous implementation, which seem to have resolved the NaN issue in the new Weights & Biases (wandb) logs and the new PR #1475. If you could help reproduce the results, it would be greatly appreciated. We sincerely apologize for the bugs in the earlier development release that caused any inconvenience.

@jiani-huang
Copy link

I apologize for the earlier assumption that the log_prob issue was solely based on differences in stable math/code training, as it is not the root cause. After testing with cold start SFT variants, we found that adding cold start SFT significantly increases log_prob_min and reduces some occurrences of NaN, but it does not completely resolve the issue. During this process, we identified and fixed a few bugs in the previous implementation, which seem to have resolved the NaN issue in the new Weights & Biases (wandb) logs and the new PR #1475. If you could help reproduce the results, it would be greatly appreciated. We sincerely apologize for the bugs in the earlier development release that caused any inconvenience.

Thank you very much! After running with your new PR #1475, the NaN issue of my experiments is fixed.

@supermancmk
Copy link

Thank you very much, but it still doesn't seem to work and doesn't solve the problem of the model training crash.
I pulled the latest version of verl's code without any modifications and ran the official gsm8k with tool example and the model still training crashes and appears Nan.
Any solutions please. Here is my wandb log
@SwordFaith @zhaochenyang20

非常感谢,但好像还是不起作用,没有解决模型训练崩溃的问题。我拉取了verl最新版本的代码,没有进行任何修改,跑官方gsm8k with tool的例子,模型仍然会训崩,请问有什么解决办法。下面是我wandb日志

image
image
image
image
image

@FloSophorae
Copy link

FloSophorae commented May 27, 2025

Hello, may I ask if this PR supports the use of multi-turn + tools in multimodal models? If it is supported, how should I use it? And how can I integrate my own image processing tools into it?
@SwordFaith

@SwordFaith
Copy link
Collaborator Author

Hello, may I ask if this PR supports the use of multi-turn + tools in multimodal models? If it is supported, how should I use it? And how can I integrate my own image processing tools into it? @SwordFaith

We are currently addressing some issues that need to be resolved for the Qwen 2.5 VL training, which is part of our Update 2. For more details, you can refer to this link: zhaochenyang20/Awesome-ML-SYS-Tutorial#132. If you're interested in collaborating with us, feel free to add me on WeChat. My ID is swordfaith.

GitMonkey0 pushed a commit to GitMonkey0/verl that referenced this pull request Jun 14, 2025
…olcengine#1037)

A redesigned version of volcengine#917 

## Current Status
[Develop log &
Tracker](zhaochenyang20/Awesome-ML-SYS-Tutorial#113)

**What Has Been Done**
- Async Rollout Refactoring: Integrate with the tool server to
coordinate tool calls during generation, leveraging request IDs for
state and progress tracking, support async multi-turn conversations in
Agentic RL training (with Tool support).
- Async Request Management: Encapsulate rollout requests into a unified
structure, enabling efficient tracking and handling of concurrent
multi-turn dialogues with chatml style messages.
- Extensible Tools: A modular design for adapt tools in
OpenAIFunctionTool format which is both support by SGLang and vLLM, with
create separate instance, execute when tool call, calc score according
to tool env state and release resource.
- Multi-turn support has been implemented for the GSM8K task (new
version working on). However, training has not yet converged, and we
hope the community could join to investigate the issue.

**What Is WIP**
- [x] Merge loss mask to training process from last version
- [x] Add more user friendly tool config and e2e tests for gsm8k with
tool training
- [ ] We are going to validate our multiturn feature in open-source
sandbox environments.

## Key Features will be introduced in future version

- Integrate a Ray-based agent trainer to enable explicit separation of
the rollout and training pipeline. Provide support for partial rollout
handling and fine-grained request state management.
- Extend the framework to support simulated user interactions (e.g.,
roleplay, interactive feedback) and more complex environment-in-the-loop
RL tasks.

**Future Plan**
[Discussion
Thread](zhaochenyang20/Awesome-ML-SYS-Tutorial#74 (comment))
[RFC
doc](https://github.com/SwordFaith/verl-sglang-dev-log/blob/main/rlhf/verl/multi-turn/veRL-multiturn-rollout-RFC.md)
will be updated soon.

## Contributors & Acknowledgement

- Xiang Long [mid.of.change@gmail.com](mailto:mid.of.change@gmail.com)
@SwordFaith (Design RFC & core-dev of refactor part)
- Yuzhen Zhou [zyzshishui@gmail.com](mailto:zyzshishui@gmail.com)
@zyzshishui (Core-dev)
- Chenyang Zhao [zhaochen20@outlook.com](mailto:zhaochen20@outlook.com)
@zhaochenyang20 (PM)
- Guanhua Wang @WANG-GH 
- Junrong Lin @ocss884 (verl-sglang support)
- Hanchen Zhang
[zhanghanchen77@gmail.com](mailto:zhanghanchen77@gmail.com)
- Haoran Wang [ubecwang@gmail.com](mailto:ubecwang@gmail.com)
- Rui Lu [learningrate1@gmail.com](mailto:learningrate1@gmail.com)
- Yujiang Li [liyujiang2020@gmail.com](mailto:liyujiang2020@gmail.com)
- Jiajun Li [guapisolo@gmail.com](mailto:guapisolo@gmail.com)
- Jin Pan [jpan236@wisc.edu](mailto:jpan236@wisc.edu)
- Zhi Zheng [zhengzhi@modelbest.cn](mailto:zhengzhi@modelbest.cn)
@zh-zheng

---------

Co-authored-by: zyzshishui <492129152@qq.com>
Co-authored-by: guanhua <281484683@qq.com>
Co-authored-by: zhaochenyang20 <zhaochen20@outlook.com>
Co-authored-by: ocss884 <ocss.lin@gmail.com>
Co-authored-by: Shawn/Yuxuan Tong <tongyuxuan361@gmail.com>
Co-authored-by: HL <linhaibin.eric@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.