22
22
from megatron .core import parallel_state as mpu
23
23
from megatron .core .dist_checkpointing .serialization import StrictHandling
24
24
from megatron .core .models .gpt .gpt_model import ModelType
25
+ from megatron .core .tensor_parallel .random import model_parallel_cuda_manual_seed
25
26
from transformers import AutoConfig , AutoModelForCausalLM
26
27
27
- from verl .utils .megatron_utils import convert_config , get_model
28
+ from verl .models .mcore import hf_to_mcore_config
29
+ from verl .utils .megatron_utils import get_model
28
30
29
31
30
32
def _init_args ():
@@ -51,6 +53,49 @@ def __init__(self):
51
53
self .model = ModelConfig ()
52
54
53
55
56
+ def convert_checkpoint_from_transformers_to_megatron (hf_model , model , hf_config ):
57
+ num_attention_heads = hf_config .num_attention_heads
58
+ hidden_dim = hf_config .hidden_size
59
+ head_dim = hidden_dim // num_attention_heads
60
+ with torch .no_grad ():
61
+ model .embedding .word_embeddings .weight .copy_ (hf_model .model .embed_tokens .weight )
62
+ for layer , hf_layer in zip (model .decoder .layers , hf_model .model .layers ):
63
+ layer .self_attention .linear_qkv .layer_norm_weight .copy_ (hf_layer .input_layernorm .weight )
64
+
65
+ q = hf_layer .self_attn .q_proj .weight .view ([num_attention_heads , - 1 , head_dim , hidden_dim ])
66
+ k = hf_layer .self_attn .k_proj .weight .view ([num_attention_heads , - 1 , head_dim , hidden_dim ])
67
+ v = hf_layer .self_attn .v_proj .weight .view ([num_attention_heads , - 1 , head_dim , hidden_dim ])
68
+ qkv = torch .cat ([q , k , v ], dim = 1 ).view (- 1 , hidden_dim ).contiguous ()
69
+
70
+ q_bias = hf_layer .self_attn .q_proj .bias .view ([num_attention_heads , - 1 ])
71
+ k_bias = hf_layer .self_attn .k_proj .bias .view ([num_attention_heads , - 1 ])
72
+ v_bias = hf_layer .self_attn .v_proj .bias .view ([num_attention_heads , - 1 ])
73
+ qkv_bias = torch .cat ([q_bias , k_bias , v_bias ], dim = 1 ).view (- 1 ).contiguous ()
74
+
75
+ layer .self_attention .linear_qkv .weight .copy_ (qkv )
76
+ layer .self_attention .linear_qkv .bias .copy_ (qkv_bias )
77
+
78
+ layer .self_attention .linear_proj .weight .copy_ (hf_layer .self_attn .o_proj .weight )
79
+ layer .pre_mlp_layernorm .weight .copy_ (hf_layer .post_attention_layernorm .weight )
80
+
81
+ layer .mlp .router .weight .copy_ (hf_layer .mlp .gate .weight )
82
+
83
+ for idx , hf_expert in enumerate (hf_layer .mlp .experts ):
84
+ fc1_weight = torch .cat ([hf_expert .gate_proj .weight , hf_expert .up_proj .weight ])
85
+ layer .mlp .experts .linear_fc1 ._parameters [f"weight{ idx } " ].copy_ (fc1_weight )
86
+ layer .mlp .experts .linear_fc2 ._parameters [f"weight{ idx } " ].copy_ (hf_expert .down_proj .weight )
87
+
88
+ layer .mlp .shared_experts .gate_weight .copy_ (hf_layer .mlp .shared_expert_gate .weight )
89
+ shared_fc1_weight = torch .cat (
90
+ [hf_layer .mlp .shared_expert .gate_proj .weight , hf_layer .mlp .shared_expert .up_proj .weight ]
91
+ )
92
+ layer .mlp .shared_experts .linear_fc1 .weight .copy_ (shared_fc1_weight )
93
+ layer .mlp .shared_experts .linear_fc2 .weight .copy_ (hf_layer .mlp .shared_expert .down_proj .weight )
94
+
95
+ model .decoder .final_layernorm .weight .copy_ (hf_model .model .norm .weight )
96
+ model .output_layer .weight .copy_ (hf_model .lm_head .weight )
97
+
98
+
54
99
def convert_hf_to_mcore (hf_model_path , output_path , test = False ):
55
100
os .makedirs (output_path , exist_ok = True )
56
101
if len (os .listdir (output_path )) > 0 and not test :
@@ -69,21 +114,22 @@ def convert_hf_to_mcore(hf_model_path, output_path, test=False):
69
114
context_parallel_size = 1 ,
70
115
expert_model_parallel_size = 1 ,
71
116
)
117
+ model_parallel_cuda_manual_seed (0 )
72
118
73
119
# init hf config
74
120
hf_config = AutoConfig .from_pretrained (hf_model_path )
75
121
print (hf_config )
76
- megatron_config = MegatronConfig ()
122
+
77
123
cfg = Config ()
78
124
cfg .model .path = hf_model_path
79
- tfconfig = convert_config (hf_config , megatron_config )
125
+ tfconfig = hf_to_mcore_config (hf_config , torch . bfloat16 )
80
126
tie_word_embeddings = getattr (hf_config , "tie_word_embeddings" , False )
81
127
82
128
# init megatron model
83
129
def megatron_model_provider (pre_process , post_process ):
84
- from verl .utils . model import get_parallel_gptmodel_from_config
130
+ from verl .models . mcore import init_mcore_model
85
131
86
- parallel_model = get_parallel_gptmodel_from_config (
132
+ parallel_model = init_mcore_model (
87
133
tfconfig ,
88
134
hf_config ,
89
135
pre_process ,
@@ -94,27 +140,31 @@ def megatron_model_provider(pre_process, post_process):
94
140
return parallel_model
95
141
96
142
model = get_model (
97
- model_provider_func = megatron_model_provider , model_type = ModelType .encoder_or_decoder , wrap_with_ddp = True
143
+ model_provider_func = megatron_model_provider , model_type = ModelType .encoder_or_decoder , wrap_with_ddp = False
98
144
)
99
145
100
146
with warnings .catch_warnings ():
101
147
warnings .simplefilter ("ignore" )
102
148
103
149
# init hf model
104
- hf_model = AutoModelForCausalLM .from_pretrained (hf_model_path )
150
+ hf_model = AutoModelForCausalLM .from_pretrained (hf_model_path , torch_dtype = torch . bfloat16 )
105
151
ref_state_dict = hf_model .state_dict ()
106
152
107
153
# load hf state dict to megatron model
108
- from verl .models .mcore .loader import load_state_dict_to_megatron_gptmodel
109
-
110
- load_state_dict_to_megatron_gptmodel (
111
- state_dict = ref_state_dict ,
112
- wrapped_models = model ,
113
- config = hf_config ,
114
- params_dtype = torch .bfloat16 ,
115
- is_value_model = False ,
116
- )
117
- ssd = model [0 ].module .module .sharded_state_dict ()
154
+ if "Qwen2MoeForCausalLM" in hf_config .architectures :
155
+ convert_checkpoint_from_transformers_to_megatron (hf_model , model [0 ].module , hf_config )
156
+ else :
157
+ from verl .models .mcore .loader import load_state_dict_to_megatron_gptmodel
158
+
159
+ load_state_dict_to_megatron_gptmodel (
160
+ state_dict = ref_state_dict ,
161
+ wrapped_models = model ,
162
+ config = hf_config ,
163
+ params_dtype = torch .bfloat16 ,
164
+ is_value_model = False ,
165
+ )
166
+
167
+ ssd = model [0 ].module .sharded_state_dict ()
118
168
del ref_state_dict , hf_model
119
169
120
170
# save megatron model
@@ -126,11 +176,11 @@ def megatron_model_provider(pre_process, post_process):
126
176
model_test = get_model (
127
177
model_provider_func = megatron_model_provider , model_type = ModelType .encoder_or_decoder , wrap_with_ddp = True
128
178
)
129
- ssd2 = model_test [0 ].module .module . sharded_state_dict ()
179
+ ssd2 = model_test [0 ].module .sharded_state_dict ()
130
180
dist_checkpointing .load (ssd2 , output_path , strict = StrictHandling .ASSUME_OK_UNEXPECTED )
131
181
132
- sd = model [0 ].module .module . state_dict ()
133
- sd2 = model_test [0 ].module .module . state_dict ()
182
+ sd = model [0 ].module .state_dict ()
183
+ sd2 = model_test [0 ].module .state_dict ()
134
184
for k in sd .keys ():
135
185
if sd [k ] is None :
136
186
continue
@@ -163,11 +213,11 @@ def megatron_value_model_provider(pre_process, post_process):
163
213
model_type = ModelType .encoder_or_decoder ,
164
214
wrap_with_ddp = True ,
165
215
)
166
- ssd2 = model_value [0 ].module .module . sharded_state_dict ()
216
+ ssd2 = model_value [0 ].module .sharded_state_dict ()
167
217
dist_checkpointing .load (ssd2 , output_path , strict = StrictHandling .IGNORE_ALL )
168
218
169
- sd = model [0 ].module .module . state_dict ()
170
- sd2 = model_value [0 ].module .module . state_dict ()
219
+ sd = model [0 ].module .state_dict ()
220
+ sd2 = model_value [0 ].module .state_dict ()
171
221
for k in sd .keys ():
172
222
if sd [k ] is None :
173
223
continue
0 commit comments