|
| 1 | +# Copyright 2025 Bytedance Ltd. and/or its affiliates |
| 2 | +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +from typing import List, Tuple, Dict |
| 17 | +import re |
| 18 | +import os |
| 19 | +import torch |
| 20 | +import argparse |
| 21 | +import warnings |
| 22 | +import numpy as np |
| 23 | +from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq |
| 24 | +from concurrent.futures import ThreadPoolExecutor |
| 25 | +from safetensors.torch import load_file |
| 26 | +from torch.distributed._tensor import Shard, Placement |
| 27 | +from verl.utils.megatron_utils import get_model, convert_config |
| 28 | +from megatron.core.models.gpt.gpt_model import ModelType |
| 29 | +from megatron.core import parallel_state as mpu |
| 30 | +from megatron.core import dist_checkpointing |
| 31 | +from megatron.core.dist_checkpointing.serialization import StrictHandling |
| 32 | + |
| 33 | + |
| 34 | +def _init_args(): |
| 35 | + parser = argparse.ArgumentParser() |
| 36 | + parser.add_argument('--hf_model_path', type=str, required=True, help="The path for the huggingface model") |
| 37 | + parser.add_argument('--output_path', type=str, required=True, help="The path for the output mcore model") |
| 38 | + parser.add_argument('--test', action='store_true', help="Whether to test the conversion") |
| 39 | + args = parser.parse_args() |
| 40 | + return args |
| 41 | + |
| 42 | + |
| 43 | +class MegatronConfig: |
| 44 | + |
| 45 | + def __init__(self): |
| 46 | + self.params_dtype = torch.bfloat16 |
| 47 | + |
| 48 | + |
| 49 | +class ModelConfig: |
| 50 | + |
| 51 | + def __init__(self): |
| 52 | + self.path = None |
| 53 | + |
| 54 | + |
| 55 | +class Config: |
| 56 | + |
| 57 | + def __init__(self): |
| 58 | + self.model = ModelConfig() |
| 59 | + |
| 60 | + |
| 61 | +def convert_hf_to_mcore(hf_model_path, output_path, test=False): |
| 62 | + os.makedirs(output_path, exist_ok=True) |
| 63 | + if len(os.listdir(output_path)) > 0 and not test: |
| 64 | + print(f"Output path {output_path} is not empty, skipping conversion") |
| 65 | + return |
| 66 | + |
| 67 | + # init torch distributed and mpu |
| 68 | + os.environ['RANK'] = '0' |
| 69 | + os.environ['WORLD_SIZE'] = '1' |
| 70 | + os.environ['MASTER_ADDR'] = 'localhost' |
| 71 | + os.environ['MASTER_PORT'] = '12355' |
| 72 | + torch.distributed.init_process_group('nccl') |
| 73 | + mpu.initialize_model_parallel(tensor_model_parallel_size=1, |
| 74 | + virtual_pipeline_model_parallel_size=None, |
| 75 | + context_parallel_size=1, |
| 76 | + expert_model_parallel_size=1) |
| 77 | + |
| 78 | + # init hf config |
| 79 | + hf_config = AutoConfig.from_pretrained(hf_model_path) |
| 80 | + print(hf_config) |
| 81 | + megatron_config = MegatronConfig() |
| 82 | + cfg = Config() |
| 83 | + cfg.model.path = hf_model_path |
| 84 | + tfconfig = convert_config(hf_config, megatron_config) |
| 85 | + tie_word_embeddings = getattr(hf_config, "tie_word_embeddings", False) |
| 86 | + |
| 87 | + # init megatron model |
| 88 | + def megatron_model_provider(pre_process, post_process): |
| 89 | + from verl.utils.model import get_parallel_gptmodel_from_config |
| 90 | + parallel_model = get_parallel_gptmodel_from_config(tfconfig, |
| 91 | + hf_config, |
| 92 | + pre_process, |
| 93 | + post_process, |
| 94 | + share_embeddings_and_output_weights=tie_word_embeddings, |
| 95 | + value=False) |
| 96 | + return parallel_model |
| 97 | + |
| 98 | + model = get_model(model_provider_func=megatron_model_provider, |
| 99 | + model_type=ModelType.encoder_or_decoder, |
| 100 | + wrap_with_ddp=True) |
| 101 | + |
| 102 | + with warnings.catch_warnings(): |
| 103 | + warnings.simplefilter("ignore") |
| 104 | + |
| 105 | + # init hf model |
| 106 | + hf_model = AutoModelForCausalLM.from_pretrained(hf_model_path) |
| 107 | + ref_state_dict = hf_model.state_dict() |
| 108 | + |
| 109 | + # load hf state dict to megatron model |
| 110 | + from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel |
| 111 | + load_state_dict_to_megatron_gptmodel(state_dict=ref_state_dict, |
| 112 | + wrapped_models=model, |
| 113 | + config=hf_config, |
| 114 | + params_dtype=torch.bfloat16, |
| 115 | + is_value_model=False) |
| 116 | + ssd = model[0].module.module.sharded_state_dict() |
| 117 | + del ref_state_dict, hf_model |
| 118 | + |
| 119 | + # save megatron model |
| 120 | + if len(os.listdir(output_path)) == 0: |
| 121 | + dist_checkpointing.save(ssd, output_path, sharded_strategy=None, async_sharded_save=False) |
| 122 | + if test: |
| 123 | + ########### test ########### |
| 124 | + # load model |
| 125 | + model_test = get_model(model_provider_func=megatron_model_provider, |
| 126 | + model_type=ModelType.encoder_or_decoder, |
| 127 | + wrap_with_ddp=True) |
| 128 | + ssd2 = model_test[0].module.module.sharded_state_dict() |
| 129 | + dist_checkpointing.load(ssd2, output_path, strict=StrictHandling.ASSUME_OK_UNEXPECTED) |
| 130 | + |
| 131 | + sd = model[0].module.module.state_dict() |
| 132 | + sd2 = model_test[0].module.module.state_dict() |
| 133 | + for k in sd.keys(): |
| 134 | + if sd[k] is None: |
| 135 | + continue |
| 136 | + d1 = sd[k].data |
| 137 | + if k in sd2: |
| 138 | + d2 = sd2[k].data |
| 139 | + assert d1.shape == d2.shape, f'{k=} {d1.shape=} {d2.shape=}' |
| 140 | + assert (d1 == d2).all(), f"{k} is not equal" |
| 141 | + for k in sd2.keys(): |
| 142 | + if sd2[k] is None: |
| 143 | + continue |
| 144 | + d1 = sd2[k].data |
| 145 | + if k in sd: |
| 146 | + d2 = sd[k].data |
| 147 | + assert d1.shape == d2.shape, f'{k=} {d1.shape=} {d2.shape=}' |
| 148 | + assert (d1 == d2).all(), f"{k} is not equal" |
| 149 | + |
| 150 | + # load value model |
| 151 | + def megatron_value_model_provider(pre_process, post_process): |
| 152 | + from verl.utils.model import get_parallel_gptmodel_from_config |
| 153 | + parallel_model = get_parallel_gptmodel_from_config(tfconfig, |
| 154 | + hf_config, |
| 155 | + pre_process, |
| 156 | + post_process, |
| 157 | + share_embeddings_and_output_weights=False, |
| 158 | + value=True) |
| 159 | + parallel_model.cuda() |
| 160 | + return parallel_model |
| 161 | + |
| 162 | + model_value = get_model(model_provider_func=megatron_value_model_provider, |
| 163 | + model_type=ModelType.encoder_or_decoder, |
| 164 | + wrap_with_ddp=True) |
| 165 | + ssd2 = model_value[0].module.module.sharded_state_dict() |
| 166 | + dist_checkpointing.load(ssd2, output_path, strict=StrictHandling.IGNORE_ALL) |
| 167 | + |
| 168 | + sd = model[0].module.module.state_dict() |
| 169 | + sd2 = model_value[0].module.module.state_dict() |
| 170 | + for k in sd.keys(): |
| 171 | + if sd[k] is None: |
| 172 | + continue |
| 173 | + d1 = sd[k].data |
| 174 | + if k in sd2: |
| 175 | + d2 = sd2[k].data |
| 176 | + assert d1.shape == d2.shape, f'{k=} {d1.shape=} {d2.shape=}' |
| 177 | + assert (d1 == d2).all(), f"{k} is not equal" |
| 178 | + for k in sd2.keys(): |
| 179 | + if sd2[k] is None: |
| 180 | + continue |
| 181 | + d1 = sd2[k].data |
| 182 | + if k in sd: |
| 183 | + d2 = sd[k].data |
| 184 | + assert d1.shape == d2.shape, f'{k=} {d1.shape=} {d2.shape=}' |
| 185 | + assert (d1 == d2).all(), f"{k} is not equal" |
| 186 | + |
| 187 | + |
| 188 | +if __name__ == "__main__": |
| 189 | + args = _init_args() |
| 190 | + convert_hf_to_mcore(args.hf_model_path, args.output_path, args.test) |
0 commit comments