Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 190 additions & 0 deletions scripts/converter_hf_to_mcore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Tuple, Dict
import re
import os
import torch
import argparse
import warnings
import numpy as np
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq
from concurrent.futures import ThreadPoolExecutor
from safetensors.torch import load_file
from torch.distributed._tensor import Shard, Placement
from verl.utils.megatron_utils import get_model, convert_config
from megatron.core.models.gpt.gpt_model import ModelType
from megatron.core import parallel_state as mpu
from megatron.core import dist_checkpointing
from megatron.core.dist_checkpointing.serialization import StrictHandling


def _init_args():
parser = argparse.ArgumentParser()
parser.add_argument('--hf_model_path', type=str, required=True, help="The path for the huggingface model")
parser.add_argument('--output_path', type=str, required=True, help="The path for the output mcore model")
parser.add_argument('--test', action='store_true', help="Whether to test the conversion")
args = parser.parse_args()
return args


class MegatronConfig:

def __init__(self):
self.params_dtype = torch.bfloat16


class ModelConfig:

def __init__(self):
self.path = None


class Config:

def __init__(self):
self.model = ModelConfig()


def convert_hf_to_mcore(hf_model_path, output_path, test=False):
os.makedirs(output_path, exist_ok=True)
if len(os.listdir(output_path)) > 0 and not test:
print(f"Output path {output_path} is not empty, skipping conversion")
return

# init torch distributed and mpu
os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
torch.distributed.init_process_group('nccl')
mpu.initialize_model_parallel(tensor_model_parallel_size=1,
virtual_pipeline_model_parallel_size=None,
context_parallel_size=1,
expert_model_parallel_size=1)

# init hf config
hf_config = AutoConfig.from_pretrained(hf_model_path)
print(hf_config)
megatron_config = MegatronConfig()
cfg = Config()
cfg.model.path = hf_model_path
tfconfig = convert_config(hf_config, megatron_config)
tie_word_embeddings = getattr(hf_config, "tie_word_embeddings", False)

# init megatron model
def megatron_model_provider(pre_process, post_process):
from verl.utils.model import get_parallel_gptmodel_from_config
parallel_model = get_parallel_gptmodel_from_config(tfconfig,
hf_config,
pre_process,
post_process,
share_embeddings_and_output_weights=tie_word_embeddings,
value=False)
return parallel_model

model = get_model(model_provider_func=megatron_model_provider,
model_type=ModelType.encoder_or_decoder,
wrap_with_ddp=True)

with warnings.catch_warnings():
warnings.simplefilter("ignore")

# init hf model
hf_model = AutoModelForCausalLM.from_pretrained(hf_model_path)
ref_state_dict = hf_model.state_dict()

# load hf state dict to megatron model
from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel
load_state_dict_to_megatron_gptmodel(state_dict=ref_state_dict,
wrapped_models=model,
config=hf_config,
params_dtype=torch.bfloat16,
is_value_model=False)
ssd = model[0].module.module.sharded_state_dict()
del ref_state_dict, hf_model

# save megatron model
if len(os.listdir(output_path)) == 0:
dist_checkpointing.save(ssd, output_path, sharded_strategy=None, async_sharded_save=False)
if test:
########### test ###########
# load model
model_test = get_model(model_provider_func=megatron_model_provider,
model_type=ModelType.encoder_or_decoder,
wrap_with_ddp=True)
ssd2 = model_test[0].module.module.sharded_state_dict()
dist_checkpointing.load(ssd2, output_path, strict=StrictHandling.ASSUME_OK_UNEXPECTED)

sd = model[0].module.module.state_dict()
sd2 = model_test[0].module.module.state_dict()
for k in sd.keys():
if sd[k] is None:
continue
d1 = sd[k].data
if k in sd2:
d2 = sd2[k].data
assert d1.shape == d2.shape, f'{k=} {d1.shape=} {d2.shape=}'
assert (d1 == d2).all(), f"{k} is not equal"
for k in sd2.keys():
if sd2[k] is None:
continue
d1 = sd2[k].data
if k in sd:
d2 = sd[k].data
assert d1.shape == d2.shape, f'{k=} {d1.shape=} {d2.shape=}'
assert (d1 == d2).all(), f"{k} is not equal"

# load value model
def megatron_value_model_provider(pre_process, post_process):
from verl.utils.model import get_parallel_gptmodel_from_config
parallel_model = get_parallel_gptmodel_from_config(tfconfig,
hf_config,
pre_process,
post_process,
share_embeddings_and_output_weights=False,
value=True)
parallel_model.cuda()
return parallel_model

model_value = get_model(model_provider_func=megatron_value_model_provider,
model_type=ModelType.encoder_or_decoder,
wrap_with_ddp=True)
ssd2 = model_value[0].module.module.sharded_state_dict()
dist_checkpointing.load(ssd2, output_path, strict=StrictHandling.IGNORE_ALL)

sd = model[0].module.module.state_dict()
sd2 = model_value[0].module.module.state_dict()
for k in sd.keys():
if sd[k] is None:
continue
d1 = sd[k].data
if k in sd2:
d2 = sd2[k].data
assert d1.shape == d2.shape, f'{k=} {d1.shape=} {d2.shape=}'
assert (d1 == d2).all(), f"{k} is not equal"
for k in sd2.keys():
if sd2[k] is None:
continue
d1 = sd2[k].data
if k in sd:
d2 = sd[k].data
assert d1.shape == d2.shape, f'{k=} {d1.shape=} {d2.shape=}'
assert (d1 == d2).all(), f"{k} is not equal"


if __name__ == "__main__":
args = _init_args()
convert_hf_to_mcore(args.hf_model_path, args.output_path, args.test)
1 change: 0 additions & 1 deletion scripts/model_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
except ImportError:
from torch.distributed._tensor import DTensor


parser = argparse.ArgumentParser()
parser.add_argument('--backend', type=str, required=True, help="The backend of the model", choices=["fsdp", "megatron"])
parser.add_argument('--tie-word-embedding', action='store_true', help="Whether to tie word embedding weights")
Expand Down
8 changes: 8 additions & 0 deletions verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ actor_rollout_ref:
context_parallel_size: 1
sequence_parallel: True
use_distributed_optimizer: True
use_dist_checkpointing: False
dist_checkpointing_path: null
seed: 1
load_weight: True
checkpoint:
Expand All @@ -79,6 +81,8 @@ actor_rollout_ref:
context_parallel_size: 1
sequence_parallel: True
use_distributed_optimizer: True
use_dist_checkpointing: False
dist_checkpointing_path: null
seed: 1
load_weight: True
param_offload: False
Expand Down Expand Up @@ -150,6 +154,8 @@ critic:
context_parallel_size: 1
sequence_parallel: True
use_distributed_optimizer: True
use_dist_checkpointing: False
dist_checkpointing_path: null
seed: 1
load_weight: True
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
Expand All @@ -175,6 +181,8 @@ reward_model:
context_parallel_size: 1
sequence_parallel: True
use_distributed_optimizer: True
use_dist_checkpointing: False
dist_checkpointing_path: null
seed: 1
model:
input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical
Expand Down
2 changes: 1 addition & 1 deletion verl/utils/megatron_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerC
batch_p2p_comm=batch_p2p_comm,
pipeline_dtype=dt,
params_dtype=dt,
sequence_parallel=True,
sequence_parallel=mpu.get_tensor_model_parallel_world_size() > 1,
variable_seq_lengths=True,
masked_softmax_fusion=True,
moe_token_dispatcher_type="alltoall",
Expand Down
17 changes: 17 additions & 0 deletions verl/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,23 @@ def load_megatron_gptmodel_weights(config,
del state_dict, model


def load_mcore_dist_weights(parallel_model, dist_weight_path, is_value_model=False):
from megatron.core import dist_checkpointing
from megatron.core.dist_checkpointing.serialization import StrictHandling

# strict = StrictHandling.IGNORE_ALL if is_value_model else StrictHandling.ASSUME_OK_UNEXPECTED
strict = StrictHandling.ASSUME_OK_UNEXPECTED
for model in parallel_model:
ssd = model.module.module.sharded_state_dict()
if is_value_model:
for k in list(ssd.keys()):
if "output_layer" in k:
ssd.pop(k)
dist_checkpointing.load(ssd, dist_weight_path, strict=strict)

return


def get_parallel_gptmodel_from_config(tfconfig,
hf_config,
pre_process=None,
Expand Down
52 changes: 36 additions & 16 deletions verl/workers/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import os
import logging
import time
import ray
import torch
import torch.distributed
Expand All @@ -33,7 +34,7 @@
from verl import DataProto
from verl.utils.fs import copy_to_local
from verl.utils.debug import log_gpu_memory_usage
from verl.utils.model import load_megatron_model_weights, load_megatron_gptmodel_weights
from verl.utils.model import load_megatron_model_weights, load_megatron_gptmodel_weights, load_mcore_dist_weights
from verl.utils.flops_counter import FlopsCounter
from verl.utils.checkpoint.megatron_checkpoint_manager import MegatronCheckpointManager
from verl.utils.megatron_utils import mcore_model_parallel_config
Expand Down Expand Up @@ -204,11 +205,16 @@ def megatron_actor_model_provider(pre_process, post_process):
actor_module = actor_modules_list
print(f'actor_module: {len(actor_module)}')
if self.config.actor.load_weight:
load_megatron_gptmodel_weights(self.config,
actor_model_config,
actor_module,
params_dtype=megatron_config.params_dtype,
is_value_model=False)
if self.config.actor.megatron.use_dist_checkpointing:
load_mcore_dist_weights(actor_module,
self.config.actor.megatron.dist_checkpointing_path,
is_value_model=False)
else:
load_megatron_gptmodel_weights(self.config,
actor_model_config,
actor_module,
params_dtype=megatron_config.params_dtype,
is_value_model=False)

if self.rank == 0:
print_model_size(actor_module[0])
Expand All @@ -224,11 +230,16 @@ def megatron_actor_model_provider(pre_process, post_process):
if self.config.ref.load_weight: # should align with the actor:
assert self.config.actor.load_weight == self.config.ref.load_weight
print(f'load ref weight start')
load_megatron_gptmodel_weights(self.config,
actor_model_config,
ref_module,
params_dtype=megatron_config.params_dtype,
is_value_model=False)
if self.config.ref.megatron.use_dist_checkpointing:
load_mcore_dist_weights(ref_module,
self.config.ref.megatron.dist_checkpointing_path,
is_value_model=False)
else:
load_megatron_gptmodel_weights(self.config,
actor_model_config,
ref_module,
params_dtype=megatron_config.params_dtype,
is_value_model=False)
log_gpu_memory_usage('After ref module init', logger=logger)
return ref_module, actor_model_config

Expand Down Expand Up @@ -569,11 +580,20 @@ def megatron_critic_model_provider(pre_process, post_process):
# critic_module = nn.ModuleList(critic_module)

if self.config.load_weight:
load_megatron_gptmodel_weights(self.config,
critic_model_config,
critic_module,
params_dtype=megatron_config.params_dtype,
is_value_model=True)
t0 = time.time()
if self.config.megatron.use_dist_checkpointing:
load_mcore_dist_weights(critic_module,
self.config.megatron.dist_checkpointing_path,
is_value_model=True)
else:
load_megatron_gptmodel_weights(self.config,
critic_model_config,
critic_module,
params_dtype=megatron_config.params_dtype,
is_value_model=True)
t1 = time.time()
if torch.distributed.get_rank() == 0:
print(f'critic load_weight time: {t1 - t0}')
if self.rank == 0:
print_model_size(critic_module[0])

Expand Down
Loading