Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLM] Support fuse attention q, k, v weights #8202

Merged
merged 29 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d859826
1. add use-interface & fuse action
ziangqin-baidu Feb 5, 2024
f7b6973
3.6.1 solve shard branch after rebase develop
ziangqin-baidu Mar 11, 2024
05507a7
code clean
ziangqin-baidu Mar 17, 2024
b8b828b
Merge remote-tracking branch 'qinziang/fuse' into dev-fuse-qkv
DrownFish19 Mar 28, 2024
28ed30f
remove debug comment
DrownFish19 Mar 29, 2024
1a1591a
Redefine fuse and split functions
DrownFish19 Apr 1, 2024
ef1fb18
Redefine fuse and split functions
DrownFish19 Apr 1, 2024
f56e602
Merge branch 'dev-fuse-qkv' of github.com:DrownFish19/PaddleNLP into …
DrownFish19 Apr 2, 2024
19faa26
comment and fix
DrownFish19 Apr 2, 2024
4a92dde
update method
DrownFish19 Apr 2, 2024
c0a71d2
update QKV fuse and split
DrownFish19 Apr 9, 2024
d1e7d17
support fuse weights in multi-files
DrownFish19 Apr 10, 2024
58a2c23
add precision compare
DrownFish19 Apr 10, 2024
7e1ab08
simplify function call
DrownFish19 Apr 11, 2024
0458f4e
support use_fast_ffn
DrownFish19 Apr 11, 2024
920155b
clean modeling and configuration
DrownFish19 Apr 11, 2024
a90fcfc
add test for gpt and opt
DrownFish19 Apr 11, 2024
774bb6f
fix tp_actions get
DrownFish19 Apr 11, 2024
057198c
Merge branch 'PaddlePaddle:develop' into dev-fuse-qkv
DrownFish19 Apr 11, 2024
db78c6a
Merge branch 'PaddlePaddle:develop' into dev-fuse-qkv
DrownFish19 Apr 12, 2024
a1aa078
Merge branch 'PaddlePaddle:develop' into dev-fuse-qkv
DrownFish19 Apr 12, 2024
f0d19f6
add fast_ffn test
DrownFish19 Apr 12, 2024
110983d
Merge branch 'dev-fuse-qkv' of github.com:DrownFish19/PaddleNLP into …
DrownFish19 Apr 12, 2024
113b883
add Qwen2Moe
DrownFish19 Apr 16, 2024
1fb0a7d
Revert "add Qwen2Moe"
DrownFish19 Apr 17, 2024
a118f1b
Merge branch 'PaddlePaddle:develop' into dev-fuse-qkv
DrownFish19 Apr 17, 2024
11d0793
add test for split
DrownFish19 Apr 17, 2024
0c98a45
update doc
DrownFish19 Apr 17, 2024
f6f3b0e
update filter_dict_keys
DrownFish19 Apr 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 237 additions & 0 deletions paddlenlp/transformers/conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,118 @@
return naive_merged_qkv_to_tensor_parallel_qkv(weight)


def fuse_param_func():
def fn(fuse_params, is_qkv=False, num_heads=None, num_key_value_heads=None):
"""fuse function for fusing weights

(1) fuse_attention_qkv
q => [q1,q2,q3,q4]
k => [k1,k2,k3,k4] or [k1,k2] for GQA
v => [v1,v2,v3,v4] or [v1,v2] for GQA
fused weight => [q1,k1,v1,q2,k2,v2,q3,k3,v3,q4,k4,v4]
or for GQA [q1,q2,k1,v1,q3,q4,k2,v2]
(2) fuse_attention_ffn
directly fuse weights to 1 parts
[gate_weight], [up_weight] => [gate_weight, up_weight]

Args:
fuse_params (_type_): to be fused weights
is_qkv (bool, optional): for attention qkv weights. Defaults to False.
num_heads (_type_, optional): query heads. Defaults to None.
num_key_value_heads (_type_, optional): key and value heads. Defaults to None.

Returns:
_type_: fused weights
"""
concat_fn = np.concatenate
split_fn = np.split
if isinstance(fuse_params[0], paddle.Tensor):
concat_fn = paddle.concat
split_fn = paddle.split

if is_qkv:
# fuse_attention_qkv
assert num_heads, f"num_heads should be number of heads for Q, but got {num_heads}"
assert (
num_key_value_heads
), f"num_key_value_heads should be number of key_value_heads for K and V, but got {num_key_value_heads}"
assert (
len(fuse_params) == 3
), f"fuse_params length is not equal 3, it should be Q K V list. but got length {len(fuse_params)}"
num_query_groups = num_heads // num_key_value_heads
q_list = split_fn(fuse_params[0], num_heads, axis=-1)
k_list = split_fn(fuse_params[1], num_key_value_heads, axis=-1)
v_list = split_fn(fuse_params[2], num_key_value_heads, axis=-1)

qkv_pairs = []
for i in range(num_key_value_heads):
qkv_pairs += q_list[i * num_query_groups : (i + 1) * num_query_groups]
qkv_pairs.append(k_list[i])
qkv_pairs.append(v_list[i])
return concat_fn(qkv_pairs, axis=-1)
else:
# fuse_attention_ffn
return concat_fn(fuse_params, axis=-1)

return fn


def split_param_func():
def fn(fused_param, split_nums=2, is_qkv=False, num_heads=None, num_key_value_heads=None):
"""split function for splitting weights

(1) fuse_attention_qkv
fused weight => [q1,k1,v1,q2,k2,v2,q3,k3,v3,q4,k4,v4]
or for GQA [q1,q2,k1,v1,q3,q4,k2,v2]
after split
q => [q1,q2,q3,q4]
k => [k1,k2,k3,k4] or [k1,k2] for GQA
v => [v1,v2,v3,v4] or [v1,v2] for GQA
(2) fuse_attention_ffn
directly split weight to 2 parts
[gate_weight, up_weight] => [gate_weight], [up_weight]

Args:
fused_param (_type_): len(fused_param)=1, only one weight to be splitted
split_nums (int, optional): split_nums. Defaults to 2.
is_qkv (bool, optional): for attention qkv weights. Defaults to False.
num_heads (_type_, optional): query heads. Defaults to None.
num_key_value_heads (_type_, optional): key and value heads. Defaults to None.

Returns:
_type_: splitted weights
"""
concat_fn = np.concatenate
split_fn = np.split
if isinstance(fused_param, paddle.Tensor):
concat_fn = paddle.concat
split_fn = paddle.split

if is_qkv:
# fuse_attention_qkv
assert num_heads, f"num_heads should be number of heads for Q, but got {num_heads}"
assert (
num_key_value_heads
), f"num_key_value_heads should be number of key_value_heads for K and V, but got {num_key_value_heads}"
num_query_groups = num_heads // num_key_value_heads
q_list, k_list, v_list = [], [], []
split_heads = split_fn(fused_param, num_heads + 2 * num_key_value_heads, axis=-1)
for i in range(num_key_value_heads):
q_list += split_heads[i * (num_query_groups + 2) : (i + 1) * (num_query_groups + 2) - 2]
k_list.append(split_heads[(i + 1) * (num_query_groups + 2) - 2])
v_list.append(split_heads[(i + 1) * (num_query_groups + 2) - 1])
return concat_fn(q_list, axis=-1), concat_fn(k_list, axis=-1), concat_fn(v_list, axis=-1)
else:
# fuse_attention_ffn
return split_fn(fused_param, split_nums, axis=-1)

return fn


def split_or_fuse_func(is_fuse=True):
return fuse_param_func() if is_fuse else split_param_func()


def get_tensor_parallel_merge_func(tensor_parallel_degree, tensor_parallel_rank, num_attention_heads=None):
def fn(
x,
Expand Down Expand Up @@ -1100,6 +1212,7 @@
weight_file (str | None): the weight file path of `model_state.pdparams` file
config (PretrainedConfig): the PretrainedConfig instance of model
"""

name_action_mappings = cls._get_tensor_parallel_mappings(config)
if state_dict is None:
with device_guard("cpu"):
Expand Down Expand Up @@ -1201,6 +1314,130 @@

return state_keys_map

@classmethod
def convert_fuse_and_split(cls, config: PretrainedConfig, state_dict, tp_actions=None):
loaded_keys = state_dict.keys()
# collect and convert fuse/split action
fused_and_split_keys = []
fuse_actions, resume_keys = cls.get_fuse_or_split_param_convert_actions(config, loaded_keys, is_fuse=True)
for keys, action in fuse_actions.items():
origin_states = [state_dict.pop(key) for key in keys[:-1]]
state_dict[keys[-1]] = action(origin_states)
fused_and_split_keys.append(keys[-1])
logger.info(f"Fusing parameter: {keys[:-1]} into {keys[-1]}")

split_actions, _ = cls.get_fuse_or_split_param_convert_actions(config, loaded_keys, is_fuse=False)
for keys, action in split_actions.items():
origin_state = state_dict.pop(keys[-1])
split_states = action(origin_state)
for key_idx, key in enumerate(keys[:-1]):
state_dict[key] = split_states[key_idx]
fused_and_split_keys.append(key)
logger.info(f"Splitting parameter: {keys[-1]} into {keys[:-1]}")

if tp_actions is not None:
for key in fused_and_split_keys:
for name in tp_actions.keys():
if key.endswith(name):
with device_guard():
state_dict[key] = paddle.Tensor(tp_actions[name](state_dict.pop(key)), zero_copy=True)
break

Check warning on line 1344 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L1339-L1344

Added lines #L1339 - L1344 were not covered by tests

# when shard file split the weight as follows, some weights need to be resumed for next shard file
# shard-001-file: q_weight, k_weight
# shard_002-file: v_weight
resume_state_dict = {k: state_dict[k] for k in resume_keys if k in state_dict}
return state_dict, resume_state_dict

@classmethod
def get_fuse_or_split_param_convert_actions(
cls,
config: PretrainedConfig,
loaded_state_dict_keys,
is_fuse=True,
ignore_error=False,
):
name_action_mappings = cls._get_fuse_or_split_param_mappings(config, is_fuse)
state_keys_map = cls._resolve_prefix_keys_for_fuse_and_split(
name_action_mappings.keys(), loaded_state_dict_keys, ignore_error, is_fuse
)
for k, v in state_keys_map.items():
name_action_mappings[v] = name_action_mappings.pop(k)

# filter name_action_mappings with corresponding weights
# fusing: verify all of the keys in name_action_mappings are in loaded_state_dict_keys
# splitting: verify the last key in name_action_mappings is in loaded_state_dict_keys
filter_name_action = {}
resume_keys = []
if is_fuse:
for k, v in name_action_mappings.items():
cond = True
if not all(item in loaded_state_dict_keys for item in k[:-1]):
# resume keys for next fuse
resume_keys += k[:-1]
cond = False
if cond:
filter_name_action[k] = v
else:
for k, v in name_action_mappings.items():
if k[-1] in loaded_state_dict_keys:
filter_name_action[k] = v

return filter_name_action, resume_keys

@classmethod
def _get_fuse_or_split_param_mappings(cls, config: PretrainedConfig, is_fuse=True) -> List[StateDictNameMapping]:
"""get fused parameter mapping of PretrainedModel

Args:
config (PretrainedConfig): the configuration of name-mapping

Raises:
NotImplementedError:

Returns:
List[StateDictNameMapping]: the name-mappings for tensor_parallel
"""
# raise NotImplementedError(
# f"`_get_fuse_or_split_param_mappings` is not implemented for {cls.__name__}`. To implement it, you should "
# f"overwrite this method in the class {cls.__name__} in `{cls.__module__}.py`"
# )
return {}

@staticmethod
def _resolve_prefix_keys_for_fuse_and_split(state_keys_base, state_keys_real, ignore_error=False, is_fuse=True):
state_keys_map = {}

# use the tuple (x1,x2,x3,x4) as one key, and the prefix of x1,x2,x3 is used as a new key x4 or
# the last key x4 is used as new keys x1,x2,x3. And, the tuple also could be (a) (x1, x1) -> convert x1 to x1;
# (b) (x1,x2,x3) -> fuse x1 and x2 to x3; (c) (x1,x2,x3,x4) -> fuse x1, x2 and x3 to x4.

# is_fuse: True -> fuse, False -> split
# True: (x1,x2,x3,x4) -> [x1,x2,x3] are exist in state_keys_real, x4 is not exist in state_keys_real
# False: (x1,x2,x3,x4) -> [x1,x2,x3] are not exist in state_keys_real, x4 is exist in state_keys_real

for keys in state_keys_base:
prefix = ""
if is_fuse:
for x in state_keys_real:
for base_key in keys[:-1]:
if x.endswith(base_key):
prefix = x.replace(base_key, "")
break
if prefix != "":
break
else:
base_key = keys[-1]
for x in state_keys_real:
if x.endswith(base_key):
prefix = x.replace(base_key, "")
break

new_keys = tuple([prefix + key for key in keys])
state_keys_map[keys] = new_keys

return state_keys_map


class Converter(ConversionMixin, LogitComparer):
"""some converters are implemented in ppdiffusers, so if remove it directly, it will make ppdiffusers down.
Expand Down
43 changes: 43 additions & 0 deletions paddlenlp/transformers/gpt/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,49 @@ def get_tensor_parallel_split_mappings(num_layers):

return mappings

@classmethod
def _get_fuse_or_split_param_mappings(cls, config: GPTConfig, is_fuse=False):
# return parameter fuse utils
from paddlenlp.transformers.conversion_utils import split_or_fuse_func

fn = split_or_fuse_func(is_fuse=is_fuse)

# last key is fused key, other keys are to be fused.
fuse_qkv_keys = (
"decoder.layers.0.self_attn.q_proj.weight",
"decoder.layers.0.self_attn.k_proj.weight",
"decoder.layers.0.self_attn.v_proj.weight",
"decoder.layers.0.self_attn.qkv_proj.weight",
)
fuse_qkv_bias_keys = (
"decoder.layers.0.self_attn.q_proj.bias",
"decoder.layers.0.self_attn.k_proj.bias",
"decoder.layers.0.self_attn.v_proj.bias",
"decoder.layers.0.self_attn.qkv_proj.bias",
)
num_heads = config.num_attention_heads
num_key_value_heads = getattr(config, "num_key_value_heads", num_heads)
fuse_attention_qkv = getattr(config, "fuse_attention_qkv", False)

final_actions = {}
if is_fuse:
if fuse_attention_qkv:
for i in range(config.num_hidden_layers):
for keys in [fuse_qkv_keys, fuse_qkv_bias_keys]:
new_keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in keys])
final_actions[new_keys] = partial(
fn, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads
)
else:
if not fuse_attention_qkv:
for i in range(config.num_hidden_layers):
for keys in [fuse_qkv_keys, fuse_qkv_bias_keys]:
new_keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in keys])
final_actions[new_keys] = partial(
fn, split_nums=3, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads
)
return final_actions

@classmethod
def _get_name_mappings(cls, config: GPTConfig) -> list[StateDictNameMapping]:
mappings: list[StateDictNameMapping] = []
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/transformers/gpt/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ class GPTForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
config_class = GPTConfig

_get_tensor_parallel_mappings = GPTPretrainedModel._get_tensor_parallel_mappings
_get_fuse_or_split_param_mappings = GPTPretrainedModel._get_fuse_or_split_param_mappings
_init_weights = GPTPretrainedModel._init_weights

pretrained_init_configuration = GPTPretrainedModel.pretrained_init_configuration
Expand Down
50 changes: 50 additions & 0 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,6 +1302,56 @@ def get_tensor_parallel_split_mappings(num_layers):

return mappings

@classmethod
def _get_fuse_or_split_param_mappings(cls, config: LlamaConfig, is_fuse=False):
# return parameter fuse utils
from paddlenlp.transformers.conversion_utils import split_or_fuse_func

fn = split_or_fuse_func(is_fuse=is_fuse)

# last key is fused key, other keys are to be fused.
fuse_qkv_keys = (
"layers.0.self_attn.q_proj.weight",
"layers.0.self_attn.k_proj.weight",
"layers.0.self_attn.v_proj.weight",
"layers.0.self_attn.qkv_proj.weight",
)

fuse_gate_up_keys = (
"layers.0.mlp.gate_proj.weight",
"layers.0.mlp.up_proj.weight",
"layers.0.mlp.gate_up_fused_proj.weight",
)
num_heads = config.num_attention_heads
num_key_value_heads = getattr(config, "num_key_value_heads", num_heads)
fuse_attention_qkv = getattr(config, "fuse_attention_qkv", False)
fuse_attention_ffn = getattr(config, "fuse_attention_ffn", False)

final_actions = {}
if is_fuse:
if fuse_attention_qkv:
for i in range(config.num_hidden_layers):
keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_qkv_keys])
final_actions[keys] = partial(
fn, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads
)
if fuse_attention_ffn:
for i in range(config.num_hidden_layers):
keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys])
final_actions[keys] = fn
else:
if not fuse_attention_qkv:
for i in range(config.num_hidden_layers):
keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_qkv_keys])
final_actions[keys] = partial(
fn, split_nums=3, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads
)
if not fuse_attention_ffn:
for i in range(config.num_hidden_layers):
keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys])
final_actions[keys] = partial(fn, split_nums=2)
return final_actions

def _init_weights(self, layer):
"""Initialization hook"""
if self.config.tensor_parallel_degree > 1:
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/transformers/llama/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ class LlamaForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
config_class = LlamaConfig

_get_tensor_parallel_mappings = LlamaPretrainedModel._get_tensor_parallel_mappings
_get_fuse_or_split_param_mappings = LlamaPretrainedModel._get_fuse_or_split_param_mappings
_init_weights = LlamaPretrainedModel._init_weights
_keys_to_ignore_on_load_unexpected = LlamaPretrainedModel._keys_to_ignore_on_load_unexpected

Expand Down
Loading
Loading