Skip to content

Commit d4cae44

Browse files
authored
[mcore] option to use dist checkpoint (#1030)
mcore dist checkpointing is a parallel-invariant weight format, you can save and load in arbitrary parallel settings. e.g. save in tp2pp2 and load in tp4pp1. This PR introduce an option to use dist checkpoint with mcore backend. It is *disabled* by default for backward compatibility. But future support for *mcore MoE models and VLM models* will work only when dist ckpt is enabled for a easier implementation. Before this PR, when initing actor and critic workers, each GPU would load the entire huggingface weights and then re-shard to correct mcore model state dict, making the procedure slow and complicated. With this PR, we convert hf weight to dist ckpt by offline scripts, and each GPU will only load its parts from dist ckpt. The speed is faster and no more online resharding needed. When loading `Qwen2-7B-Instruct` for critic worker, the loading time reduced from 109s to 25s, speedup by *4.36x* The `converter_hf_to_mcore.py` in this version use existing online resharding function to convert weights. And it should be refactored for better efficiency and MoE/VLM models. Thanks to #998 for the optimization of loading hf weight only at GPU 0. Future TODO: * refactor the converter for efficiency * support converting MoE models * support converting VLM models * re-design `megatron_checkpoint_manager.py` with dist ckpt * implement converter from mcore dist ckpt to hf / `model_merger.py` * add docs and example scripts
1 parent 6dd5e39 commit d4cae44

File tree

5 files changed

+252
-17
lines changed

5 files changed

+252
-17
lines changed

scripts/converter_hf_to_mcore.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from typing import List, Tuple, Dict
17+
import re
18+
import os
19+
import torch
20+
import argparse
21+
import warnings
22+
import numpy as np
23+
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq
24+
from concurrent.futures import ThreadPoolExecutor
25+
from safetensors.torch import load_file
26+
from torch.distributed._tensor import Shard, Placement
27+
from verl.utils.megatron_utils import get_model, convert_config
28+
from megatron.core.models.gpt.gpt_model import ModelType
29+
from megatron.core import parallel_state as mpu
30+
from megatron.core import dist_checkpointing
31+
from megatron.core.dist_checkpointing.serialization import StrictHandling
32+
33+
34+
def _init_args():
35+
parser = argparse.ArgumentParser()
36+
parser.add_argument('--hf_model_path', type=str, required=True, help="The path for the huggingface model")
37+
parser.add_argument('--output_path', type=str, required=True, help="The path for the output mcore model")
38+
parser.add_argument('--test', action='store_true', help="Whether to test the conversion")
39+
args = parser.parse_args()
40+
return args
41+
42+
43+
class MegatronConfig:
44+
45+
def __init__(self):
46+
self.params_dtype = torch.bfloat16
47+
48+
49+
class ModelConfig:
50+
51+
def __init__(self):
52+
self.path = None
53+
54+
55+
class Config:
56+
57+
def __init__(self):
58+
self.model = ModelConfig()
59+
60+
61+
def convert_hf_to_mcore(hf_model_path, output_path, test=False):
62+
os.makedirs(output_path, exist_ok=True)
63+
if len(os.listdir(output_path)) > 0 and not test:
64+
print(f"Output path {output_path} is not empty, skipping conversion")
65+
return
66+
67+
# init torch distributed and mpu
68+
os.environ['RANK'] = '0'
69+
os.environ['WORLD_SIZE'] = '1'
70+
os.environ['MASTER_ADDR'] = 'localhost'
71+
os.environ['MASTER_PORT'] = '12355'
72+
torch.distributed.init_process_group('nccl')
73+
mpu.initialize_model_parallel(tensor_model_parallel_size=1,
74+
virtual_pipeline_model_parallel_size=None,
75+
context_parallel_size=1,
76+
expert_model_parallel_size=1)
77+
78+
# init hf config
79+
hf_config = AutoConfig.from_pretrained(hf_model_path)
80+
print(hf_config)
81+
megatron_config = MegatronConfig()
82+
cfg = Config()
83+
cfg.model.path = hf_model_path
84+
tfconfig = convert_config(hf_config, megatron_config)
85+
tie_word_embeddings = getattr(hf_config, "tie_word_embeddings", False)
86+
87+
# init megatron model
88+
def megatron_model_provider(pre_process, post_process):
89+
from verl.utils.model import get_parallel_gptmodel_from_config
90+
parallel_model = get_parallel_gptmodel_from_config(tfconfig,
91+
hf_config,
92+
pre_process,
93+
post_process,
94+
share_embeddings_and_output_weights=tie_word_embeddings,
95+
value=False)
96+
return parallel_model
97+
98+
model = get_model(model_provider_func=megatron_model_provider,
99+
model_type=ModelType.encoder_or_decoder,
100+
wrap_with_ddp=True)
101+
102+
with warnings.catch_warnings():
103+
warnings.simplefilter("ignore")
104+
105+
# init hf model
106+
hf_model = AutoModelForCausalLM.from_pretrained(hf_model_path)
107+
ref_state_dict = hf_model.state_dict()
108+
109+
# load hf state dict to megatron model
110+
from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel
111+
load_state_dict_to_megatron_gptmodel(state_dict=ref_state_dict,
112+
wrapped_models=model,
113+
config=hf_config,
114+
params_dtype=torch.bfloat16,
115+
is_value_model=False)
116+
ssd = model[0].module.module.sharded_state_dict()
117+
del ref_state_dict, hf_model
118+
119+
# save megatron model
120+
if len(os.listdir(output_path)) == 0:
121+
dist_checkpointing.save(ssd, output_path, sharded_strategy=None, async_sharded_save=False)
122+
if test:
123+
########### test ###########
124+
# load model
125+
model_test = get_model(model_provider_func=megatron_model_provider,
126+
model_type=ModelType.encoder_or_decoder,
127+
wrap_with_ddp=True)
128+
ssd2 = model_test[0].module.module.sharded_state_dict()
129+
dist_checkpointing.load(ssd2, output_path, strict=StrictHandling.ASSUME_OK_UNEXPECTED)
130+
131+
sd = model[0].module.module.state_dict()
132+
sd2 = model_test[0].module.module.state_dict()
133+
for k in sd.keys():
134+
if sd[k] is None:
135+
continue
136+
d1 = sd[k].data
137+
if k in sd2:
138+
d2 = sd2[k].data
139+
assert d1.shape == d2.shape, f'{k=} {d1.shape=} {d2.shape=}'
140+
assert (d1 == d2).all(), f"{k} is not equal"
141+
for k in sd2.keys():
142+
if sd2[k] is None:
143+
continue
144+
d1 = sd2[k].data
145+
if k in sd:
146+
d2 = sd[k].data
147+
assert d1.shape == d2.shape, f'{k=} {d1.shape=} {d2.shape=}'
148+
assert (d1 == d2).all(), f"{k} is not equal"
149+
150+
# load value model
151+
def megatron_value_model_provider(pre_process, post_process):
152+
from verl.utils.model import get_parallel_gptmodel_from_config
153+
parallel_model = get_parallel_gptmodel_from_config(tfconfig,
154+
hf_config,
155+
pre_process,
156+
post_process,
157+
share_embeddings_and_output_weights=False,
158+
value=True)
159+
parallel_model.cuda()
160+
return parallel_model
161+
162+
model_value = get_model(model_provider_func=megatron_value_model_provider,
163+
model_type=ModelType.encoder_or_decoder,
164+
wrap_with_ddp=True)
165+
ssd2 = model_value[0].module.module.sharded_state_dict()
166+
dist_checkpointing.load(ssd2, output_path, strict=StrictHandling.IGNORE_ALL)
167+
168+
sd = model[0].module.module.state_dict()
169+
sd2 = model_value[0].module.module.state_dict()
170+
for k in sd.keys():
171+
if sd[k] is None:
172+
continue
173+
d1 = sd[k].data
174+
if k in sd2:
175+
d2 = sd2[k].data
176+
assert d1.shape == d2.shape, f'{k=} {d1.shape=} {d2.shape=}'
177+
assert (d1 == d2).all(), f"{k} is not equal"
178+
for k in sd2.keys():
179+
if sd2[k] is None:
180+
continue
181+
d1 = sd2[k].data
182+
if k in sd:
183+
d2 = sd[k].data
184+
assert d1.shape == d2.shape, f'{k=} {d1.shape=} {d2.shape=}'
185+
assert (d1 == d2).all(), f"{k} is not equal"
186+
187+
188+
if __name__ == "__main__":
189+
args = _init_args()
190+
convert_hf_to_mcore(args.hf_model_path, args.output_path, args.test)

verl/trainer/config/ppo_megatron_trainer.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ actor_rollout_ref:
7070
context_parallel_size: 1
7171
sequence_parallel: True
7272
use_distributed_optimizer: True
73+
use_dist_checkpointing: False
74+
dist_checkpointing_path: null
7375
seed: 1
7476
load_weight: True
7577
checkpoint:
@@ -82,6 +84,8 @@ actor_rollout_ref:
8284
context_parallel_size: 1
8385
sequence_parallel: True
8486
use_distributed_optimizer: True
87+
use_dist_checkpointing: False
88+
dist_checkpointing_path: null
8589
seed: 1
8690
load_weight: True
8791
param_offload: False
@@ -153,6 +157,8 @@ critic:
153157
context_parallel_size: 1
154158
sequence_parallel: True
155159
use_distributed_optimizer: True
160+
use_dist_checkpointing: False
161+
dist_checkpointing_path: null
156162
seed: 1
157163
load_weight: True
158164
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
@@ -178,6 +184,8 @@ reward_model:
178184
context_parallel_size: 1
179185
sequence_parallel: True
180186
use_distributed_optimizer: True
187+
use_dist_checkpointing: False
188+
dist_checkpointing_path: null
181189
seed: 1
182190
model:
183191
input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical

verl/utils/megatron_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerC
187187
batch_p2p_comm=batch_p2p_comm,
188188
pipeline_dtype=dt,
189189
params_dtype=dt,
190-
sequence_parallel=True,
190+
sequence_parallel=mpu.get_tensor_model_parallel_world_size() > 1,
191191
variable_seq_lengths=True,
192192
masked_softmax_fusion=True,
193193
moe_token_dispatcher_type="alltoall",

verl/utils/model.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,23 @@ def load_megatron_gptmodel_weights(config,
409409
del state_dict, model
410410

411411

412+
def load_mcore_dist_weights(parallel_model, dist_weight_path, is_value_model=False):
413+
from megatron.core import dist_checkpointing
414+
from megatron.core.dist_checkpointing.serialization import StrictHandling
415+
416+
# strict = StrictHandling.IGNORE_ALL if is_value_model else StrictHandling.ASSUME_OK_UNEXPECTED
417+
strict = StrictHandling.ASSUME_OK_UNEXPECTED
418+
for model in parallel_model:
419+
ssd = model.module.module.sharded_state_dict()
420+
if is_value_model:
421+
for k in list(ssd.keys()):
422+
if "output_layer" in k:
423+
ssd.pop(k)
424+
dist_checkpointing.load(ssd, dist_weight_path, strict=strict)
425+
426+
return
427+
428+
412429
def get_parallel_gptmodel_from_config(tfconfig,
413430
hf_config,
414431
pre_process=None,

verl/workers/megatron_workers.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import os
1919
import logging
20+
import time
2021
import ray
2122
import torch
2223
import torch.distributed
@@ -33,7 +34,7 @@
3334
from verl import DataProto
3435
from verl.utils.fs import copy_to_local
3536
from verl.utils.debug import log_gpu_memory_usage
36-
from verl.utils.model import load_megatron_model_weights, load_megatron_gptmodel_weights
37+
from verl.utils.model import load_megatron_model_weights, load_megatron_gptmodel_weights, load_mcore_dist_weights
3738
from verl.utils.flops_counter import FlopsCounter
3839
from verl.utils.checkpoint.megatron_checkpoint_manager import MegatronCheckpointManager
3940
from verl.utils.megatron_utils import mcore_model_parallel_config
@@ -204,11 +205,16 @@ def megatron_actor_model_provider(pre_process, post_process):
204205
actor_module = actor_modules_list
205206
print(f'actor_module: {len(actor_module)}')
206207
if self.config.actor.load_weight:
207-
load_megatron_gptmodel_weights(self.config,
208-
actor_model_config,
209-
actor_module,
210-
params_dtype=megatron_config.params_dtype,
211-
is_value_model=False)
208+
if self.config.actor.megatron.use_dist_checkpointing:
209+
load_mcore_dist_weights(actor_module,
210+
self.config.actor.megatron.dist_checkpointing_path,
211+
is_value_model=False)
212+
else:
213+
load_megatron_gptmodel_weights(self.config,
214+
actor_model_config,
215+
actor_module,
216+
params_dtype=megatron_config.params_dtype,
217+
is_value_model=False)
212218

213219
if self.rank == 0:
214220
print_model_size(actor_module[0])
@@ -224,11 +230,16 @@ def megatron_actor_model_provider(pre_process, post_process):
224230
if self.config.ref.load_weight: # should align with the actor:
225231
assert self.config.actor.load_weight == self.config.ref.load_weight
226232
print(f'load ref weight start')
227-
load_megatron_gptmodel_weights(self.config,
228-
actor_model_config,
229-
ref_module,
230-
params_dtype=megatron_config.params_dtype,
231-
is_value_model=False)
233+
if self.config.ref.megatron.use_dist_checkpointing:
234+
load_mcore_dist_weights(ref_module,
235+
self.config.ref.megatron.dist_checkpointing_path,
236+
is_value_model=False)
237+
else:
238+
load_megatron_gptmodel_weights(self.config,
239+
actor_model_config,
240+
ref_module,
241+
params_dtype=megatron_config.params_dtype,
242+
is_value_model=False)
232243
log_gpu_memory_usage('After ref module init', logger=logger)
233244
return ref_module, actor_model_config
234245

@@ -571,11 +582,20 @@ def megatron_critic_model_provider(pre_process, post_process):
571582
# critic_module = nn.ModuleList(critic_module)
572583

573584
if self.config.load_weight:
574-
load_megatron_gptmodel_weights(self.config,
575-
critic_model_config,
576-
critic_module,
577-
params_dtype=megatron_config.params_dtype,
578-
is_value_model=True)
585+
t0 = time.time()
586+
if self.config.megatron.use_dist_checkpointing:
587+
load_mcore_dist_weights(critic_module,
588+
self.config.megatron.dist_checkpointing_path,
589+
is_value_model=True)
590+
else:
591+
load_megatron_gptmodel_weights(self.config,
592+
critic_model_config,
593+
critic_module,
594+
params_dtype=megatron_config.params_dtype,
595+
is_value_model=True)
596+
t1 = time.time()
597+
if torch.distributed.get_rank() == 0:
598+
print(f'critic load_weight time: {t1 - t0}')
579599
if self.rank == 0:
580600
print_model_size(critic_module[0])
581601

0 commit comments

Comments
 (0)