-
Notifications
You must be signed in to change notification settings - Fork 3k
[megatron] feat: support gpt-oss #4323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for gpt-oss by introducing a bshd data format to handle models that do not support the existing thd format. The changes are spread across model forwarding logic, configuration files, and utility functions. A new test script for gpt-oss is also included.
My review has identified a critical bug in the model_forward.py file where a hardcoded sequence_parallel flag can lead to runtime errors due to shape mismatches. Additionally, there's an inconsistency in the new test script where the default model path does not match the path where the model is saved, which would cause the script to fail. I've provided suggestions to fix both issues.
verl/models/mcore/model_forward.py
Outdated
| ) | ||
| if post_process and logits_processor is not None: | ||
| args = { | ||
| k: preprocess_bshd(v, attention_mask, position_ids, sequence_parallel=True, pre_process=True)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sequence_parallel argument for preprocess_bshd is hardcoded to True when processing logits_processor_args. This is inconsistent with the processing of input_ids, which uses the sp variable. If sequence parallelism is disabled (sp=False), this will lead to a shape mismatch between the model's output and the logits_processor arguments, causing a runtime error. The sp variable should be used for consistency.
| k: preprocess_bshd(v, attention_mask, position_ids, sequence_parallel=True, pre_process=True)[0] | |
| k: preprocess_bshd(v, attention_mask, position_ids, sequence_parallel=sp, pre_process=True)[0] |
| NNODES=${NNODES:-1} | ||
| # Paths | ||
| RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} | ||
| MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/gpt-oss-20b"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default MODEL_PATH is inconsistent with the output_dir in the get_model.py script. The get_model.py script saves the model to "$HOME/models/gpt-oss-20b-bf16", but MODEL_PATH defaults to "${RAY_DATA_HOME}/models/gpt-oss-20b", which resolves to "${HOME}/verl/models/gpt-oss-20b". This will cause the script to fail with a "model not found" error unless MODEL_PATH is explicitly set. To ensure consistency, the default MODEL_PATH should point to the correct directory.
| MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/gpt-oss-20b"} | |
| MODEL_PATH=${MODEL_PATH:-"$HOME/models/gpt-oss-20b-bf16"} |
for now (latest TE=2.10), gptoss's optimized attn kernel is not supported for thd format, so we use bshd format here. when bshd format is used, we need to pad the input_ids to the longest sequence length so we recommend to disable dynamic batch size and set micro batch size to 1 to avoid paddings but it is ok to try with micro_batch_size>1 see `test_dapo_gptoss_20b_megatron.sh` for example. <img width="1299" height="867" alt="image" src="https://github.com/user-attachments/assets/b166a4b7-9c3a-4840-84c1-e8de02b506db" /> The training crashes with mismatch, need further experiments with MIS/TIS or fp16
for now (latest TE=2.10), gptoss's optimized attn kernel is not supported for thd format, so we use bshd format here.
when bshd format is used, we need to pad the input_ids to the longest sequence length
so we recommend to disable dynamic batch size and set micro batch size to 1 to avoid paddings
but it is ok to try with micro_batch_size>1
see
test_dapo_gptoss_20b_megatron.shfor example.The training crashes with mismatch, need further experiments with MIS/TIS or fp16