forked from PaddlePaddle/PaddleNLP
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[LLM] Add tools for parameters (PaddlePaddle#9137)
* add tools
- Loading branch information
1 parent
dbd3947
commit 37c211a
Showing
3 changed files
with
316 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import os | ||
|
||
import paddle | ||
|
||
from paddlenlp.transformers import AutoConfig, AutoModelForCausalLM | ||
from paddlenlp.transformers.model_utils import load_tp_checkpoint | ||
|
||
|
||
def parse_arguments(): | ||
""" | ||
parse_arguments | ||
""" | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--gqa_model_path", type=str, required=True, help="the dir of gqa_model weight") | ||
parser.add_argument("--mha_model_path", type=str, required=True, help="the saved dir of mha_model weight") | ||
parser.add_argument( | ||
"--model_prefix_name", default="model_state", type=str, required=False, help="model prefix name" | ||
) | ||
return parser.parse_args() | ||
|
||
|
||
def convert(gqa_model_path, mha_model_path, config_path): | ||
""" | ||
Convert model from gqa to mha | ||
Args: | ||
gqa_model_path (str): the path of gqa_model weight | ||
mha_model_path (str): the saved path of mha_model weight | ||
config_path (str): the path of model's config | ||
""" | ||
config = AutoConfig.from_pretrained(gqa_model_path) | ||
|
||
model = AutoModelForCausalLM.from_pretrained(gqa_model_path) | ||
model_state = load_tp_checkpoint(gqa_model_path, model, config) | ||
|
||
model_type = config["model_type"] | ||
hidden_size = config["hidden_size"] | ||
num_head = config["num_attention_heads"] | ||
num_key_value_heads = config["num_key_value_heads"] | ||
dim_head = hidden_size // num_head | ||
num_layers = config["num_hidden_layers"] | ||
num_gqa_partitions = num_head // num_key_value_heads | ||
|
||
for i in range(num_layers): | ||
print(f"num_layers: {i}") | ||
# qkv weight [hidden_size, (num_head + 2 * num_key_value_heads) * dim_head] | ||
q_weight = model_state[f"{model_type}.layers.{i}.self_attn.q_proj.weight"] | ||
k_weight = model_state[f"{model_type}.layers.{i}.self_attn.k_proj.weight"] | ||
v_weight = model_state[f"{model_type}.layers.{i}.self_attn.v_proj.weight"] | ||
print(f"q_weight.shape: {q_weight.shape}") | ||
print(f"k_weight.shape: {k_weight.shape}") | ||
print(f"k_weight.shape: {v_weight.shape}") | ||
|
||
k_weight = k_weight.reshape([hidden_size, num_key_value_heads, dim_head]) | ||
v_weight = v_weight.reshape([hidden_size, num_key_value_heads, dim_head]) | ||
print(f"(reshape) k_weight.shape: {k_weight.shape}") | ||
print(f"(reshape) v_weight.shape: {v_weight.shape}") | ||
|
||
kk_weight = paddle.reshape( | ||
paddle.stack([k_weight] * num_gqa_partitions, axis=2), [hidden_size, num_head, dim_head] | ||
) | ||
vv_weight = paddle.reshape( | ||
paddle.stack([v_weight] * num_gqa_partitions, axis=2), [hidden_size, num_head, dim_head] | ||
) | ||
print(f"(extend) k_weight.shape: {kk_weight.shape}") | ||
print(f"(extend) v_weight.shape: {vv_weight.shape}") | ||
|
||
new_k_weight = kk_weight.reshape([hidden_size, num_head * dim_head]) | ||
new_v_weight = vv_weight.reshape([hidden_size, num_head * dim_head]) | ||
print(f"new_k_weight.shape: {new_k_weight.shape}") | ||
print(f"new_v_weight.shape: {new_v_weight.shape}") | ||
|
||
model_state[f"{model_type}.layers.{i}.self_attn.k_proj.weight"] = new_k_weight | ||
model_state[f"{model_type}.layers.{i}.self_attn.v_proj.weight"] = new_v_weight | ||
|
||
if ( | ||
f"{model_type}.layers.{i}.self_attn.q_proj.bias" in model_state | ||
and f"{model_type}.layers.{i}.self_attn.k_proj.bias" in model_state | ||
and f"{model_type}.layers.{i}.self_attn.v_proj.bias" in model_state | ||
): | ||
print("bias") | ||
|
||
q_bias = model_state[f"{model_type}.layers.{i}.self_attn.q_proj.bias"] | ||
k_bias = model_state[f"{model_type}.layers.{i}.self_attn.k_proj.bias"] | ||
v_bias = model_state[f"{model_type}.layers.{i}.self_attn.v_proj.bias"] | ||
print(f"q_bias.shape: {q_bias.shape}") | ||
print(f"k_bias.shape: {k_bias.shape}") | ||
print(f"v_bias.shape: {v_bias.shape}") | ||
|
||
k_bias = k_bias.reshape([num_key_value_heads, dim_head]) | ||
v_bias = v_bias.reshape([num_key_value_heads, dim_head]) | ||
print(f"(reshape) k_bias.shape: {k_bias.shape}") | ||
print(f"(reshape) v_bias.shape: {v_bias.shape}") | ||
|
||
kk_bias = paddle.reshape(paddle.stack([k_bias] * num_gqa_partitions, axis=1), [num_head, dim_head]) | ||
vv_bias = paddle.reshape(paddle.stack([v_bias] * num_gqa_partitions, axis=1), [num_head, dim_head]) | ||
print(f"(extend) k_bias.shape: {kk_bias.shape}") | ||
print(f"(extend) v_bias.shape: {vv_bias.shape}") | ||
|
||
new_k_bias = kk_bias.reshape([num_head * dim_head]) | ||
new_v_bias = vv_bias.reshape([num_head * dim_head]) | ||
print(f"new_k_bias.shape: {new_k_bias.shape}") | ||
print(f"new_v_bias.shape: {new_v_bias.shape}") | ||
|
||
model_state[f"{model_type}.layers.{i}.self_attn.k_proj.bias"] = new_k_bias | ||
model_state[f"{model_type}.layers.{i}.self_attn.v_proj.bias"] = new_v_bias | ||
|
||
paddle.save(model_state, mha_model_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
""" | ||
Script to convert model from gqa to mha. | ||
""" | ||
args = parse_arguments() | ||
config_path = os.path.join(args.gqa_model_path, "config.json") | ||
mha_model_path = os.path.join(args.mha_model_path, f"{args.model_prefix_name}.pdparams") | ||
|
||
assert os.path.exists(config_path), "config.json is not found in {}".format(args.gqa_model_path) | ||
assert os.path.exists(args.gqa_model_path), "{} is not found".format(args.gqa_model_path) | ||
convert(args.gqa_model_path, mha_model_path, config_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
export device="0,1,3,4" | ||
export CUDA_VISIBLE_DEVICES=${device} | ||
|
||
model_path=${1-"/path/to/model"} | ||
|
||
python -m paddle.distributed.launch \ | ||
--gpus ${device} \ | ||
split_weights.py \ | ||
--model_path ${model_path} \ | ||
--output_path ${model_path}/tp4 \ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import json | ||
import os | ||
|
||
import numpy as np | ||
import paddle | ||
|
||
from paddlenlp.generation import GenerationConfig | ||
from paddlenlp.transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | ||
from paddlenlp.transformers.model_utils import load_tp_checkpoint | ||
from paddlenlp.utils import llm_utils | ||
|
||
|
||
def parse_arguments(): | ||
""" | ||
parse_arguments | ||
""" | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model_path", default=None, type=str, required=True, help="The directory of model.") | ||
parser.add_argument("--output_path", default=None, type=str, help="The directory of splited model") | ||
parser.add_argument("--model_rank_id", default=None, type=int, help="Input model mp degree.") | ||
parser.add_argument("--dtype", default="float16", type=str, help="The dtype of model weights.") | ||
return parser.parse_args() | ||
|
||
|
||
def split(args): | ||
""" | ||
Split model weight | ||
""" | ||
rank, nranks = llm_utils.init_dist_env() | ||
|
||
if args.output_path is None: | ||
args.output_path = os.path.join(args.model_path, f"{nranks}_ranks") | ||
|
||
paddle.set_default_dtype(args.dtype) | ||
|
||
config = AutoConfig.from_pretrained(args.model_path) | ||
config.tensor_parallel_degree = nranks | ||
config.tensor_parallel_rank = rank | ||
|
||
generation_config = GenerationConfig.from_pretrained(args.model_path) | ||
model = AutoModelForCausalLM.from_pretrained(args.model_path) | ||
tokenizer = AutoTokenizer.from_pretrained(args.model_path) | ||
|
||
if args.model_rank_id is not None: | ||
model_path = os.path.join(args.model_path, f"model_state.tp0{args.model_rank_id - 1}.pdparams") | ||
assert os.path.isfile(model_path), f"{model_path} not exist" | ||
state_dict = load_tp_checkpoint(args.model_path, model, config) | ||
model_rank = args.model_rank_id | ||
save_base_rank = model_rank * nranks | ||
else: | ||
state_dict = load_tp_checkpoint(args.model_path, model, config) | ||
model_rank = 0 | ||
save_base_rank = 0 | ||
|
||
weight_file = os.path.join(args.output_path, f"model_state.tp0{rank + save_base_rank}.pdparams") | ||
paddle.save(state_dict, weight_file) | ||
|
||
# process weight scales | ||
possible_weight_scales_path = os.path.join(args.model_path, f"weight_scales_{model_rank}.json") | ||
if os.path.exists(possible_weight_scales_path) and rank == 0: | ||
with open(possible_weight_scales_path, "r") as f: | ||
weight_scales_dict = json.load(f) | ||
|
||
processed_weight_scales = [{} for i in range(nranks)] | ||
for k, v in weight_scales_dict.items(): | ||
if "self_attn.q_proj" in k: | ||
splited_value = np.split(np.array(v), nranks, axis=-1) | ||
for tp_rank in range(nranks): | ||
processed_weight_scales[tp_rank][k] = splited_value[tp_rank].tolist() | ||
elif "self_attn.k_proj" in k: | ||
splited_value = np.split(np.array(v), nranks, axis=-1) | ||
for tp_rank in range(nranks): | ||
processed_weight_scales[tp_rank][k] = splited_value[tp_rank].tolist() | ||
elif "self_attn.v_proj" in k: | ||
splited_value = np.split(np.array(v), nranks, axis=-1) | ||
for tp_rank in range(nranks): | ||
processed_weight_scales[tp_rank][k] = splited_value[tp_rank].tolist() | ||
elif "self_attn.o_proj" in k: | ||
for tp_rank in range(nranks): | ||
processed_weight_scales[tp_rank][k] = v | ||
elif "mlp.gate_proj" in k: | ||
splited_value = np.split(np.array(v), nranks, axis=-1) | ||
for tp_rank in range(nranks): | ||
processed_weight_scales[tp_rank][k] = splited_value[tp_rank].tolist() | ||
elif "mlp.up_proj" in k: | ||
splited_value = np.split(np.array(v), nranks, axis=-1) | ||
for tp_rank in range(nranks): | ||
processed_weight_scales[tp_rank][k] = splited_value[tp_rank].tolist() | ||
elif "mlp.down_proj" in k: | ||
for tp_rank in range(nranks): | ||
processed_weight_scales[tp_rank][k] = v | ||
else: | ||
raise ValueError(f"key {k} is not supported!") | ||
|
||
for tp_rank in range(nranks): | ||
save_path = os.path.join(args.output_path, f"weight_scales_{tp_rank + save_base_rank}.json") | ||
with open(save_path, "w") as f: | ||
print("weight scale save_path:", save_path) | ||
json.dump(processed_weight_scales[tp_rank], f) | ||
|
||
# process cachekv scales | ||
possible_cache_path = os.path.join(args.model_path, f"cachekv_scales_{model_rank}.json") | ||
if os.path.exists(possible_cache_path) and rank == 0: | ||
with open(possible_cache_path, "r") as f: | ||
cache_dict = json.load(f) | ||
|
||
processed_cachekv_scales = [{} for i in range(nranks)] | ||
for k, v in cache_dict.items(): | ||
v = np.array(v).flatten() | ||
splited_value = np.split(np.array(v), nranks, axis=-1) | ||
for tp_rank in range(nranks): | ||
processed_cachekv_scales[tp_rank][k] = splited_value[tp_rank].tolist() | ||
for tp_rank in range(nranks): | ||
save_path = os.path.join(args.output_path, f"cachekv_scales_{tp_rank + save_base_rank}.json") | ||
print("cachekv scale save_path:", save_path) | ||
with open(save_path, "w") as f: | ||
json.dump(processed_cachekv_scales[tp_rank], f) | ||
|
||
# process act scales | ||
possible_act_scales_path = os.path.join(args.model_path, f"act_scales_{model_rank}.json") | ||
if os.path.exists(possible_act_scales_path) and rank == 0: | ||
with open(possible_act_scales_path, "r") as f: | ||
act_scale = json.load(f) | ||
for tp_rank in range(nranks): | ||
save_path = os.path.join(args.output_path, f"act_scales_{tp_rank + save_base_rank}.json") | ||
with open(save_path, "w") as outf: | ||
print("act scale save_path:", save_path) | ||
json.dump(act_scale, outf) | ||
|
||
if rank == 0: | ||
tokenizer.save_pretrained(args.output_path) | ||
config.save_pretrained(args.output_path) | ||
generation_config.save_pretrained(args.output_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
""" | ||
Script to split model weight. | ||
""" | ||
args = parse_arguments() | ||
split(args) |