Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLM] Support fuse attention q, k, v weights #8202

Merged
merged 29 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d859826
1. add use-interface & fuse action
ziangqin-baidu Feb 5, 2024
f7b6973
3.6.1 solve shard branch after rebase develop
ziangqin-baidu Mar 11, 2024
05507a7
code clean
ziangqin-baidu Mar 17, 2024
b8b828b
Merge remote-tracking branch 'qinziang/fuse' into dev-fuse-qkv
DrownFish19 Mar 28, 2024
28ed30f
remove debug comment
DrownFish19 Mar 29, 2024
1a1591a
Redefine fuse and split functions
DrownFish19 Apr 1, 2024
ef1fb18
Redefine fuse and split functions
DrownFish19 Apr 1, 2024
f56e602
Merge branch 'dev-fuse-qkv' of github.com:DrownFish19/PaddleNLP into …
DrownFish19 Apr 2, 2024
19faa26
comment and fix
DrownFish19 Apr 2, 2024
4a92dde
update method
DrownFish19 Apr 2, 2024
c0a71d2
update QKV fuse and split
DrownFish19 Apr 9, 2024
d1e7d17
support fuse weights in multi-files
DrownFish19 Apr 10, 2024
58a2c23
add precision compare
DrownFish19 Apr 10, 2024
7e1ab08
simplify function call
DrownFish19 Apr 11, 2024
0458f4e
support use_fast_ffn
DrownFish19 Apr 11, 2024
920155b
clean modeling and configuration
DrownFish19 Apr 11, 2024
a90fcfc
add test for gpt and opt
DrownFish19 Apr 11, 2024
774bb6f
fix tp_actions get
DrownFish19 Apr 11, 2024
057198c
Merge branch 'PaddlePaddle:develop' into dev-fuse-qkv
DrownFish19 Apr 11, 2024
db78c6a
Merge branch 'PaddlePaddle:develop' into dev-fuse-qkv
DrownFish19 Apr 12, 2024
a1aa078
Merge branch 'PaddlePaddle:develop' into dev-fuse-qkv
DrownFish19 Apr 12, 2024
f0d19f6
add fast_ffn test
DrownFish19 Apr 12, 2024
110983d
Merge branch 'dev-fuse-qkv' of github.com:DrownFish19/PaddleNLP into …
DrownFish19 Apr 12, 2024
113b883
add Qwen2Moe
DrownFish19 Apr 16, 2024
1fb0a7d
Revert "add Qwen2Moe"
DrownFish19 Apr 17, 2024
a118f1b
Merge branch 'PaddlePaddle:develop' into dev-fuse-qkv
DrownFish19 Apr 17, 2024
11d0793
add test for split
DrownFish19 Apr 17, 2024
0c98a45
update doc
DrownFish19 Apr 17, 2024
f6f3b0e
update filter_dict_keys
DrownFish19 Apr 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 219 additions & 0 deletions paddlenlp/transformers/conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,104 @@
return naive_merged_qkv_to_tensor_parallel_qkv(weight)


def fuse_param_func():
def fn(fuse_params, is_qkv=False, num_heads=None, num_key_value_heads=None, convert_fast_ffn=False):
concat_fn = np.concatenate
split_fn = np.split
if isinstance(fuse_params[0], paddle.Tensor):
concat_fn = paddle.concat
split_fn = paddle.split

if is_qkv:
# fuse_attention_qkv
# q => [q1,q2,q3,q4]
# k => [k1,k2,k3,k4] or [k1,k2] for GQA
# v => [v1,v2,v3,v4] or [v1,v2] for GQA
# fused weight => [q1,k1,v1,q2,k2,v2,q3,k3,v3,q4,k4,v4]
# or for GQA [q1,q2,k1,v1,q3,q4,k2,v2]
assert num_heads, f"num_heads should be number of heads for Q, but got {num_heads}"
assert (
num_key_value_heads
), f"num_key_value_heads should be number of key_value_heads for K and V, but got {num_key_value_heads}"
assert (
len(fuse_params) == 3
), f"fuse_params length is not equal 3, it should be Q K V list. but got length {len(fuse_params)}"
num_query_groups = num_heads // num_key_value_heads
q_list = split_fn(fuse_params[0], num_heads, axis=-1)
k_list = split_fn(fuse_params[1], num_key_value_heads, axis=-1)
v_list = split_fn(fuse_params[2], num_key_value_heads, axis=-1)

qkv_pairs = []
for i in range(num_key_value_heads):
qkv_pairs += q_list[i * num_query_groups : (i + 1) * num_query_groups]
qkv_pairs.append(k_list[i])
qkv_pairs.append(v_list[i])
return concat_fn(qkv_pairs, axis=-1)
elif convert_fast_ffn:
# fast_ffn
first = fuse_params[0][..., ::2]
second = fuse_params[0][..., 1::2]
return concat_fn([first, second], axis=-1)
else:
# fuse_attention_ffn
return concat_fn(fuse_params, axis=-1)

return fn


def split_param_func():
def fn(fused_param, split_nums=2, is_qkv=False, num_heads=None, num_key_value_heads=None):
"""TODO 参数变换形式、支持类型
排列转换适配兼容说明
DrownFish19 marked this conversation as resolved.
Show resolved Hide resolved

Args:
fused_param (_type_): _description_
split_nums (int, optional): _description_. Defaults to 2.
is_qkv (bool, optional): _description_. Defaults to False.
num_heads (_type_, optional): _description_. Defaults to None.
num_key_value_heads (_type_, optional): _description_. Defaults to None.

Returns:
_type_: _description_
"""
concat_fn = np.concatenate
split_fn = np.split
if isinstance(fused_param, paddle.Tensor):
concat_fn = paddle.concat
split_fn = paddle.split

Check warning on line 556 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L552-L556

Added lines #L552 - L556 were not covered by tests

if is_qkv:

Check warning on line 558 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L558

Added line #L558 was not covered by tests
# fuse_attention_qkv
# fused weight => [q1,k1,v1,q2,k2,v2,q3,k3,v3,q4,k4,v4]
# or for GQA [q1,q2,k1,v1,q3,q4,k2,v2]
# after split
# q => [q1,q2,q3,q4]
# k => [k1,k2,k3,k4] or [k1,k2] for GQA
# v => [v1,v2,v3,v4] or [v1,v2] for GQA

assert num_heads, f"num_heads should be number of heads for Q, but got {num_heads}"
assert (

Check warning on line 568 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L567-L568

Added lines #L567 - L568 were not covered by tests
num_key_value_heads
), f"num_key_value_heads should be number of key_value_heads for K and V, but got {num_key_value_heads}"
num_query_groups = num_heads // num_key_value_heads
q_list, k_list, v_list = [], [], []
split_heads = split_fn(fused_param, num_heads + 2 * num_key_value_heads, axis=-1)
for i in range(num_key_value_heads):
q_list += split_heads[i * (num_query_groups + 2) : (i + 1) * (num_query_groups + 2) - 2]
k_list.append(split_heads[(i + 1) * (num_query_groups + 2) - 2])
v_list.append(split_heads[(i + 1) * (num_query_groups + 2) - 1])
return concat_fn(q_list, axis=-1), concat_fn(k_list, axis=-1), concat_fn(v_list, axis=-1)

Check warning on line 578 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L571-L578

Added lines #L571 - L578 were not covered by tests
else:
# fuse_attention_ffn
return split_fn(fused_param, split_nums, axis=-1)

Check warning on line 581 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L581

Added line #L581 was not covered by tests

return fn


def split_or_fuse_func(is_fuse=True):
return fuse_param_func() if is_fuse else split_param_func()


def get_tensor_parallel_merge_func(tensor_parallel_degree, tensor_parallel_rank, num_attention_heads=None):
def fn(
x,
Expand Down Expand Up @@ -1100,6 +1198,7 @@
weight_file (str | None): the weight file path of `model_state.pdparams` file
config (PretrainedConfig): the PretrainedConfig instance of model
"""

name_action_mappings = cls._get_tensor_parallel_mappings(config)
if state_dict is None:
with device_guard("cpu"):
Expand Down Expand Up @@ -1201,6 +1300,126 @@

return state_keys_map

@classmethod
def convert_fuse_and_split(cls, config: PretrainedConfig, state_dict, tp_actions=None):
loaded_keys = state_dict.keys()
# collect and convert fuse/split action
fused_and_split_keys = []
fuse_actions, resume_keys = cls.get_fuse_or_split_param_convert_actions(config, loaded_keys, is_fuse=True)
for keys, action in fuse_actions.items():
origin_states = [state_dict.pop(key) for key in keys[:-1]]
state_dict[keys[-1]] = action(origin_states)
fused_and_split_keys.append(keys[-1])
logger.info(f"Fusing parameter: {keys[:-1]} into {keys[-1]}")

split_actions, _ = cls.get_fuse_or_split_param_convert_actions(config, loaded_keys, is_fuse=False)
for keys, action in split_actions.items():
origin_state = state_dict.pop(keys[-1])
split_states = action(origin_state)
for key, key_idx in enumerate(keys[:-1]):
state_dict[key] = split_states[key_idx]
fused_and_split_keys.append(key)
logger.info(f"Splitting parameter: {keys[-1]} into {keys[:-1]}")

Check warning on line 1322 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L1317-L1322

Added lines #L1317 - L1322 were not covered by tests

if tp_actions is not None:
for key in fused_and_split_keys:
for name in tp_actions.keys():
if key.endswith(name):
with device_guard():
state_dict[key] = paddle.Tensor(tp_actions[name](state_dict.pop(key)), zero_copy=True)
break

Check warning on line 1330 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L1325-L1330

Added lines #L1325 - L1330 were not covered by tests

# when shard file split the weight as follows, some weights need to be resumed for next shard file
# shard-001-file: q_weight, k_weight
# shard_002-file: v_weight
resume_state_dict = {k: state_dict[k] for k in resume_keys if k in state_dict}
return state_dict, resume_state_dict

@classmethod
def get_fuse_or_split_param_convert_actions(
cls,
config: PretrainedConfig,
loaded_state_dict_keys,
is_fuse=True,
ignore_error=False,
):
name_action_mappings = cls._get_fuse_or_split_param_mappings(config, is_fuse)
state_keys_map = cls._resolve_prefix_keys_for_fuse_and_split(
name_action_mappings.keys(), loaded_state_dict_keys, ignore_error, is_fuse=True
)
for k, v in state_keys_map.items():
name_action_mappings[v] = name_action_mappings.pop(k)

# filter name_action_mappings with corresponding weights
# fusing: verify all of the keys in name_action_mappings are in loaded_state_dict_keys
# splitting: verify the last key in name_action_mappings is in loaded_state_dict_keys
filter_name_action = {}
resume_keys = []
if is_fuse:
for k, v in name_action_mappings.items():
cond = True
if not all(item in loaded_state_dict_keys for item in k[:-1]):
# resume keys for next fuse
resume_keys += k[:-1]
cond = False
if cond:
filter_name_action[k] = v
else:
for k, v in name_action_mappings.items():
if k[-1] in loaded_state_dict_keys:
filter_name_action[k] = v

Check warning on line 1370 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L1370

Added line #L1370 was not covered by tests

return filter_name_action, resume_keys

@classmethod
def _get_fuse_or_split_param_mappings(cls, config: PretrainedConfig, is_fuse=True) -> List[StateDictNameMapping]:
"""get fused parameter mapping of PretrainedModel

Args:
config (PretrainedConfig): the configuration of name-mapping

Raises:
NotImplementedError:

Returns:
List[StateDictNameMapping]: the name-mappings for tensor_parallel
"""
# raise NotImplementedError(
# f"`_get_fuse_or_split_param_mappings` is not implemented for {cls.__name__}`. To implement it, you should "
# f"overwrite this method in the class {cls.__name__} in `{cls.__module__}.py`"
# )
return {}

@staticmethod
def _resolve_prefix_keys_for_fuse_and_split(state_keys_base, state_keys_real, ignore_error=False, is_fuse=True):
state_keys_map = {}

# is_fuse: True -> fuse, False -> split
# True: [x1,x2,x3,x4] -> [x1,x2,x3] are exist in state_keys_real, x4 is not exist in state_keys_real
# False: [x1,x2,x3,x4] -> [x1,x2,x3] are not exist in state_keys_real, x4 is exist in state_keys_real
DrownFish19 marked this conversation as resolved.
Show resolved Hide resolved

for keys in state_keys_base:
prefix = ""
if is_fuse:
for x in state_keys_real:
for base_key in keys[:-1]:
if x.endswith(base_key):
prefix = x.replace(base_key, "")
break
if prefix != "":
break
else:
base_key = keys[-1]
for x in state_keys_real:
if x.endswith(base_key):
prefix = x.replace(base_key, "")
break

Check warning on line 1416 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L1412-L1416

Added lines #L1412 - L1416 were not covered by tests

new_keys = tuple([prefix + key for key in keys])
state_keys_map[keys] = new_keys

return state_keys_map


class Converter(ConversionMixin, LogitComparer):
"""some converters are implemented in ppdiffusers, so if remove it directly, it will make ppdiffusers down.
Expand Down
49 changes: 49 additions & 0 deletions paddlenlp/transformers/gpt/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,55 @@ def get_tensor_parallel_split_mappings(num_layers):

return mappings

@classmethod
def _get_fuse_or_split_param_mappings(cls, config: GPTConfig, is_fuse=False):
# return parameter fuse utils
from paddlenlp.transformers.conversion_utils import split_or_fuse_func

fn = split_or_fuse_func(is_fuse=is_fuse)

# last key is fused key, other keys are to be fused.
fuse_qkv_keys = (
"decoder.layers.0.self_attn.q_proj.weight",
"decoder.layers.0.self_attn.k_proj.weight",
"decoder.layers.0.self_attn.v_proj.weight",
"decoder.layers.0.self_attn.qkv_proj.weight",
)
fuse_qkv_bias_keys = (
"decoder.layers.0.self_attn.q_proj.bias",
"decoder.layers.0.self_attn.k_proj.bias",
"decoder.layers.0.self_attn.v_proj.bias",
"decoder.layers.0.self_attn.qkv_proj.bias",
)
num_heads = config.num_attention_heads
num_key_value_heads = getattr(config, "num_key_value_heads", num_heads)
fuse_attention_qkv = getattr(config, "fuse_attention_qkv", False)

final_actions = {}
if is_fuse:
if fuse_attention_qkv:
for i in range(config.num_hidden_layers):
weight_keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_qkv_keys])
final_actions[weight_keys] = partial(
fn, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads
)
bias_keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_qkv_bias_keys])
final_actions[bias_keys] = partial(
fn, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads
DrownFish19 marked this conversation as resolved.
Show resolved Hide resolved
)
else:
if fuse_attention_qkv:
for i in range(config.num_hidden_layers):
weight_keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_qkv_keys])
final_actions[weight_keys] = partial(
fn, split_nums=3, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads
)
bias_keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_qkv_bias_keys])
final_actions[bias_keys] = partial(
fn, split_nums=3, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads
)
return final_actions

@classmethod
def _get_name_mappings(cls, config: GPTConfig) -> list[StateDictNameMapping]:
mappings: list[StateDictNameMapping] = []
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/transformers/gpt/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ class GPTForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
config_class = GPTConfig

_get_tensor_parallel_mappings = GPTPretrainedModel._get_tensor_parallel_mappings
_get_fuse_or_split_param_mappings = GPTPretrainedModel._get_fuse_or_split_param_mappings
_init_weights = GPTPretrainedModel._init_weights

pretrained_init_configuration = GPTPretrainedModel.pretrained_init_configuration
Expand Down
50 changes: 50 additions & 0 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,6 +1302,56 @@ def get_tensor_parallel_split_mappings(num_layers):

return mappings

@classmethod
def _get_fuse_or_split_param_mappings(cls, config: LlamaConfig, is_fuse=False):
# return parameter fuse utils
from paddlenlp.transformers.conversion_utils import split_or_fuse_func

fn = split_or_fuse_func(is_fuse=is_fuse)

# last key is fused key, other keys are to be fused.
fuse_qkv_keys = (
"layers.0.self_attn.q_proj.weight",
"layers.0.self_attn.k_proj.weight",
"layers.0.self_attn.v_proj.weight",
"layers.0.self_attn.qkv_proj.weight",
)

fuse_gate_up_keys = (
"layers.0.mlp.gate_proj.weight",
"layers.0.mlp.up_proj.weight",
"layers.0.mlp.gate_up_fused_proj.weight",
)
num_heads = config.num_attention_heads
num_key_value_heads = getattr(config, "num_key_value_heads", num_heads)
fuse_attention_qkv = getattr(config, "fuse_attention_qkv", False)
fuse_attention_ffn = getattr(config, "fuse_attention_ffn", False)

final_actions = {}
if is_fuse:
if fuse_attention_qkv:
for i in range(config.num_hidden_layers):
keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_qkv_keys])
final_actions[keys] = partial(
fn, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads
)
if fuse_attention_ffn:
for i in range(config.num_hidden_layers):
keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys])
final_actions[keys] = fn
else:
if fuse_attention_qkv:
for i in range(config.num_hidden_layers):
keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_qkv_keys])
final_actions[keys] = partial(
fn, split_nums=3, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads
)
if fuse_attention_ffn:
for i in range(config.num_hidden_layers):
keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys])
final_actions[keys] = partial(fn, split_nums=2)
return final_actions

def _init_weights(self, layer):
"""Initialization hook"""
if self.config.tensor_parallel_degree > 1:
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/transformers/llama/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ class LlamaForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
config_class = LlamaConfig

_get_tensor_parallel_mappings = LlamaPretrainedModel._get_tensor_parallel_mappings
_get_fuse_or_split_param_mappings = LlamaPretrainedModel._get_fuse_or_split_param_mappings
_init_weights = LlamaPretrainedModel._init_weights
_keys_to_ignore_on_load_unexpected = LlamaPretrainedModel._keys_to_ignore_on_load_unexpected

Expand Down
Loading
Loading