Skip to content

Commit 4fa7ed6

Browse files
authored
[mcore] qwen2moe support (#1139)
support qwen2moe structure to run with megatron-core including: * qwen2moe config converter * qwen2moe model initializer * refactor the online weight converter from mcore to vllm * qwen2moe online weight converter * qwen2moe offline weight conversion script from hf to mcore * a script to run training qwen1.5moe_a2.7b with 4 nodes TODO add option to freeze the MoE router weight during training
1 parent c54ec18 commit 4fa7ed6

File tree

13 files changed

+564
-108
lines changed

13 files changed

+564
-108
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
set -x
2+
# 0. download the model
3+
huggingface-cli download Qwen/Qwen1.5-MoE-A2.7B-Chat
4+
5+
# 1. convert the model to mcore format
6+
# change the HF_MODEL_PATH and DIST_CKPT_PATH to your own path
7+
HF_MODEL_PATH=/data/models/Qwen/Qwen1.5-MoE-A2.7B-Chat
8+
DIST_CKPT_PATH=/data/mcore_ckpt/Qwen1.5-MoE-A2.7B-Chat
9+
python scripts/converter_hf_to_mcore.py --hf_model_path $HF_MODEL_PATH --output_path $DIST_CKPT_PATH
10+
11+
# 2. run the script
12+
gsm8k_train_path=$HOME/data/gsm8k/train.parquet
13+
gsm8k_test_path=$HOME/data/gsm8k/test.parquet
14+
train_files=$gsm8k_train_path
15+
test_files=$gsm8k_test_path
16+
17+
NODES=4
18+
PP=2
19+
TP=4
20+
CP=1
21+
VLLM_TP=4
22+
23+
# RAY_ADDRESS='auto' ray job submit --working-dir . --
24+
python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\
25+
algorithm.adv_estimator=gae \
26+
data.train_files="$train_files" \
27+
data.val_files="$test_files" \
28+
data.train_batch_size=1024 \
29+
data.max_prompt_length=1024 \
30+
data.max_response_length=512 \
31+
data.filter_overlong_prompts=True \
32+
data.truncation='error' \
33+
actor_rollout_ref.model.path=$HF_MODEL_PATH \
34+
actor_rollout_ref.actor.optim.lr=1e-6 \
35+
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
36+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
37+
actor_rollout_ref.actor.use_kl_loss=False \
38+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \
39+
actor_rollout_ref.rollout.name=vllm \
40+
actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
41+
critic.optim.lr=1e-5 \
42+
critic.model.path=$HF_MODEL_PATH \
43+
critic.model.enable_gradient_checkpointing=False \
44+
critic.ppo_micro_batch_size_per_gpu=4 \
45+
algorithm.use_kl_in_reward=False \
46+
trainer.critic_warmup=0 \
47+
trainer.logger=['console','wandb'] \
48+
trainer.project_name='verl_megatron_gsm8k_examples' \
49+
trainer.experiment_name='qwen1.5_moe_nochat' \
50+
trainer.n_gpus_per_node=8 \
51+
trainer.nnodes=$NODES \
52+
trainer.save_freq=-1 \
53+
trainer.test_freq=5 \
54+
actor_rollout_ref.rollout.tensor_model_parallel_size=$VLLM_TP \
55+
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \
56+
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$PP \
57+
critic.megatron.pipeline_model_parallel_size=$PP \
58+
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \
59+
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$TP \
60+
critic.megatron.tensor_model_parallel_size=$TP \
61+
actor_rollout_ref.actor.megatron.context_parallel_size=$CP \
62+
actor_rollout_ref.ref.megatron.context_parallel_size=$CP \
63+
critic.megatron.context_parallel_size=$CP \
64+
actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \
65+
actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \
66+
critic.megatron.use_dist_checkpointing=True \
67+
actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \
68+
actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \
69+
critic.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \
70+
trainer.total_epochs=100 $@
71+

scripts/converter_hf_to_mcore.py

Lines changed: 73 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@
2222
from megatron.core import parallel_state as mpu
2323
from megatron.core.dist_checkpointing.serialization import StrictHandling
2424
from megatron.core.models.gpt.gpt_model import ModelType
25+
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
2526
from transformers import AutoConfig, AutoModelForCausalLM
2627

27-
from verl.utils.megatron_utils import convert_config, get_model
28+
from verl.models.mcore import hf_to_mcore_config
29+
from verl.utils.megatron_utils import get_model
2830

2931

3032
def _init_args():
@@ -51,6 +53,49 @@ def __init__(self):
5153
self.model = ModelConfig()
5254

5355

56+
def convert_checkpoint_from_transformers_to_megatron(hf_model, model, hf_config):
57+
num_attention_heads = hf_config.num_attention_heads
58+
hidden_dim = hf_config.hidden_size
59+
head_dim = hidden_dim // num_attention_heads
60+
with torch.no_grad():
61+
model.embedding.word_embeddings.weight.copy_(hf_model.model.embed_tokens.weight)
62+
for layer, hf_layer in zip(model.decoder.layers, hf_model.model.layers):
63+
layer.self_attention.linear_qkv.layer_norm_weight.copy_(hf_layer.input_layernorm.weight)
64+
65+
q = hf_layer.self_attn.q_proj.weight.view([num_attention_heads, -1, head_dim, hidden_dim])
66+
k = hf_layer.self_attn.k_proj.weight.view([num_attention_heads, -1, head_dim, hidden_dim])
67+
v = hf_layer.self_attn.v_proj.weight.view([num_attention_heads, -1, head_dim, hidden_dim])
68+
qkv = torch.cat([q, k, v], dim=1).view(-1, hidden_dim).contiguous()
69+
70+
q_bias = hf_layer.self_attn.q_proj.bias.view([num_attention_heads, -1])
71+
k_bias = hf_layer.self_attn.k_proj.bias.view([num_attention_heads, -1])
72+
v_bias = hf_layer.self_attn.v_proj.bias.view([num_attention_heads, -1])
73+
qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view(-1).contiguous()
74+
75+
layer.self_attention.linear_qkv.weight.copy_(qkv)
76+
layer.self_attention.linear_qkv.bias.copy_(qkv_bias)
77+
78+
layer.self_attention.linear_proj.weight.copy_(hf_layer.self_attn.o_proj.weight)
79+
layer.pre_mlp_layernorm.weight.copy_(hf_layer.post_attention_layernorm.weight)
80+
81+
layer.mlp.router.weight.copy_(hf_layer.mlp.gate.weight)
82+
83+
for idx, hf_expert in enumerate(hf_layer.mlp.experts):
84+
fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight])
85+
layer.mlp.experts.linear_fc1._parameters[f"weight{idx}"].copy_(fc1_weight)
86+
layer.mlp.experts.linear_fc2._parameters[f"weight{idx}"].copy_(hf_expert.down_proj.weight)
87+
88+
layer.mlp.shared_experts.gate_weight.copy_(hf_layer.mlp.shared_expert_gate.weight)
89+
shared_fc1_weight = torch.cat(
90+
[hf_layer.mlp.shared_expert.gate_proj.weight, hf_layer.mlp.shared_expert.up_proj.weight]
91+
)
92+
layer.mlp.shared_experts.linear_fc1.weight.copy_(shared_fc1_weight)
93+
layer.mlp.shared_experts.linear_fc2.weight.copy_(hf_layer.mlp.shared_expert.down_proj.weight)
94+
95+
model.decoder.final_layernorm.weight.copy_(hf_model.model.norm.weight)
96+
model.output_layer.weight.copy_(hf_model.lm_head.weight)
97+
98+
5499
def convert_hf_to_mcore(hf_model_path, output_path, test=False):
55100
os.makedirs(output_path, exist_ok=True)
56101
if len(os.listdir(output_path)) > 0 and not test:
@@ -69,21 +114,22 @@ def convert_hf_to_mcore(hf_model_path, output_path, test=False):
69114
context_parallel_size=1,
70115
expert_model_parallel_size=1,
71116
)
117+
model_parallel_cuda_manual_seed(0)
72118

73119
# init hf config
74120
hf_config = AutoConfig.from_pretrained(hf_model_path)
75121
print(hf_config)
76-
megatron_config = MegatronConfig()
122+
77123
cfg = Config()
78124
cfg.model.path = hf_model_path
79-
tfconfig = convert_config(hf_config, megatron_config)
125+
tfconfig = hf_to_mcore_config(hf_config, torch.bfloat16)
80126
tie_word_embeddings = getattr(hf_config, "tie_word_embeddings", False)
81127

82128
# init megatron model
83129
def megatron_model_provider(pre_process, post_process):
84-
from verl.utils.model import get_parallel_gptmodel_from_config
130+
from verl.models.mcore import init_mcore_model
85131

86-
parallel_model = get_parallel_gptmodel_from_config(
132+
parallel_model = init_mcore_model(
87133
tfconfig,
88134
hf_config,
89135
pre_process,
@@ -94,27 +140,31 @@ def megatron_model_provider(pre_process, post_process):
94140
return parallel_model
95141

96142
model = get_model(
97-
model_provider_func=megatron_model_provider, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True
143+
model_provider_func=megatron_model_provider, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=False
98144
)
99145

100146
with warnings.catch_warnings():
101147
warnings.simplefilter("ignore")
102148

103149
# init hf model
104-
hf_model = AutoModelForCausalLM.from_pretrained(hf_model_path)
150+
hf_model = AutoModelForCausalLM.from_pretrained(hf_model_path, torch_dtype=torch.bfloat16)
105151
ref_state_dict = hf_model.state_dict()
106152

107153
# load hf state dict to megatron model
108-
from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel
109-
110-
load_state_dict_to_megatron_gptmodel(
111-
state_dict=ref_state_dict,
112-
wrapped_models=model,
113-
config=hf_config,
114-
params_dtype=torch.bfloat16,
115-
is_value_model=False,
116-
)
117-
ssd = model[0].module.module.sharded_state_dict()
154+
if "Qwen2MoeForCausalLM" in hf_config.architectures:
155+
convert_checkpoint_from_transformers_to_megatron(hf_model, model[0].module, hf_config)
156+
else:
157+
from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel
158+
159+
load_state_dict_to_megatron_gptmodel(
160+
state_dict=ref_state_dict,
161+
wrapped_models=model,
162+
config=hf_config,
163+
params_dtype=torch.bfloat16,
164+
is_value_model=False,
165+
)
166+
167+
ssd = model[0].module.sharded_state_dict()
118168
del ref_state_dict, hf_model
119169

120170
# save megatron model
@@ -126,11 +176,11 @@ def megatron_model_provider(pre_process, post_process):
126176
model_test = get_model(
127177
model_provider_func=megatron_model_provider, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True
128178
)
129-
ssd2 = model_test[0].module.module.sharded_state_dict()
179+
ssd2 = model_test[0].module.sharded_state_dict()
130180
dist_checkpointing.load(ssd2, output_path, strict=StrictHandling.ASSUME_OK_UNEXPECTED)
131181

132-
sd = model[0].module.module.state_dict()
133-
sd2 = model_test[0].module.module.state_dict()
182+
sd = model[0].module.state_dict()
183+
sd2 = model_test[0].module.state_dict()
134184
for k in sd.keys():
135185
if sd[k] is None:
136186
continue
@@ -163,11 +213,11 @@ def megatron_value_model_provider(pre_process, post_process):
163213
model_type=ModelType.encoder_or_decoder,
164214
wrap_with_ddp=True,
165215
)
166-
ssd2 = model_value[0].module.module.sharded_state_dict()
216+
ssd2 = model_value[0].module.sharded_state_dict()
167217
dist_checkpointing.load(ssd2, output_path, strict=StrictHandling.IGNORE_ALL)
168218

169-
sd = model[0].module.module.state_dict()
170-
sd2 = model_value[0].module.module.state_dict()
219+
sd = model[0].module.state_dict()
220+
sd2 = model_value[0].module.state_dict()
171221
for k in sd.keys():
172222
if sd[k] is None:
173223
continue

verl/models/mcore/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from .registry import get_mcore_forward_fn, hf_to_mcore_config, init_mcore_model
16+
from .registry import get_mcore_forward_fn, get_mcore_weight_converter, hf_to_mcore_config, init_mcore_model
1717

18-
__all__ = ["init_mcore_model", "hf_to_mcore_config", "get_mcore_forward_fn"]
18+
__all__ = ["init_mcore_model", "hf_to_mcore_config", "get_mcore_forward_fn", "get_mcore_weight_converter"]

verl/models/mcore/config_converter.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,70 @@ def hf_to_mcore_config_dense(hf_config: PretrainedConfig, dtype: torch.dtype) ->
6666

6767

6868
def hf_to_mcore_config_qwen2moe(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig:
69-
# Qwen2MoeForCausalLM
70-
raise NotImplementedError("Qwen2MoeForCausalLM is not supported yet")
69+
from megatron.core import parallel_state as mpu
70+
71+
overlap_p2p_comm = (
72+
mpu.get_virtual_pipeline_model_parallel_world_size() is not None
73+
and mpu.get_virtual_pipeline_model_parallel_world_size() > 1
74+
)
75+
batch_p2p_comm = False
76+
transformer_config = TransformerConfig(
77+
num_layers=hf_config.num_hidden_layers,
78+
hidden_size=hf_config.hidden_size,
79+
num_attention_heads=hf_config.num_attention_heads,
80+
num_query_groups=hf_config.num_key_value_heads,
81+
attention_dropout=hf_config.attention_dropout,
82+
hidden_dropout=getattr(hf_config, "hidden_dropout", 0.0),
83+
activation_func=F.silu,
84+
normalization="RMSNorm",
85+
gated_linear_unit=True,
86+
use_cpu_initialization=False,
87+
add_bias_linear=False,
88+
pipeline_dtype=dtype,
89+
params_dtype=dtype,
90+
variable_seq_lengths=True,
91+
masked_softmax_fusion=True,
92+
attention_backend=AttnBackend.flash,
93+
# attention_backend=AttnBackend.fused,
94+
bf16=dtype is torch.bfloat16,
95+
layernorm_epsilon=hf_config.rms_norm_eps,
96+
ffn_hidden_size=hf_config.intermediate_size,
97+
# parallel config
98+
tensor_model_parallel_size=mpu.get_tensor_model_parallel_world_size(),
99+
pipeline_model_parallel_size=mpu.get_pipeline_model_parallel_world_size(),
100+
virtual_pipeline_model_parallel_size=mpu.get_virtual_pipeline_model_parallel_world_size(),
101+
context_parallel_size=mpu.get_context_parallel_world_size(),
102+
overlap_p2p_comm=overlap_p2p_comm,
103+
batch_p2p_comm=batch_p2p_comm,
104+
sequence_parallel=mpu.get_tensor_model_parallel_world_size() > 1,
105+
# moe specific
106+
moe_ffn_hidden_size=hf_config.moe_intermediate_size,
107+
moe_token_dispatcher_type="alltoall",
108+
moe_router_bias_update_rate=0.001,
109+
moe_router_topk=hf_config.num_experts_per_tok,
110+
num_moe_experts=hf_config.num_experts,
111+
moe_shared_expert_intermediate_size=hf_config.shared_expert_intermediate_size,
112+
moe_aux_loss_coeff=hf_config.router_aux_loss_coef,
113+
# moe_aux_loss_coeff=0.0,
114+
moe_router_load_balancing_type="aux_loss",
115+
moe_shared_expert_overlap=True,
116+
# moe_permute_fusion=True, # need TE 2.1+
117+
moe_grouped_gemm=True,
118+
moe_router_score_function="softmax",
119+
# # mcore 0.12 moe
120+
# moe_router_dtype="fp64",
121+
# disable_bf16_reduced_precision_matmul=True,
122+
# other
123+
# deallocate_pipeline_outputs=True,
124+
# gradient_accumulation_fusion=True,
125+
persist_layer_norm=True,
126+
bias_activation_fusion=True,
127+
bias_dropout_fusion=True,
128+
# qwen specific
129+
moe_router_pre_softmax=True,
130+
add_qkv_bias=True,
131+
)
132+
return transformer_config
71133

72134

73135
def hf_to_mcore_config_dpskv3(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig:

0 commit comments

Comments
 (0)