Skip to content

Commit

Permalink
Redefine fuse and split functions
Browse files Browse the repository at this point in the history
  • Loading branch information
DrownFish19 committed Apr 1, 2024
1 parent 28ed30f commit ef1fb18
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 259 deletions.
183 changes: 109 additions & 74 deletions paddlenlp/transformers/conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import inspect
import json
import os
import re
from copy import deepcopy
from dataclasses import dataclass
from typing import (
Expand Down Expand Up @@ -490,22 +489,22 @@ def splited_qkv_to_tensor_parallel_qkv(weight_list, num_attention_heads):
return naive_merged_qkv_to_tensor_parallel_qkv(weight)


def merged_as_tensor_parallel_qkv(state_dict, q_name, k_name, v_name, num_hidden_layers):
q = state_dict[q_name]
k = state_dict[k_name]
v = state_dict[v_name]
def fuse_param_func():
def fn(fuse_params: List[np.array]):
return np.concatenate(fuse_params, axis=-1)

naive_merged_qkv = np.concatenate((q, k, v), axis=-1)
return fn

return naive_merged_qkv_to_tensor_parallel_qkv(naive_merged_qkv, num_hidden_layers)

def split_param_func(split_nums):
def fn(fused_param):
return np.split(fused_param, split_nums, axis=-1)

def merge_as_naive_merged_qkv():
pass
return fn


def merge_as_splited_qkv():
pass
def split_or_fuse_func(is_fuse=True):
return fuse_param_func if is_fuse else split_param_func


def get_tensor_parallel_merge_func(tensor_parallel_degree, tensor_parallel_rank, num_attention_heads=None):
Expand Down Expand Up @@ -1101,19 +1100,10 @@ def _get_name_mappings(cls, config: PretrainedConfig) -> List[StateDictNameMappi

@classmethod
def get_tensor_parallel_convert_actions(
cls, config: PretrainedConfig, loaded_state_dict_keys, is_split=True, ignore_error=False, ignore_params=[]
cls, config: PretrainedConfig, loaded_state_dict_keys, is_split=True, ignore_error=False
):
name_action_mappings = cls._get_tensor_parallel_mappings(config, is_split=is_split)

# avoid act on fuse parameters (qkv/gate-up), they are not consistant between config and loaded_state_dict_keys
name_map_list = cls._get_name_mappings(config)
for key in ignore_params:
for name_map in name_map_list:
if name_map.target_name == key:
name_action_mappings.pop(name_map.source_name.split("model.")[-1], None)

state_keys_map = cls._resolve_prefix_keys(name_action_mappings.keys(), loaded_state_dict_keys, ignore_error)

for k, v in state_keys_map.items():
name_action_mappings[v] = name_action_mappings.pop(k)
return name_action_mappings
Expand All @@ -1129,66 +1119,27 @@ def convert_tensor_parallel(
config (PretrainedConfig): the PretrainedConfig instance of model
"""

def _apply_tp_action(name_action_mappings):
state_keys_map = cls._resolve_prefix_keys(name_action_mappings.keys(), state_dict.keys(), ignore_error)

for k, v in state_keys_map.items():
name_action_mappings[v] = name_action_mappings.pop(k)

for name, action in name_action_mappings.items():
if name not in state_dict:
if not ignore_error:
logger.warning(f"Key <{name}> not in the model state weight file.")
continue
tensor = state_dict.pop(name)
new_tensor = action(tensor)
with device_guard("cpu"):
state_dict[name] = paddle.Tensor(new_tensor, zero_copy=True)

name_action_mappings = cls._get_tensor_parallel_mappings(config)
if state_dict is None:
with device_guard("cpu"):
state_dict = paddle.load(weight_file, return_numpy=False)
logger.info("Starting to convert orignal state_dict to tensor parallel state_dict.")

from paddlenlp.transformers.model_utils import select_fuse_parameter
state_keys_map = cls._resolve_prefix_keys(name_action_mappings.keys(), state_dict.keys(), ignore_error)

do_fuse_parameter_list, do_separate_parameter_list = select_fuse_parameter(cls, state_dict.keys(), config)
if "attention_qkv_proj" in do_fuse_parameter_list:
state_dict, fuse_success = cls.fuse_attention_parameters(
state_dict, ["attention_qkv_proj"], config
) # design: q, k, v => qkv

name_action_mappings = cls._get_tensor_parallel_mappings(config)
for k, v in state_keys_map.items():
name_action_mappings[v] = name_action_mappings.pop(k)

# avoid act on fuse parameters (qkv/gate-up), they are not consistant between config and loaded_state_dict_keys
# pop qkv tp actions and apply the rest actions
if "attention_qkv_proj" in do_fuse_parameter_list:

name_map_list = [
lambda layer_id: re.sub(r"\d+", str(layer_id), "layers.0.self_attn.q_proj.weight"),
lambda layer_id: re.sub(r"\d+", str(layer_id), "layers.0.self_attn.k_proj.weight"),
lambda layer_id: re.sub(r"\d+", str(layer_id), "layers.0.self_attn.v_proj.weight"),
lambda layer_id: re.sub(r"\d+", str(layer_id), "layers.0.self_attn.qkv_proj.weight"),
]
tp_action_keys = list(name_action_mappings.keys())
poped_param_names = []
for key in tp_action_keys:
for name_map in name_map_list:
if re.sub(r"\d+", "0", key) == name_map(0):
name_action_mappings.pop(key, None)
poped_param_names.append(key)

_apply_tp_action(name_action_mappings)

# tail processing qkv parameters
if "attention_qkv_proj" in do_fuse_parameter_list:
name_action_mappings_fuse = cls._get_tensor_parallel_mappings(config)
tp_action_fuse_keys = list(name_action_mappings_fuse.keys())
for key in tp_action_fuse_keys:
if key not in poped_param_names:
name_action_mappings_fuse.pop(key, None)

_apply_tp_action(name_action_mappings_fuse)
for name, action in name_action_mappings.items():
if name not in state_dict:
if not ignore_error:
logger.warning(f"Key <{name}> not in the model state weight file.")
continue
tensor = state_dict.pop(name)
new_tensor = action(tensor)
with device_guard("cpu"):
state_dict[name] = paddle.Tensor(new_tensor, zero_copy=True)
state_dict = cls.convert_fuse_and_split(config, state_dict, name_action_mappings)

return state_dict

Expand Down Expand Up @@ -1270,6 +1221,90 @@ def _resolve_prefix_keys(state_keys_base, state_keys_real, ignore_error=False):

return state_keys_map

def convert_fuse_and_split(cls, config: PretrainedConfig, state_dict, tp_actions=None):
loaded_keys = state_dict.keys()
# collect and convert fuse/split action
fused_and_split_keys = []
fuse_actions = cls.get_fuse_or_split_param_convert_actions(config, loaded_keys, is_fuse=True)
for keys, action in fuse_actions.items():
origin_states = [state_dict[key] for key in keys]
state_dict[keys[-1]] = action(origin_states)
fused_and_split_keys.append(keys[-1])

split_actions = cls.get_fuse_or_split_param_convert_actions(config, loaded_keys, is_fuse=False)
for keys, action in split_actions.items():
origin_state = state_dict[keys[-1]]
split_states = action(origin_state)
for key, key_idx in enumerate(keys[:-1]):
state_dict[key] = split_states[key_idx]
fused_and_split_keys.append(key)

if tp_actions is not None:
for key in fused_and_split_keys:
if key in tp_actions:
state_dict[key] = tp_actions[key](state_dict.pop(key))
return state_dict

def get_fuse_or_split_param_convert_actions(
cls,
config: PretrainedConfig,
loaded_state_dict_keys,
is_fuse=True,
ignore_error=False,
):
name_action_mappings = cls._get_fuse_or_split_param_mappings(config, is_fuse)
state_keys_map = cls._resolve_prefix_keys_for_fuse_and_split(
name_action_mappings.keys(), loaded_state_dict_keys, ignore_error, is_fuse=True
)
for k, v in state_keys_map.items():
name_action_mappings[v] = name_action_mappings.pop(k)

filter_name_action = {}
for k, v in name_action_mappings.items():
if is_fuse:
cond = all(item in loaded_state_dict_keys for item in k[:-1])
else:
cond = k[-1] in loaded_state_dict_keys

if cond:
filter_name_action[k] = v

return filter_name_action

def _get_fuse_or_split_param_mappings(cls, config: PretrainedConfig, is_fuse=True) -> List[StateDictNameMapping]:
"""get fused parameter mapping of PretrainedModel
Args:
config (PretrainedConfig): the configuration of name-mapping
Raises:
NotImplementedError:
Returns:
List[StateDictNameMapping]: the name-mappings for tensor_parallel
"""
raise NotImplementedError(
f"`_get_fused_param_mappings` is not implemented for {cls.__name__}`. To implement it, you should "
f"overwrite this method in the class {cls.__name__} in `{cls.__module__}.py`"
)

@staticmethod
def _resolve_prefix_keys_for_fuse_and_split(state_keys_base, state_keys_real, ignore_error=False, is_fuse=True):
state_keys_map = {}

for keys in state_keys_base:
base_key = keys[0] if is_fuse else keys[-1]
prefix = ""
for x in state_keys_real:
if x.endswith(base_key):
prefix = x.replace(x, base_key)
break
new_keys = (prefix + key for key in keys)

state_keys_map[keys] = new_keys

return state_keys_map


class Converter(ConversionMixin, LogitComparer):
"""some converters are implemented in ppdiffusers, so if remove it directly, it will make ppdiffusers down.
Expand Down
52 changes: 28 additions & 24 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from __future__ import annotations

import math
import re
import warnings
from functools import partial
from typing import Optional, Tuple
Expand Down Expand Up @@ -1279,32 +1278,37 @@ def get_tensor_parallel_split_mappings(num_layers):
return mappings

@classmethod
def _get_fused_param_mappings(cls):
def _get_fuse_or_split_param_mappings(cls, config: LlamaConfig, is_fuse=False):
# return parameter fuse utils
from paddlenlp.transformers.conversion_utils import (
merged_as_tensor_parallel_qkv,
)
from paddlenlp.transformers.conversion_utils import split_or_fuse_func

fn = split_or_fuse_func(is_fuse=is_fuse)

final_actions = {}
if config.fuse_attention_qkv:
# last key is fused key, other keys are to be fused.
base_keys = (
"layers.0.self_attn.q_proj.weight",
"layers.0.self_attn.k_proj.weight",
"layers.0.self_attn.v_proj.weight",
"layers.0.self_attn.qkv_proj.weight",
)

# attention: q,k,v -> qkv, ffn: gate, up -> gate_up
mappings = {
"fuse_action": [merged_as_tensor_parallel_qkv, None],
"split_action": [None, None],
"attn_param_names": {
"qkv_proj": lambda layer_id: re.sub(r"\d+", str(layer_id), "llama.layers.0.self_attn.qkv_proj.weight"),
"q_proj": lambda layer_id: re.sub(r"\d+", str(layer_id), "llama.layers.0.self_attn.q_proj.weight"),
"k_proj": lambda layer_id: re.sub(r"\d+", str(layer_id), "llama.layers.0.self_attn.k_proj.weight"),
"v_proj": lambda layer_id: re.sub(r"\d+", str(layer_id), "llama.layers.0.self_attn.v_proj.weight"),
},
"ffn_param_names": {
"gate_up_proj": lambda layer_id: re.sub(
r"\d+", str(layer_id), "llama.layers.0.mlp.gate_up_proj.weight"
),
"gate_proj": lambda layer_id: re.sub(r"\d+", str(layer_id), "llama.layers.0.mlp.gate_proj.weight"),
"up_proj": lambda layer_id: re.sub(r"\d+", str(layer_id), "llama.layers.0.mlp.up_proj.weight"),
},
}
for i in range(config.num_hidden_layers):
keys = (key.replace("layers.0.", f"layers.{i}.") for key in base_keys)
final_actions[keys] = fn

return mappings
if config.fuse_attention_ffn:
base_keys = (
"llama.layers.0.mlp.gate_proj.weight",
"llama.layers.0.mlp.up_proj.weight",
"llama.layers.0.mlp.gate_up_proj.weight",
)
for i in range(config.num_hidden_layers):
keys = (key.replace("layers.0.", f"layers.{i}.") for key in base_keys)
final_actions[keys] = fn

return final_actions

def _init_weights(self, layer):
"""Initialization hook"""
Expand Down
Loading

0 comments on commit ef1fb18

Please sign in to comment.