Skip to content

Commit

Permalink
[LLM] Support fuse attention q, k, v weights (#8202)
Browse files Browse the repository at this point in the history
1. add use-interface & fuse action

1.1. modify 1., code order

2. switch to name_mapping

3. solve tp branch

3.2 follow hui, handel qkv separately

3.3 handle pdparams

3.4 from torch

3.5 abandon low_cpu_mem_usage

3.6 solve shard branch

* 3.6.1 solve shard branch after rebase develop

* code clean

* remove debug comment

* Redefine fuse and split functions

* Redefine fuse and split functions

* comment and fix

* update method

* update QKV fuse and split

* support fuse weights in multi-files

* add precision compare

* simplify function call

* support use_fast_ffn

* clean modeling and configuration

* add test for gpt and opt

* fix tp_actions get

* add fast_ffn test

* add Qwen2Moe

* Revert "add Qwen2Moe"

This reverts commit 113b883.

* add test for split

* update doc

* update filter_dict_keys

---------

Co-authored-by: Zii <ziangqin.baidu@gmail.com>
  • Loading branch information
DrownFish19 and ziangqin-baidu committed Apr 25, 2024
1 parent 1a73b76 commit f29a7b9
Show file tree
Hide file tree
Showing 9 changed files with 701 additions and 6 deletions.
237 changes: 237 additions & 0 deletions paddlenlp/transformers/conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,118 @@ def splited_qkv_to_tensor_parallel_qkv(weight_list, num_attention_heads):
return naive_merged_qkv_to_tensor_parallel_qkv(weight)


def fuse_param_func():
def fn(fuse_params, is_qkv=False, num_heads=None, num_key_value_heads=None):
"""fuse function for fusing weights
(1) fuse_attention_qkv
q => [q1,q2,q3,q4]
k => [k1,k2,k3,k4] or [k1,k2] for GQA
v => [v1,v2,v3,v4] or [v1,v2] for GQA
fused weight => [q1,k1,v1,q2,k2,v2,q3,k3,v3,q4,k4,v4]
or for GQA [q1,q2,k1,v1,q3,q4,k2,v2]
(2) fuse_attention_ffn
directly fuse weights to 1 parts
[gate_weight], [up_weight] => [gate_weight, up_weight]
Args:
fuse_params (_type_): to be fused weights
is_qkv (bool, optional): for attention qkv weights. Defaults to False.
num_heads (_type_, optional): query heads. Defaults to None.
num_key_value_heads (_type_, optional): key and value heads. Defaults to None.
Returns:
_type_: fused weights
"""
concat_fn = np.concatenate
split_fn = np.split
if isinstance(fuse_params[0], paddle.Tensor):
concat_fn = paddle.concat
split_fn = paddle.split

if is_qkv:
# fuse_attention_qkv
assert num_heads, f"num_heads should be number of heads for Q, but got {num_heads}"
assert (
num_key_value_heads
), f"num_key_value_heads should be number of key_value_heads for K and V, but got {num_key_value_heads}"
assert (
len(fuse_params) == 3
), f"fuse_params length is not equal 3, it should be Q K V list. but got length {len(fuse_params)}"
num_query_groups = num_heads // num_key_value_heads
q_list = split_fn(fuse_params[0], num_heads, axis=-1)
k_list = split_fn(fuse_params[1], num_key_value_heads, axis=-1)
v_list = split_fn(fuse_params[2], num_key_value_heads, axis=-1)

qkv_pairs = []
for i in range(num_key_value_heads):
qkv_pairs += q_list[i * num_query_groups : (i + 1) * num_query_groups]
qkv_pairs.append(k_list[i])
qkv_pairs.append(v_list[i])
return concat_fn(qkv_pairs, axis=-1)
else:
# fuse_attention_ffn
return concat_fn(fuse_params, axis=-1)

return fn


def split_param_func():
def fn(fused_param, split_nums=2, is_qkv=False, num_heads=None, num_key_value_heads=None):
"""split function for splitting weights
(1) fuse_attention_qkv
fused weight => [q1,k1,v1,q2,k2,v2,q3,k3,v3,q4,k4,v4]
or for GQA [q1,q2,k1,v1,q3,q4,k2,v2]
after split
q => [q1,q2,q3,q4]
k => [k1,k2,k3,k4] or [k1,k2] for GQA
v => [v1,v2,v3,v4] or [v1,v2] for GQA
(2) fuse_attention_ffn
directly split weight to 2 parts
[gate_weight, up_weight] => [gate_weight], [up_weight]
Args:
fused_param (_type_): len(fused_param)=1, only one weight to be splitted
split_nums (int, optional): split_nums. Defaults to 2.
is_qkv (bool, optional): for attention qkv weights. Defaults to False.
num_heads (_type_, optional): query heads. Defaults to None.
num_key_value_heads (_type_, optional): key and value heads. Defaults to None.
Returns:
_type_: splitted weights
"""
concat_fn = np.concatenate
split_fn = np.split
if isinstance(fused_param, paddle.Tensor):
concat_fn = paddle.concat
split_fn = paddle.split

if is_qkv:
# fuse_attention_qkv
assert num_heads, f"num_heads should be number of heads for Q, but got {num_heads}"
assert (
num_key_value_heads
), f"num_key_value_heads should be number of key_value_heads for K and V, but got {num_key_value_heads}"
num_query_groups = num_heads // num_key_value_heads
q_list, k_list, v_list = [], [], []
split_heads = split_fn(fused_param, num_heads + 2 * num_key_value_heads, axis=-1)
for i in range(num_key_value_heads):
q_list += split_heads[i * (num_query_groups + 2) : (i + 1) * (num_query_groups + 2) - 2]
k_list.append(split_heads[(i + 1) * (num_query_groups + 2) - 2])
v_list.append(split_heads[(i + 1) * (num_query_groups + 2) - 1])
return concat_fn(q_list, axis=-1), concat_fn(k_list, axis=-1), concat_fn(v_list, axis=-1)
else:
# fuse_attention_ffn
return split_fn(fused_param, split_nums, axis=-1)

return fn


def split_or_fuse_func(is_fuse=True):
return fuse_param_func() if is_fuse else split_param_func()


def get_tensor_parallel_merge_func(tensor_parallel_degree, tensor_parallel_rank, num_attention_heads=None):
def fn(
x,
Expand Down Expand Up @@ -1100,6 +1212,7 @@ def convert_tensor_parallel(
weight_file (str | None): the weight file path of `model_state.pdparams` file
config (PretrainedConfig): the PretrainedConfig instance of model
"""

name_action_mappings = cls._get_tensor_parallel_mappings(config)
if state_dict is None:
with device_guard("cpu"):
Expand Down Expand Up @@ -1201,6 +1314,130 @@ def _resolve_prefix_keys(state_keys_base, state_keys_real, ignore_error=False):

return state_keys_map

@classmethod
def convert_fuse_and_split(cls, config: PretrainedConfig, state_dict, tp_actions=None):
loaded_keys = state_dict.keys()
# collect and convert fuse/split action
fused_and_split_keys = []
fuse_actions, resume_keys = cls.get_fuse_or_split_param_convert_actions(config, loaded_keys, is_fuse=True)
for keys, action in fuse_actions.items():
origin_states = [state_dict.pop(key) for key in keys[:-1]]
state_dict[keys[-1]] = action(origin_states)
fused_and_split_keys.append(keys[-1])
logger.info(f"Fusing parameter: {keys[:-1]} into {keys[-1]}")

split_actions, _ = cls.get_fuse_or_split_param_convert_actions(config, loaded_keys, is_fuse=False)
for keys, action in split_actions.items():
origin_state = state_dict.pop(keys[-1])
split_states = action(origin_state)
for key_idx, key in enumerate(keys[:-1]):
state_dict[key] = split_states[key_idx]
fused_and_split_keys.append(key)
logger.info(f"Splitting parameter: {keys[-1]} into {keys[:-1]}")

if tp_actions is not None:
for key in fused_and_split_keys:
for name in tp_actions.keys():
if key.endswith(name):
with device_guard():
state_dict[key] = paddle.Tensor(tp_actions[name](state_dict.pop(key)), zero_copy=True)
break

# when shard file split the weight as follows, some weights need to be resumed for next shard file
# shard-001-file: q_weight, k_weight
# shard_002-file: v_weight
resume_state_dict = {k: state_dict[k] for k in resume_keys if k in state_dict}
return state_dict, resume_state_dict

@classmethod
def get_fuse_or_split_param_convert_actions(
cls,
config: PretrainedConfig,
loaded_state_dict_keys,
is_fuse=True,
ignore_error=False,
):
name_action_mappings = cls._get_fuse_or_split_param_mappings(config, is_fuse)
state_keys_map = cls._resolve_prefix_keys_for_fuse_and_split(
name_action_mappings.keys(), loaded_state_dict_keys, ignore_error, is_fuse
)
for k, v in state_keys_map.items():
name_action_mappings[v] = name_action_mappings.pop(k)

# filter name_action_mappings with corresponding weights
# fusing: verify all of the keys in name_action_mappings are in loaded_state_dict_keys
# splitting: verify the last key in name_action_mappings is in loaded_state_dict_keys
filter_name_action = {}
resume_keys = []
if is_fuse:
for k, v in name_action_mappings.items():
cond = True
if not all(item in loaded_state_dict_keys for item in k[:-1]):
# resume keys for next fuse
resume_keys += k[:-1]
cond = False
if cond:
filter_name_action[k] = v
else:
for k, v in name_action_mappings.items():
if k[-1] in loaded_state_dict_keys:
filter_name_action[k] = v

return filter_name_action, resume_keys

@classmethod
def _get_fuse_or_split_param_mappings(cls, config: PretrainedConfig, is_fuse=True) -> List[StateDictNameMapping]:
"""get fused parameter mapping of PretrainedModel
Args:
config (PretrainedConfig): the configuration of name-mapping
Raises:
NotImplementedError:
Returns:
List[StateDictNameMapping]: the name-mappings for tensor_parallel
"""
# raise NotImplementedError(
# f"`_get_fuse_or_split_param_mappings` is not implemented for {cls.__name__}`. To implement it, you should "
# f"overwrite this method in the class {cls.__name__} in `{cls.__module__}.py`"
# )
return {}

@staticmethod
def _resolve_prefix_keys_for_fuse_and_split(state_keys_base, state_keys_real, ignore_error=False, is_fuse=True):
state_keys_map = {}

# use the tuple (x1,x2,x3,x4) as one key, and the prefix of x1,x2,x3 is used as a new key x4 or
# the last key x4 is used as new keys x1,x2,x3. And, the tuple also could be (a) (x1, x1) -> convert x1 to x1;
# (b) (x1,x2,x3) -> fuse x1 and x2 to x3; (c) (x1,x2,x3,x4) -> fuse x1, x2 and x3 to x4.

# is_fuse: True -> fuse, False -> split
# True: (x1,x2,x3,x4) -> [x1,x2,x3] are exist in state_keys_real, x4 is not exist in state_keys_real
# False: (x1,x2,x3,x4) -> [x1,x2,x3] are not exist in state_keys_real, x4 is exist in state_keys_real

for keys in state_keys_base:
prefix = ""
if is_fuse:
for x in state_keys_real:
for base_key in keys[:-1]:
if x.endswith(base_key):
prefix = x.replace(base_key, "")
break
if prefix != "":
break
else:
base_key = keys[-1]
for x in state_keys_real:
if x.endswith(base_key):
prefix = x.replace(base_key, "")
break

new_keys = tuple([prefix + key for key in keys])
state_keys_map[keys] = new_keys

return state_keys_map


class Converter(ConversionMixin, LogitComparer):
"""some converters are implemented in ppdiffusers, so if remove it directly, it will make ppdiffusers down.
Expand Down
43 changes: 43 additions & 0 deletions paddlenlp/transformers/gpt/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,49 @@ def get_tensor_parallel_split_mappings(num_layers):

return mappings

@classmethod
def _get_fuse_or_split_param_mappings(cls, config: GPTConfig, is_fuse=False):
# return parameter fuse utils
from paddlenlp.transformers.conversion_utils import split_or_fuse_func

fn = split_or_fuse_func(is_fuse=is_fuse)

# last key is fused key, other keys are to be fused.
fuse_qkv_keys = (
"decoder.layers.0.self_attn.q_proj.weight",
"decoder.layers.0.self_attn.k_proj.weight",
"decoder.layers.0.self_attn.v_proj.weight",
"decoder.layers.0.self_attn.qkv_proj.weight",
)
fuse_qkv_bias_keys = (
"decoder.layers.0.self_attn.q_proj.bias",
"decoder.layers.0.self_attn.k_proj.bias",
"decoder.layers.0.self_attn.v_proj.bias",
"decoder.layers.0.self_attn.qkv_proj.bias",
)
num_heads = config.num_attention_heads
num_key_value_heads = getattr(config, "num_key_value_heads", num_heads)
fuse_attention_qkv = getattr(config, "fuse_attention_qkv", False)

final_actions = {}
if is_fuse:
if fuse_attention_qkv:
for i in range(config.num_hidden_layers):
for keys in [fuse_qkv_keys, fuse_qkv_bias_keys]:
new_keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in keys])
final_actions[new_keys] = partial(
fn, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads
)
else:
if not fuse_attention_qkv:
for i in range(config.num_hidden_layers):
for keys in [fuse_qkv_keys, fuse_qkv_bias_keys]:
new_keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in keys])
final_actions[new_keys] = partial(
fn, split_nums=3, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads
)
return final_actions

@classmethod
def _get_name_mappings(cls, config: GPTConfig) -> list[StateDictNameMapping]:
mappings: list[StateDictNameMapping] = []
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/transformers/gpt/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ class GPTForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
config_class = GPTConfig

_get_tensor_parallel_mappings = GPTPretrainedModel._get_tensor_parallel_mappings
_get_fuse_or_split_param_mappings = GPTPretrainedModel._get_fuse_or_split_param_mappings
_init_weights = GPTPretrainedModel._init_weights

pretrained_init_configuration = GPTPretrainedModel.pretrained_init_configuration
Expand Down
50 changes: 50 additions & 0 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,6 +1293,56 @@ def get_tensor_parallel_split_mappings(num_layers):

return mappings

@classmethod
def _get_fuse_or_split_param_mappings(cls, config: LlamaConfig, is_fuse=False):
# return parameter fuse utils
from paddlenlp.transformers.conversion_utils import split_or_fuse_func

fn = split_or_fuse_func(is_fuse=is_fuse)

# last key is fused key, other keys are to be fused.
fuse_qkv_keys = (
"layers.0.self_attn.q_proj.weight",
"layers.0.self_attn.k_proj.weight",
"layers.0.self_attn.v_proj.weight",
"layers.0.self_attn.qkv_proj.weight",
)

fuse_gate_up_keys = (
"layers.0.mlp.gate_proj.weight",
"layers.0.mlp.up_proj.weight",
"layers.0.mlp.gate_up_fused_proj.weight",
)
num_heads = config.num_attention_heads
num_key_value_heads = getattr(config, "num_key_value_heads", num_heads)
fuse_attention_qkv = getattr(config, "fuse_attention_qkv", False)
fuse_attention_ffn = getattr(config, "fuse_attention_ffn", False)

final_actions = {}
if is_fuse:
if fuse_attention_qkv:
for i in range(config.num_hidden_layers):
keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_qkv_keys])
final_actions[keys] = partial(
fn, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads
)
if fuse_attention_ffn:
for i in range(config.num_hidden_layers):
keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys])
final_actions[keys] = fn
else:
if not fuse_attention_qkv:
for i in range(config.num_hidden_layers):
keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_qkv_keys])
final_actions[keys] = partial(
fn, split_nums=3, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads
)
if not fuse_attention_ffn:
for i in range(config.num_hidden_layers):
keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys])
final_actions[keys] = partial(fn, split_nums=2)
return final_actions

def _init_weights(self, layer):
"""Initialization hook"""
if self.config.tensor_parallel_degree > 1:
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/transformers/llama/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ class LlamaForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
config_class = LlamaConfig

_get_tensor_parallel_mappings = LlamaPretrainedModel._get_tensor_parallel_mappings
_get_fuse_or_split_param_mappings = LlamaPretrainedModel._get_fuse_or_split_param_mappings
_init_weights = LlamaPretrainedModel._init_weights
_keys_to_ignore_on_load_unexpected = LlamaPretrainedModel._keys_to_ignore_on_load_unexpected

Expand Down
Loading

0 comments on commit f29a7b9

Please sign in to comment.