Skip to content

Commit

Permalink
[LLM] Add tools for parameters (PaddlePaddle#9137)
Browse files Browse the repository at this point in the history
* add tools
  • Loading branch information
Hanyonggong authored Oct 10, 2024
1 parent dbd3947 commit 37c211a
Show file tree
Hide file tree
Showing 3 changed files with 316 additions and 0 deletions.
136 changes: 136 additions & 0 deletions llm/tools/convert_gqa_to_mha.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

import paddle

from paddlenlp.transformers import AutoConfig, AutoModelForCausalLM
from paddlenlp.transformers.model_utils import load_tp_checkpoint


def parse_arguments():
"""
parse_arguments
"""
parser = argparse.ArgumentParser()
parser.add_argument("--gqa_model_path", type=str, required=True, help="the dir of gqa_model weight")
parser.add_argument("--mha_model_path", type=str, required=True, help="the saved dir of mha_model weight")
parser.add_argument(
"--model_prefix_name", default="model_state", type=str, required=False, help="model prefix name"
)
return parser.parse_args()


def convert(gqa_model_path, mha_model_path, config_path):
"""
Convert model from gqa to mha
Args:
gqa_model_path (str): the path of gqa_model weight
mha_model_path (str): the saved path of mha_model weight
config_path (str): the path of model's config
"""
config = AutoConfig.from_pretrained(gqa_model_path)

model = AutoModelForCausalLM.from_pretrained(gqa_model_path)
model_state = load_tp_checkpoint(gqa_model_path, model, config)

model_type = config["model_type"]
hidden_size = config["hidden_size"]
num_head = config["num_attention_heads"]
num_key_value_heads = config["num_key_value_heads"]
dim_head = hidden_size // num_head
num_layers = config["num_hidden_layers"]
num_gqa_partitions = num_head // num_key_value_heads

for i in range(num_layers):
print(f"num_layers: {i}")
# qkv weight [hidden_size, (num_head + 2 * num_key_value_heads) * dim_head]
q_weight = model_state[f"{model_type}.layers.{i}.self_attn.q_proj.weight"]
k_weight = model_state[f"{model_type}.layers.{i}.self_attn.k_proj.weight"]
v_weight = model_state[f"{model_type}.layers.{i}.self_attn.v_proj.weight"]
print(f"q_weight.shape: {q_weight.shape}")
print(f"k_weight.shape: {k_weight.shape}")
print(f"k_weight.shape: {v_weight.shape}")

k_weight = k_weight.reshape([hidden_size, num_key_value_heads, dim_head])
v_weight = v_weight.reshape([hidden_size, num_key_value_heads, dim_head])
print(f"(reshape) k_weight.shape: {k_weight.shape}")
print(f"(reshape) v_weight.shape: {v_weight.shape}")

kk_weight = paddle.reshape(
paddle.stack([k_weight] * num_gqa_partitions, axis=2), [hidden_size, num_head, dim_head]
)
vv_weight = paddle.reshape(
paddle.stack([v_weight] * num_gqa_partitions, axis=2), [hidden_size, num_head, dim_head]
)
print(f"(extend) k_weight.shape: {kk_weight.shape}")
print(f"(extend) v_weight.shape: {vv_weight.shape}")

new_k_weight = kk_weight.reshape([hidden_size, num_head * dim_head])
new_v_weight = vv_weight.reshape([hidden_size, num_head * dim_head])
print(f"new_k_weight.shape: {new_k_weight.shape}")
print(f"new_v_weight.shape: {new_v_weight.shape}")

model_state[f"{model_type}.layers.{i}.self_attn.k_proj.weight"] = new_k_weight
model_state[f"{model_type}.layers.{i}.self_attn.v_proj.weight"] = new_v_weight

if (
f"{model_type}.layers.{i}.self_attn.q_proj.bias" in model_state
and f"{model_type}.layers.{i}.self_attn.k_proj.bias" in model_state
and f"{model_type}.layers.{i}.self_attn.v_proj.bias" in model_state
):
print("bias")

q_bias = model_state[f"{model_type}.layers.{i}.self_attn.q_proj.bias"]
k_bias = model_state[f"{model_type}.layers.{i}.self_attn.k_proj.bias"]
v_bias = model_state[f"{model_type}.layers.{i}.self_attn.v_proj.bias"]
print(f"q_bias.shape: {q_bias.shape}")
print(f"k_bias.shape: {k_bias.shape}")
print(f"v_bias.shape: {v_bias.shape}")

k_bias = k_bias.reshape([num_key_value_heads, dim_head])
v_bias = v_bias.reshape([num_key_value_heads, dim_head])
print(f"(reshape) k_bias.shape: {k_bias.shape}")
print(f"(reshape) v_bias.shape: {v_bias.shape}")

kk_bias = paddle.reshape(paddle.stack([k_bias] * num_gqa_partitions, axis=1), [num_head, dim_head])
vv_bias = paddle.reshape(paddle.stack([v_bias] * num_gqa_partitions, axis=1), [num_head, dim_head])
print(f"(extend) k_bias.shape: {kk_bias.shape}")
print(f"(extend) v_bias.shape: {vv_bias.shape}")

new_k_bias = kk_bias.reshape([num_head * dim_head])
new_v_bias = vv_bias.reshape([num_head * dim_head])
print(f"new_k_bias.shape: {new_k_bias.shape}")
print(f"new_v_bias.shape: {new_v_bias.shape}")

model_state[f"{model_type}.layers.{i}.self_attn.k_proj.bias"] = new_k_bias
model_state[f"{model_type}.layers.{i}.self_attn.v_proj.bias"] = new_v_bias

paddle.save(model_state, mha_model_path)


if __name__ == "__main__":
"""
Script to convert model from gqa to mha.
"""
args = parse_arguments()
config_path = os.path.join(args.gqa_model_path, "config.json")
mha_model_path = os.path.join(args.mha_model_path, f"{args.model_prefix_name}.pdparams")

assert os.path.exists(config_path), "config.json is not found in {}".format(args.gqa_model_path)
assert os.path.exists(args.gqa_model_path), "{} is not found".format(args.gqa_model_path)
convert(args.gqa_model_path, mha_model_path, config_path)
24 changes: 24 additions & 0 deletions llm/tools/run_split_weights.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

export device="0,1,3,4"
export CUDA_VISIBLE_DEVICES=${device}

model_path=${1-"/path/to/model"}

python -m paddle.distributed.launch \
--gpus ${device} \
split_weights.py \
--model_path ${model_path} \
--output_path ${model_path}/tp4 \
156 changes: 156 additions & 0 deletions llm/tools/split_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
import os

import numpy as np
import paddle

from paddlenlp.generation import GenerationConfig
from paddlenlp.transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from paddlenlp.transformers.model_utils import load_tp_checkpoint
from paddlenlp.utils import llm_utils


def parse_arguments():
"""
parse_arguments
"""
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", default=None, type=str, required=True, help="The directory of model.")
parser.add_argument("--output_path", default=None, type=str, help="The directory of splited model")
parser.add_argument("--model_rank_id", default=None, type=int, help="Input model mp degree.")
parser.add_argument("--dtype", default="float16", type=str, help="The dtype of model weights.")
return parser.parse_args()


def split(args):
"""
Split model weight
"""
rank, nranks = llm_utils.init_dist_env()

if args.output_path is None:
args.output_path = os.path.join(args.model_path, f"{nranks}_ranks")

paddle.set_default_dtype(args.dtype)

config = AutoConfig.from_pretrained(args.model_path)
config.tensor_parallel_degree = nranks
config.tensor_parallel_rank = rank

generation_config = GenerationConfig.from_pretrained(args.model_path)
model = AutoModelForCausalLM.from_pretrained(args.model_path)
tokenizer = AutoTokenizer.from_pretrained(args.model_path)

if args.model_rank_id is not None:
model_path = os.path.join(args.model_path, f"model_state.tp0{args.model_rank_id - 1}.pdparams")
assert os.path.isfile(model_path), f"{model_path} not exist"
state_dict = load_tp_checkpoint(args.model_path, model, config)
model_rank = args.model_rank_id
save_base_rank = model_rank * nranks
else:
state_dict = load_tp_checkpoint(args.model_path, model, config)
model_rank = 0
save_base_rank = 0

weight_file = os.path.join(args.output_path, f"model_state.tp0{rank + save_base_rank}.pdparams")
paddle.save(state_dict, weight_file)

# process weight scales
possible_weight_scales_path = os.path.join(args.model_path, f"weight_scales_{model_rank}.json")
if os.path.exists(possible_weight_scales_path) and rank == 0:
with open(possible_weight_scales_path, "r") as f:
weight_scales_dict = json.load(f)

processed_weight_scales = [{} for i in range(nranks)]
for k, v in weight_scales_dict.items():
if "self_attn.q_proj" in k:
splited_value = np.split(np.array(v), nranks, axis=-1)
for tp_rank in range(nranks):
processed_weight_scales[tp_rank][k] = splited_value[tp_rank].tolist()
elif "self_attn.k_proj" in k:
splited_value = np.split(np.array(v), nranks, axis=-1)
for tp_rank in range(nranks):
processed_weight_scales[tp_rank][k] = splited_value[tp_rank].tolist()
elif "self_attn.v_proj" in k:
splited_value = np.split(np.array(v), nranks, axis=-1)
for tp_rank in range(nranks):
processed_weight_scales[tp_rank][k] = splited_value[tp_rank].tolist()
elif "self_attn.o_proj" in k:
for tp_rank in range(nranks):
processed_weight_scales[tp_rank][k] = v
elif "mlp.gate_proj" in k:
splited_value = np.split(np.array(v), nranks, axis=-1)
for tp_rank in range(nranks):
processed_weight_scales[tp_rank][k] = splited_value[tp_rank].tolist()
elif "mlp.up_proj" in k:
splited_value = np.split(np.array(v), nranks, axis=-1)
for tp_rank in range(nranks):
processed_weight_scales[tp_rank][k] = splited_value[tp_rank].tolist()
elif "mlp.down_proj" in k:
for tp_rank in range(nranks):
processed_weight_scales[tp_rank][k] = v
else:
raise ValueError(f"key {k} is not supported!")

for tp_rank in range(nranks):
save_path = os.path.join(args.output_path, f"weight_scales_{tp_rank + save_base_rank}.json")
with open(save_path, "w") as f:
print("weight scale save_path:", save_path)
json.dump(processed_weight_scales[tp_rank], f)

# process cachekv scales
possible_cache_path = os.path.join(args.model_path, f"cachekv_scales_{model_rank}.json")
if os.path.exists(possible_cache_path) and rank == 0:
with open(possible_cache_path, "r") as f:
cache_dict = json.load(f)

processed_cachekv_scales = [{} for i in range(nranks)]
for k, v in cache_dict.items():
v = np.array(v).flatten()
splited_value = np.split(np.array(v), nranks, axis=-1)
for tp_rank in range(nranks):
processed_cachekv_scales[tp_rank][k] = splited_value[tp_rank].tolist()
for tp_rank in range(nranks):
save_path = os.path.join(args.output_path, f"cachekv_scales_{tp_rank + save_base_rank}.json")
print("cachekv scale save_path:", save_path)
with open(save_path, "w") as f:
json.dump(processed_cachekv_scales[tp_rank], f)

# process act scales
possible_act_scales_path = os.path.join(args.model_path, f"act_scales_{model_rank}.json")
if os.path.exists(possible_act_scales_path) and rank == 0:
with open(possible_act_scales_path, "r") as f:
act_scale = json.load(f)
for tp_rank in range(nranks):
save_path = os.path.join(args.output_path, f"act_scales_{tp_rank + save_base_rank}.json")
with open(save_path, "w") as outf:
print("act scale save_path:", save_path)
json.dump(act_scale, outf)

if rank == 0:
tokenizer.save_pretrained(args.output_path)
config.save_pretrained(args.output_path)
generation_config.save_pretrained(args.output_path)


if __name__ == "__main__":
"""
Script to split model weight.
"""
args = parse_arguments()
split(args)

0 comments on commit 37c211a

Please sign in to comment.