From 1e5ed312e2ea8ccc18eb4ceb4126b393aa790ddc Mon Sep 17 00:00:00 2001 From: Vincent Nguyen Date: Mon, 4 Dec 2023 11:51:41 +0100 Subject: [PATCH] Distributed inference of 70B awq model (#2531) * Distributed inference of 70B awq model * fix overflow --- onmt/bin/translate.py | 29 +-- onmt/inference_engine.py | 16 +- onmt/inputters/text_utils.py | 22 +- onmt/model_builder.py | 6 +- onmt/models/model.py | 27 ++- onmt/opts.py | 12 +- onmt/translate/translator.py | 5 +- onmt/utils/distributed.py | 3 +- tools/LM_scoring.py | 15 +- tools/convert_HF_llamalike.py | 376 +++++++++++++++++++++++----------- 10 files changed, 340 insertions(+), 171 deletions(-) diff --git a/onmt/bin/translate.py b/onmt/bin/translate.py index 0cd76edcc0..bf3d25f1fc 100644 --- a/onmt/bin/translate.py +++ b/onmt/bin/translate.py @@ -1,10 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from onmt.utils.logging import init_logger -from onmt.translate.translator import build_translator -from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter -from onmt.transforms import get_transforms_cls -from onmt.constants import CorpusTask +from onmt.inference_engine import InferenceEnginePY from onmt.opts import config_opts, translate_opts from onmt.utils.parse import ArgumentParser from onmt.utils.misc import use_gpu, set_random_seed @@ -17,29 +13,12 @@ def translate(opt): ArgumentParser._get_all_transform_translate(opt) ArgumentParser._validate_transforms_opts(opt) ArgumentParser.validate_translate_opts_dynamic(opt) - logger = init_logger(opt.log_file) set_random_seed(opt.seed, use_gpu(opt)) - translator = build_translator(opt, logger=logger, report_score=False) - - transforms_cls = get_transforms_cls(opt._all_transform) - - infer_iter = build_dynamic_dataset_iter( - opt, - transforms_cls, - translator.vocabs, - task=CorpusTask.INFER, - copy=translator.copy_attn, - device_id=opt.gpu, - ) - - _, _ = translator._translate( - infer_iter, - transform=infer_iter.transforms, - attn_debug=opt.attn_debug, - align_debug=opt.align_debug, - ) + engine = InferenceEnginePY(opt) + _, _ = engine.infer_file() + engine.terminate() def _get_parser(): diff --git a/onmt/inference_engine.py b/onmt/inference_engine.py index 5c7f04a985..279f836ee6 100755 --- a/onmt/inference_engine.py +++ b/onmt/inference_engine.py @@ -2,7 +2,7 @@ from onmt.constants import CorpusTask, DefaultTokens, ModelTask from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter from onmt.utils.distributed import ErrorHandler, spawned_infer -from onmt.utils.logging import logger +from onmt.utils.logging import init_logger from onmt.transforms import get_transforms_cls, make_transforms, TransformPipe @@ -82,6 +82,7 @@ def __init__(self, opt): super().__init__(opt) self.opt = opt + self.logger = init_logger(opt.log_file) if opt.world_size > 1: mp = torch.multiprocessing.get_context("spawn") @@ -92,10 +93,6 @@ def __init__(self, opt): self.queue_result = [] self.procs = [] - print("world_size: ", opt.world_size) - print("gpu_ranks: ", opt.gpu_ranks) - print("opt.gpu: ", opt.gpu) - for device_id in range(opt.world_size): self.queue_instruct.append(mp.Queue()) self.queue_result.append(mp.Queue()) @@ -113,12 +110,11 @@ def __init__(self, opt): ) ) self.procs[device_id].start() - print(" Starting process pid: %d " % self.procs[device_id].pid) self.error_handler.add_child(self.procs[device_id].pid) else: - self.device_id = 0 if opt.world_size == 1 else -1 + self.device_id = opt.gpu self.translator = build_translator( - opt, self.device_id, logger=logger, report_score=True + opt, self.device_id, logger=self.logger, report_score=True ) self.transforms_cls = get_transforms_cls(opt._all_transform) self.vocabs = self.translator.vocabs @@ -168,9 +164,9 @@ def __init__(self, opt): super().__init__(opt) self.opt = opt - self.logger = logger + self.logger = init_logger(opt.log_file) assert self.opt.world_size <= 1, "World size must be less than 1." - self.device_id = 0 if opt.world_size == 1 else -1 + self.device_id = opt.gpu if opt.world_size == 1: self.device_index = opt.gpu_ranks self.device = "cuda" diff --git a/onmt/inputters/text_utils.py b/onmt/inputters/text_utils.py index c6e9e8ff1b..42ebb3c1c7 100644 --- a/onmt/inputters/text_utils.py +++ b/onmt/inputters/text_utils.py @@ -246,10 +246,19 @@ def tensorify(vocabs, minibatch, device, left_pad=False): ) if minibatch[0][0]["tgt"] is not None: - tbatchtgt = [ - torch.tensor(ex["tgt"]["tgt_ids"], dtype=torch.long, device=device) - for ex, indice in minibatch - ] + if left_pad: + tbatchtgt = [ + torch.tensor( + ex["tgt"]["tgt_ids"], dtype=torch.long, device=device + ).flip(dims=[0]) + for ex, indice in minibatch + ] + else: + tbatchtgt = [ + torch.tensor(ex["tgt"]["tgt_ids"], dtype=torch.long, device=device) + for ex, indice in minibatch + ] + padidx = vocabs["tgt"][DefaultTokens.PAD] tbatchtgt = pad_sequence(tbatchtgt, batch_first=True, padding_value=padidx) tbatchtgt = tbatchtgt[:, :, None] @@ -258,7 +267,10 @@ def tensorify(vocabs, minibatch, device, left_pad=False): dtype=torch.long, device=device, ) - tensor_batch["tgt"] = tbatchtgt + if left_pad: + tensor_batch["tgt"] = tbatchtgt.flip(dims=[1]) + else: + tensor_batch["tgt"] = tbatchtgt tensor_batch["tgtlen"] = tbatchtgtlen if "align" in minibatch[0][0].keys() and minibatch[0][0]["align"] is not None: diff --git a/onmt/model_builder.py b/onmt/model_builder.py index 0019019c24..37831c50d1 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -100,7 +100,11 @@ def load_test_model(opt, device_id=0, model_path=None): "aawq_gemm", "aawq_gemv", ]: # if the loaded model is a awq quantized one, inference config cannot overwrite this - if hasattr(opt, "quant_type") and opt.quant_type != model_opt.quant_type: + if ( + hasattr(opt, "quant_type") + and opt.quant_type != "" + and opt.quant_type != model_opt.quant_type + ): raise ValueError( "Model is a awq quantized model, cannot overwrite with another quant method" ) diff --git a/onmt/models/model.py b/onmt/models/model.py index 2025968317..40f0ce534d 100644 --- a/onmt/models/model.py +++ b/onmt/models/model.py @@ -46,6 +46,10 @@ def count_parameters(self, log=print): raise NotImplementedError def _load_param(self, name, module, param_name, param, buf_list, ckpt_t, offset): + if module.__class__.__name__ == "WQLinear_GEMM": + # ugly patch because in_feat and out_feat are reversed in WQLinear_GEMM + param.data = param.data.transpose(0, 1) + ckpt_t = ckpt_t.transpose(0, 1) if name.split(".")[-1] in [ "linear_keys", "linear_values", @@ -73,13 +77,22 @@ def _load_param(self, name, module, param_name, param, buf_list, ckpt_t, offset) ].size() ), "An error in model's partition and checkpoint's slice was detected" if name + "." + param_name in buf_list: - module.register_buffer( - param_name, - ckpt_t[ - col_slice_start:col_slice_end, - row_slice_start:row_slice_end, - ], - ) + if module.__class__.__name__ == "WQLinear_GEMM": + module.register_buffer( + param_name, + ckpt_t[ + col_slice_start:col_slice_end, + row_slice_start:row_slice_end, + ].transpose(0, 1), + ) + else: + module.register_buffer( + param_name, + ckpt_t[ + col_slice_start:col_slice_end, + row_slice_start:row_slice_end, + ], + ) else: param.data = ckpt_t[ col_slice_start:col_slice_end, diff --git a/onmt/opts.py b/onmt/opts.py index 421210c02c..861e96172e 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -1564,8 +1564,16 @@ def _add_quant_opts(parser): group.add( "--quant_type", "-quant_type", - default="bnb_8bit", - choices=["bnb_8bit", "bnb_FP4", "bnb_NF4", "llm_awq", "aawq_gemm", "aawq_gemv"], + default="", + choices=[ + "", + "bnb_8bit", + "bnb_FP4", + "bnb_NF4", + "llm_awq", + "aawq_gemm", + "aawq_gemv", + ], type=str, help="Type of compression.", ) diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index 856f786b3c..a27fa72ae0 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -619,7 +619,10 @@ def _report_score(self, name, score_total, nb_sentences): msg = "%s No translations" % (name,) else: score = score_total / nb_sentences - ppl = exp(-score_total / nb_sentences) + try: + ppl = exp(-score_total / nb_sentences) + except OverflowError: + ppl = float("inf") msg = "%s SCORE: %.4f, %s PPL: %.2f NB SENTENCES: %d" % ( name, score, diff --git a/onmt/utils/distributed.py b/onmt/utils/distributed.py index e2acf3749e..c7e6051a7c 100644 --- a/onmt/utils/distributed.py +++ b/onmt/utils/distributed.py @@ -197,7 +197,6 @@ def spawned_infer(opt, device_id, error_queue, queue_instruct, queue_result): init_logger(opt.log_file) translator = build_translator(opt, device_id, logger=logger, report_score=True) transforms_cls = get_transforms_cls(opt._all_transform) - print("Device_id: ", device_id, " translator built") while True: instruction = queue_instruct.get() if instruction[0] == "stop": @@ -227,7 +226,7 @@ def spawned_infer(opt, device_id, error_queue, queue_instruct, queue_result): device_id=device_id, ) scores, preds = translator._translate( - infer_iter, infer_iter.transform, opt.attn_debug, opt.align_debug + infer_iter, infer_iter.transforms, opt.attn_debug, opt.align_debug ) queue_result.put(scores) queue_result.put(preds) diff --git a/tools/LM_scoring.py b/tools/LM_scoring.py index 823571fa0e..e07632f3d9 100644 --- a/tools/LM_scoring.py +++ b/tools/LM_scoring.py @@ -86,12 +86,12 @@ def main(): cumul_length = 0 # Now we can pipe the full file through the model using the Iterator - for i, batch in enumerate(infer_iter): + for i, (batch, bucket_idx) in enumerate(infer_iter): # reminder a batch includes .src .tgt .indices and it is sorted batch_size = len(batch["srclen"]) src = batch["src"] src_len = batch["srclen"] - + # print(batch) outputs, attns = model(src, None, src_len, with_align=False) # Compute and retrieve the loss for EACH sentence loss, _ = valid_loss(batch, outputs, attns) @@ -102,7 +102,16 @@ def main(): cumul_length += batch["tgt"][:, 1:, 0].ne(padding_idx).sum().cpu() # Now we need to rearrange the batch of ppl # in the original order with indices - sent_ppl_orig = ppl.gather(0, batch["cid_line_number"].argsort(0)) + sent_ppl_orig = ppl.gather( + 0, + torch.tensor( + sorted( + range(len(batch["cid_line_number"])), + key=lambda k: batch["cid_line_number"][k], + ), + device=ppl.device, + ), + ) for j in range(batch_size): ppl_file.write(str(sent_ppl_orig[j].item()) + "\n") logger.info( diff --git a/tools/convert_HF_llamalike.py b/tools/convert_HF_llamalike.py index c375404b4e..64053d72a1 100755 --- a/tools/convert_HF_llamalike.py +++ b/tools/convert_HF_llamalike.py @@ -1,13 +1,14 @@ #!/usr/bin/env python import torch +import json import argparse import pyonmttok +import safetensors from argparse import Namespace from onmt.inputters.inputter import vocabs_to_dict from onmt.constants import DefaultTokens from sentencepiece import SentencePieceProcessor import os -from transformers import AutoModelForCausalLM, AutoConfig import huggingface_hub from safetensors.torch import save_file @@ -43,21 +44,42 @@ def __init__(self, model_path: str): parser.add_argument( "--nshards", type=int, default=1, help="""Path to the model directory""" ) - opt = parser.parse_args() - - model = AutoModelForCausalLM.from_pretrained( - opt.model_dir, - torch_dtype=torch.float16, - # device_map={"": "cpu"}, - trust_remote_code=True, + parser.add_argument( + "--token", + type=str, + default="", + help="""HF token""", ) - checkpoint = model.state_dict() - if opt.format == "pytorch" and opt.nshards > 1: - raise ValueError("Saving several shards in pytorch format is not supported") + opt = parser.parse_args() - if os.path.exists(os.path.join(opt.model_dir, "tokenizer.model")): - tokenizer_model = os.path.join(opt.model_dir, "tokenizer.model") + if os.path.exists(opt.model_dir): + if os.path.exists(os.path.join(opt.model_dir, "config.json")): + config_path = os.path.join(opt.model_dir, "config.json") + else: + raise ValueError("You used a local directory but config.json is missing") + if os.path.exists(os.path.join(opt.model_dir, "model.safetensors.index.json")): + wmap_path = os.path.join(opt.model_dir, "model.safetensors.index.json") + elif os.path.exists( + os.path.join(opt.model_dir, "pytorch_model.bin.index.json") + ): + wmap_path = os.path.join(opt.model_dir, "pytorch_model.bin.index.json") + elif os.path.exists(os.path.join(opt.model_dir, "model.safetensors")): + wmap_path = None + model_path = os.path.join(opt.model_dir, "model.safetensors") + elif os.path.exists(os.path.join(opt.model_dir, "pytorch_model.bin")): + wmap_path = None + model_path = os.path.join(opt.model_dir, "pytorch_model.bin") + else: + raise ValueError( + "Could not find any proper model configuration, please check your files" + ) + if os.path.exists(os.path.join(opt.model_dir, "tokenizer.model")): + tokenizer_model = os.path.join(opt.model_dir, "tokenizer.model") + else: + raise ValueError( + "You used a local directory but tokenizer.model is missing" + ) else: directory_path, _ = os.path.split(opt.output) os.makedirs(directory_path, exist_ok=True) @@ -66,48 +88,101 @@ def __init__(self, model_path: str): repo_id=opt.model_dir, filename="tokenizer.model", local_dir=directory_path, + token=opt.token, ) except huggingface_hub.utils.EntryNotFoundError: - print( + raise huggingface_hub.utils.EntryNotFoundError( "Make sure the repo contains tokenizer.model - needed for all Llama-like models" ) - exit() - - config = AutoConfig.from_pretrained(opt.model_dir) - decoder_layers = config.num_hidden_layers - src_word_vec_size = config.hidden_size - tgt_word_vec_size = config.hidden_size - hidden_size = config.hidden_size - heads = config.num_attention_heads - vocab_size = config.vocab_size - transformer_ff = config.intermediate_size - - if hasattr(config, "num_key_value_heads") and config.num_key_value_heads != heads: - num_kv = config.num_key_value_heads + try: + config_path = huggingface_hub.hf_hub_download( + repo_id=opt.model_dir, + filename="config.json", + local_dir=directory_path, + token=opt.token, + ) + except huggingface_hub.utils.EntryNotFoundError: + raise huggingface_hub.utils.EntryNotFoundError( + "Something went wrong the repo does not contain any config.json file" + ) + try: + wmap_path = huggingface_hub.hf_hub_download( + repo_id=opt.model_dir, + filename="model.safetensors.index.json", + local_dir=directory_path, + token=opt.token, + ) + except huggingface_hub.utils.EntryNotFoundError: + try: + wmap_path = huggingface_hub.hf_hub_download( + repo_id=opt.model_dir, + filename="pytorch_model.bin.index.json", + local_dir=directory_path, + token=opt.token, + ) + except huggingface_hub.utils.EntryNotFoundError: + try: + model_path = huggingface_hub.hf_hub_download( + repo_id=opt.model_dir, + filename="model.safetensors", + local_dir=directory_path, + token=opt.token, + ) + wmap_path = None + except huggingface_hub.utils.EntryNotFoundError: + try: + model_path = huggingface_hub.hf_hub_download( + repo_id=opt.model_dir, + filename="pytorch_model.bin", + local_dir=directory_path, + token=opt.token, + ) + wmap_path = None + except huggingface_hub.utils.EntryNotFoundError: + raise huggingface_hub.utils.EntryNotFoundError( + "No valid model files found" + ) + + with open(config_path, encoding="utf-8") as fconfig: + config = json.load(fconfig) + + decoder_layers = config["num_hidden_layers"] + src_word_vec_size = config["hidden_size"] + tgt_word_vec_size = config["hidden_size"] + hidden_size = config["hidden_size"] + heads = config["num_attention_heads"] + vocab_size = config["vocab_size"] + transformer_ff = config["intermediate_size"] + + if ( + "num_key_value_heads" in config.keys() + and config["num_key_value_heads"] != heads + ): + num_kv = config["num_key_value_heads"] else: num_kv = 0 - if hasattr(config, "rms_norm_eps"): - norm_eps = config.rms_norm_eps + if "rms_norm_eps" in config.keys(): + norm_eps = config["rms_norm_eps"] else: norm_eps = 1e-6 - if hasattr(config, "sliding_window"): - sliding_window = config.sliding_window + if "sliding_window" in config.keys(): + sliding_window = config["sliding_window"] else: sliding_window = 0 - if hasattr(config, "quantization_config"): + if "quantization_config" in config.keys(): if ( - "quant_method" in config.quantization_config.keys() - and config.quantization_config["quant_method"] == "awq" + "quant_method" in config["quantization_config"].keys() + and config["quantization_config"]["quant_method"] == "awq" ): - if "backend" in config.quantization_config.keys(): - backend = config.quantization_config["backend"] + if "backend" in config["quantization_config"].keys(): + backend = config["quantization_config"]["backend"] if backend == "llm-awq": quant_type = "llm_awq" elif backend == "autoawq": - if config.quantization_config["version"].lower() == "gemm": + if config["quantization_config"]["version"].lower() == "gemm": quant_type = "aawq_gemm" - elif config.quantization_config["version"].lower() == "gemv": + elif config["quantization_config"]["version"].lower() == "gemv": quant_type = "aawq_gemv" else: raise ValueError("Unknown quantization config") @@ -115,22 +190,22 @@ def __init__(self, model_path: str): raise ValueError("Unknown backend config") else: print("Backend not specified in config, using Autoawq") - if config.quantization_config["version"].lower() == "gemm": + if config["quantization_config"]["version"].lower() == "gemm": quant_type = "aawq_gemm" - elif config.quantization_config["version"].lower() == "gemv": + elif config["quantization_config"]["version"].lower() == "gemv": quant_type = "aawq_gemv" else: raise ValueError("Unknown quantization config") else: raise ValueError("Can convert only awq models for now") - if "bits" in config.quantization_config.keys(): - w_bit = config.quantization_config["bits"] + if "bits" in config["quantization_config"].keys(): + w_bit = config["quantization_config"]["bits"] else: - w_bit = config.quantization_config["w_bit"] - if "group_size" in config.quantization_config.keys(): - group_size = config.quantization_config["group_size"] + w_bit = config["quantization_config"]["w_bit"] + if "group_size" in config["quantization_config"].keys(): + group_size = config["quantization_config"]["group_size"] else: - group_size = config.quantization_config["q_group_size"] + group_size = config["quantization_config"]["q_group_size"] quant_layers = [ "w_1", @@ -151,87 +226,158 @@ def __init__(self, model_path: str): onmt_cp = {} + if wmap_path: + with open(wmap_path, encoding="utf-8") as fweights: + wmap = json.load(fweights) + + def get_load_ckpt(dir_path, file_path): + if os.path.exists(os.path.join(dir_path, file_path)): + ckpt_path = os.path.join(dir_path, file_path) + else: + try: + ckpt_path = huggingface_hub.hf_hub_download( + repo_id=opt.model_dir, + filename=file_path, + local_dir=dir_path, + token=opt.token, + ) + except huggingface_hub.utils.EntryNotFoundError: + raise huggingface_hub.utils.EntryNotFoundError( + "Checkpoint not found on the hub" + ) + except PermissionError: + ckpt_path = os.path.join(dir_path, file_path) + if ckpt_path[-3:] == ".pt": + checkpoint = torch.load(ckpt_path, map_location=torch.device("cpu")) + else: + checkpoint = ckpt_path + + return checkpoint + + def get_weight(checkpoint, tensor_name): + if isinstance(checkpoint, dict): + if tensor_name in checkpoint.keys(): + return checkpoint[tensor_name] + else: + return None + else: + with safetensors.safe_open(checkpoint, framework="pt", device="cpu") as f: + if tensor_name in f.keys(): + return f.get_tensor(tensor_name) + else: + return None + for shard in range(opt.nshards): print("starting output shard: %d/%d" % (shard + 1, opt.nshards)) onmt_safetensor = {} if shard == 0: - onmt_safetensor[ - "decoder.embeddings.make_embedding.emb_luts.0.weight" - ] = checkpoint["model.embed_tokens.weight"] - onmt_safetensor["decoder.layer_norm.weight"] = checkpoint[ - "model.norm.weight" + sourcelist = [ + "model.embed_tokens.weight", + "model.norm.weight", + "lm_head.weight", ] + targetlist = [ + "decoder.embeddings.make_embedding.emb_luts.0.weight", + "decoder.layer_norm.weight", + "generator.weight", + ] + + for source, target in zip(sourcelist, targetlist): + if wmap_path: + checkpoint = get_load_ckpt( + os.path.split(wmap_path)[0], wmap["weight_map"][source] + ) + else: + checkpoint = get_load_ckpt(*os.path.split(model_path)) + w = get_weight(checkpoint, source) + if w is not None: + onmt_safetensor[target] = w - onmt_safetensor["generator.weight"] = checkpoint["lm_head.weight"] onmt_safetensor["generator.bias"] = torch.zeros( onmt_safetensor["generator.weight"].size(0), dtype=torch.float16 ) - for i in range( - -(decoder_layers // -opt.nshards) * shard, - min(-(decoder_layers // -opt.nshards) * (shard + 1), decoder_layers), - 1, - ): - onmt_safetensor[ - "decoder.transformer_layers." + str(i) + ".layer_norm_1.weight" - ] = checkpoint["model.layers." + str(i) + ".input_layernorm.weight"] - - for param in params: - onmt_safetensor[ - "decoder.transformer_layers." - + str(i) - + ".self_attn.linear_query." - + param - ] = checkpoint["model.layers." + str(i) + ".self_attn.q_proj." + param] - onmt_safetensor[ - "decoder.transformer_layers." - + str(i) - + ".self_attn.linear_keys." - + param - ] = checkpoint["model.layers." + str(i) + ".self_attn.k_proj." + param] - onmt_safetensor[ - "decoder.transformer_layers." - + str(i) - + ".self_attn.linear_values." - + param - ] = checkpoint["model.layers." + str(i) + ".self_attn.v_proj." + param] - - onmt_safetensor[ - "decoder.transformer_layers." - + str(i) - + ".self_attn.final_linear." - + param - ] = checkpoint["model.layers." + str(i) + ".self_attn.o_proj." + param] - - onmt_safetensor[ - "decoder.transformer_layers." - + str(i) - + ".feed_forward.w_1." - + param - ] = checkpoint["model.layers." + str(i) + ".mlp.gate_proj." + param] - - onmt_safetensor[ - "decoder.transformer_layers." - + str(i) - + ".feed_forward.w_2." - + param - ] = checkpoint["model.layers." + str(i) + ".mlp.down_proj." + param] - onmt_safetensor[ - "decoder.transformer_layers." - + str(i) - + ".feed_forward.w_3." - + param - ] = checkpoint["model.layers." + str(i) + ".mlp.up_proj." + param] - - onmt_safetensor[ - "decoder.transformer_layers." - + str(i) - + ".feed_forward.layer_norm.weight" - ] = checkpoint[ - "model.layers." + str(i) + ".post_attention_layernorm.weight" - ] + if wmap_path: + weightmap = wmap["weight_map"] + ckpt_list = [] + for key in weightmap.keys(): + if ( + key.startswith("model.layers.") + and int(key.split(".")[2]) + in range( + -(decoder_layers // -opt.nshards) * shard, + min( + -(decoder_layers // -opt.nshards) * (shard + 1), + decoder_layers, + ), + 1, + ) + and weightmap[key] not in ckpt_list + ): + ckpt_list.append(weightmap[key]) + else: + ckpt_list = [model_path] + + for ckpt in ckpt_list: + print("Loading %s" % ckpt) + if wmap_path: + checkpoint = get_load_ckpt(os.path.split(wmap_path)[0], ckpt) + else: + checkpoint = get_load_ckpt(*os.path.split(model_path)) + for i in range( + -(decoder_layers // -opt.nshards) * shard, + min(-(decoder_layers // -opt.nshards) * (shard + 1), decoder_layers), + 1, + ): + + w = get_weight( + checkpoint, "model.layers." + str(i) + ".input_layernorm.weight" + ) + if w is not None: + onmt_safetensor[ + "decoder.transformer_layers." + str(i) + ".layer_norm_1.weight" + ] = w + + for param in params: + sourcelist = [ + ".self_attn.q_proj.", + ".self_attn.k_proj.", + ".self_attn.v_proj.", + ".self_attn.o_proj.", + ".mlp.gate_proj.", + ".mlp.down_proj.", + ".mlp.up_proj.", + ] + targetlist = [ + ".self_attn.linear_query.", + ".self_attn.linear_keys.", + ".self_attn.linear_values.", + ".self_attn.final_linear.", + ".feed_forward.w_1.", + ".feed_forward.w_2.", + ".feed_forward.w_3.", + ] + for source, target in zip(sourcelist, targetlist): + w = get_weight( + checkpoint, "model.layers." + str(i) + source + param + ) + if w is not None: + onmt_safetensor[ + "decoder.transformer_layers." + str(i) + target + param + ] = w + + w = get_weight( + checkpoint, + "model.layers." + str(i) + ".post_attention_layernorm.weight", + ) + if w is not None: + onmt_safetensor[ + "decoder.transformer_layers." + + str(i) + + ".feed_forward.layer_norm.weight" + ] = w if shard == 0: vocab_size = onmt_safetensor["generator.weight"].size(0)