From f29a7b905f83000627f81e7ec09b9b6129b678e6 Mon Sep 17 00:00:00 2001 From: DrownFish19 Date: Thu, 25 Apr 2024 16:11:39 +0800 Subject: [PATCH] [LLM] Support fuse attention q, k, v weights (#8202) 1. add use-interface & fuse action 1.1. modify 1., code order 2. switch to name_mapping 3. solve tp branch 3.2 follow hui, handel qkv separately 3.3 handle pdparams 3.4 from torch 3.5 abandon low_cpu_mem_usage 3.6 solve shard branch * 3.6.1 solve shard branch after rebase develop * code clean * remove debug comment * Redefine fuse and split functions * Redefine fuse and split functions * comment and fix * update method * update QKV fuse and split * support fuse weights in multi-files * add precision compare * simplify function call * support use_fast_ffn * clean modeling and configuration * add test for gpt and opt * fix tp_actions get * add fast_ffn test * add Qwen2Moe * Revert "add Qwen2Moe" This reverts commit 113b8838a7c53f1d131928c30bf1071dfa583445. * add test for split * update doc * update filter_dict_keys --------- Co-authored-by: Zii --- paddlenlp/transformers/conversion_utils.py | 237 +++++++++++++++++ paddlenlp/transformers/gpt/modeling.py | 43 ++++ paddlenlp/transformers/gpt/modeling_pp.py | 1 + paddlenlp/transformers/llama/modeling.py | 50 ++++ paddlenlp/transformers/llama/modeling_pp.py | 1 + paddlenlp/transformers/model_utils.py | 69 ++++- paddlenlp/transformers/opt/configuration.py | 5 + paddlenlp/transformers/opt/modeling.py | 43 ++++ tests/transformers/test_conversion_common.py | 258 +++++++++++++++++++ 9 files changed, 701 insertions(+), 6 deletions(-) create mode 100644 tests/transformers/test_conversion_common.py diff --git a/paddlenlp/transformers/conversion_utils.py b/paddlenlp/transformers/conversion_utils.py index 660e79f6a3e5..a23fb808e4b5 100644 --- a/paddlenlp/transformers/conversion_utils.py +++ b/paddlenlp/transformers/conversion_utils.py @@ -489,6 +489,118 @@ def splited_qkv_to_tensor_parallel_qkv(weight_list, num_attention_heads): return naive_merged_qkv_to_tensor_parallel_qkv(weight) +def fuse_param_func(): + def fn(fuse_params, is_qkv=False, num_heads=None, num_key_value_heads=None): + """fuse function for fusing weights + + (1) fuse_attention_qkv + q => [q1,q2,q3,q4] + k => [k1,k2,k3,k4] or [k1,k2] for GQA + v => [v1,v2,v3,v4] or [v1,v2] for GQA + fused weight => [q1,k1,v1,q2,k2,v2,q3,k3,v3,q4,k4,v4] + or for GQA [q1,q2,k1,v1,q3,q4,k2,v2] + (2) fuse_attention_ffn + directly fuse weights to 1 parts + [gate_weight], [up_weight] => [gate_weight, up_weight] + + Args: + fuse_params (_type_): to be fused weights + is_qkv (bool, optional): for attention qkv weights. Defaults to False. + num_heads (_type_, optional): query heads. Defaults to None. + num_key_value_heads (_type_, optional): key and value heads. Defaults to None. + + Returns: + _type_: fused weights + """ + concat_fn = np.concatenate + split_fn = np.split + if isinstance(fuse_params[0], paddle.Tensor): + concat_fn = paddle.concat + split_fn = paddle.split + + if is_qkv: + # fuse_attention_qkv + assert num_heads, f"num_heads should be number of heads for Q, but got {num_heads}" + assert ( + num_key_value_heads + ), f"num_key_value_heads should be number of key_value_heads for K and V, but got {num_key_value_heads}" + assert ( + len(fuse_params) == 3 + ), f"fuse_params length is not equal 3, it should be Q K V list. but got length {len(fuse_params)}" + num_query_groups = num_heads // num_key_value_heads + q_list = split_fn(fuse_params[0], num_heads, axis=-1) + k_list = split_fn(fuse_params[1], num_key_value_heads, axis=-1) + v_list = split_fn(fuse_params[2], num_key_value_heads, axis=-1) + + qkv_pairs = [] + for i in range(num_key_value_heads): + qkv_pairs += q_list[i * num_query_groups : (i + 1) * num_query_groups] + qkv_pairs.append(k_list[i]) + qkv_pairs.append(v_list[i]) + return concat_fn(qkv_pairs, axis=-1) + else: + # fuse_attention_ffn + return concat_fn(fuse_params, axis=-1) + + return fn + + +def split_param_func(): + def fn(fused_param, split_nums=2, is_qkv=False, num_heads=None, num_key_value_heads=None): + """split function for splitting weights + + (1) fuse_attention_qkv + fused weight => [q1,k1,v1,q2,k2,v2,q3,k3,v3,q4,k4,v4] + or for GQA [q1,q2,k1,v1,q3,q4,k2,v2] + after split + q => [q1,q2,q3,q4] + k => [k1,k2,k3,k4] or [k1,k2] for GQA + v => [v1,v2,v3,v4] or [v1,v2] for GQA + (2) fuse_attention_ffn + directly split weight to 2 parts + [gate_weight, up_weight] => [gate_weight], [up_weight] + + Args: + fused_param (_type_): len(fused_param)=1, only one weight to be splitted + split_nums (int, optional): split_nums. Defaults to 2. + is_qkv (bool, optional): for attention qkv weights. Defaults to False. + num_heads (_type_, optional): query heads. Defaults to None. + num_key_value_heads (_type_, optional): key and value heads. Defaults to None. + + Returns: + _type_: splitted weights + """ + concat_fn = np.concatenate + split_fn = np.split + if isinstance(fused_param, paddle.Tensor): + concat_fn = paddle.concat + split_fn = paddle.split + + if is_qkv: + # fuse_attention_qkv + assert num_heads, f"num_heads should be number of heads for Q, but got {num_heads}" + assert ( + num_key_value_heads + ), f"num_key_value_heads should be number of key_value_heads for K and V, but got {num_key_value_heads}" + num_query_groups = num_heads // num_key_value_heads + q_list, k_list, v_list = [], [], [] + split_heads = split_fn(fused_param, num_heads + 2 * num_key_value_heads, axis=-1) + for i in range(num_key_value_heads): + q_list += split_heads[i * (num_query_groups + 2) : (i + 1) * (num_query_groups + 2) - 2] + k_list.append(split_heads[(i + 1) * (num_query_groups + 2) - 2]) + v_list.append(split_heads[(i + 1) * (num_query_groups + 2) - 1]) + return concat_fn(q_list, axis=-1), concat_fn(k_list, axis=-1), concat_fn(v_list, axis=-1) + else: + # fuse_attention_ffn + return split_fn(fused_param, split_nums, axis=-1) + + return fn + + +def split_or_fuse_func(is_fuse=True): + return fuse_param_func() if is_fuse else split_param_func() + + def get_tensor_parallel_merge_func(tensor_parallel_degree, tensor_parallel_rank, num_attention_heads=None): def fn( x, @@ -1100,6 +1212,7 @@ def convert_tensor_parallel( weight_file (str | None): the weight file path of `model_state.pdparams` file config (PretrainedConfig): the PretrainedConfig instance of model """ + name_action_mappings = cls._get_tensor_parallel_mappings(config) if state_dict is None: with device_guard("cpu"): @@ -1201,6 +1314,130 @@ def _resolve_prefix_keys(state_keys_base, state_keys_real, ignore_error=False): return state_keys_map + @classmethod + def convert_fuse_and_split(cls, config: PretrainedConfig, state_dict, tp_actions=None): + loaded_keys = state_dict.keys() + # collect and convert fuse/split action + fused_and_split_keys = [] + fuse_actions, resume_keys = cls.get_fuse_or_split_param_convert_actions(config, loaded_keys, is_fuse=True) + for keys, action in fuse_actions.items(): + origin_states = [state_dict.pop(key) for key in keys[:-1]] + state_dict[keys[-1]] = action(origin_states) + fused_and_split_keys.append(keys[-1]) + logger.info(f"Fusing parameter: {keys[:-1]} into {keys[-1]}") + + split_actions, _ = cls.get_fuse_or_split_param_convert_actions(config, loaded_keys, is_fuse=False) + for keys, action in split_actions.items(): + origin_state = state_dict.pop(keys[-1]) + split_states = action(origin_state) + for key_idx, key in enumerate(keys[:-1]): + state_dict[key] = split_states[key_idx] + fused_and_split_keys.append(key) + logger.info(f"Splitting parameter: {keys[-1]} into {keys[:-1]}") + + if tp_actions is not None: + for key in fused_and_split_keys: + for name in tp_actions.keys(): + if key.endswith(name): + with device_guard(): + state_dict[key] = paddle.Tensor(tp_actions[name](state_dict.pop(key)), zero_copy=True) + break + + # when shard file split the weight as follows, some weights need to be resumed for next shard file + # shard-001-file: q_weight, k_weight + # shard_002-file: v_weight + resume_state_dict = {k: state_dict[k] for k in resume_keys if k in state_dict} + return state_dict, resume_state_dict + + @classmethod + def get_fuse_or_split_param_convert_actions( + cls, + config: PretrainedConfig, + loaded_state_dict_keys, + is_fuse=True, + ignore_error=False, + ): + name_action_mappings = cls._get_fuse_or_split_param_mappings(config, is_fuse) + state_keys_map = cls._resolve_prefix_keys_for_fuse_and_split( + name_action_mappings.keys(), loaded_state_dict_keys, ignore_error, is_fuse + ) + for k, v in state_keys_map.items(): + name_action_mappings[v] = name_action_mappings.pop(k) + + # filter name_action_mappings with corresponding weights + # fusing: verify all of the keys in name_action_mappings are in loaded_state_dict_keys + # splitting: verify the last key in name_action_mappings is in loaded_state_dict_keys + filter_name_action = {} + resume_keys = [] + if is_fuse: + for k, v in name_action_mappings.items(): + cond = True + if not all(item in loaded_state_dict_keys for item in k[:-1]): + # resume keys for next fuse + resume_keys += k[:-1] + cond = False + if cond: + filter_name_action[k] = v + else: + for k, v in name_action_mappings.items(): + if k[-1] in loaded_state_dict_keys: + filter_name_action[k] = v + + return filter_name_action, resume_keys + + @classmethod + def _get_fuse_or_split_param_mappings(cls, config: PretrainedConfig, is_fuse=True) -> List[StateDictNameMapping]: + """get fused parameter mapping of PretrainedModel + + Args: + config (PretrainedConfig): the configuration of name-mapping + + Raises: + NotImplementedError: + + Returns: + List[StateDictNameMapping]: the name-mappings for tensor_parallel + """ + # raise NotImplementedError( + # f"`_get_fuse_or_split_param_mappings` is not implemented for {cls.__name__}`. To implement it, you should " + # f"overwrite this method in the class {cls.__name__} in `{cls.__module__}.py`" + # ) + return {} + + @staticmethod + def _resolve_prefix_keys_for_fuse_and_split(state_keys_base, state_keys_real, ignore_error=False, is_fuse=True): + state_keys_map = {} + + # use the tuple (x1,x2,x3,x4) as one key, and the prefix of x1,x2,x3 is used as a new key x4 or + # the last key x4 is used as new keys x1,x2,x3. And, the tuple also could be (a) (x1, x1) -> convert x1 to x1; + # (b) (x1,x2,x3) -> fuse x1 and x2 to x3; (c) (x1,x2,x3,x4) -> fuse x1, x2 and x3 to x4. + + # is_fuse: True -> fuse, False -> split + # True: (x1,x2,x3,x4) -> [x1,x2,x3] are exist in state_keys_real, x4 is not exist in state_keys_real + # False: (x1,x2,x3,x4) -> [x1,x2,x3] are not exist in state_keys_real, x4 is exist in state_keys_real + + for keys in state_keys_base: + prefix = "" + if is_fuse: + for x in state_keys_real: + for base_key in keys[:-1]: + if x.endswith(base_key): + prefix = x.replace(base_key, "") + break + if prefix != "": + break + else: + base_key = keys[-1] + for x in state_keys_real: + if x.endswith(base_key): + prefix = x.replace(base_key, "") + break + + new_keys = tuple([prefix + key for key in keys]) + state_keys_map[keys] = new_keys + + return state_keys_map + class Converter(ConversionMixin, LogitComparer): """some converters are implemented in ppdiffusers, so if remove it directly, it will make ppdiffusers down. diff --git a/paddlenlp/transformers/gpt/modeling.py b/paddlenlp/transformers/gpt/modeling.py index fcd8962e2282..14569835f078 100644 --- a/paddlenlp/transformers/gpt/modeling.py +++ b/paddlenlp/transformers/gpt/modeling.py @@ -844,6 +844,49 @@ def get_tensor_parallel_split_mappings(num_layers): return mappings + @classmethod + def _get_fuse_or_split_param_mappings(cls, config: GPTConfig, is_fuse=False): + # return parameter fuse utils + from paddlenlp.transformers.conversion_utils import split_or_fuse_func + + fn = split_or_fuse_func(is_fuse=is_fuse) + + # last key is fused key, other keys are to be fused. + fuse_qkv_keys = ( + "decoder.layers.0.self_attn.q_proj.weight", + "decoder.layers.0.self_attn.k_proj.weight", + "decoder.layers.0.self_attn.v_proj.weight", + "decoder.layers.0.self_attn.qkv_proj.weight", + ) + fuse_qkv_bias_keys = ( + "decoder.layers.0.self_attn.q_proj.bias", + "decoder.layers.0.self_attn.k_proj.bias", + "decoder.layers.0.self_attn.v_proj.bias", + "decoder.layers.0.self_attn.qkv_proj.bias", + ) + num_heads = config.num_attention_heads + num_key_value_heads = getattr(config, "num_key_value_heads", num_heads) + fuse_attention_qkv = getattr(config, "fuse_attention_qkv", False) + + final_actions = {} + if is_fuse: + if fuse_attention_qkv: + for i in range(config.num_hidden_layers): + for keys in [fuse_qkv_keys, fuse_qkv_bias_keys]: + new_keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in keys]) + final_actions[new_keys] = partial( + fn, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads + ) + else: + if not fuse_attention_qkv: + for i in range(config.num_hidden_layers): + for keys in [fuse_qkv_keys, fuse_qkv_bias_keys]: + new_keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in keys]) + final_actions[new_keys] = partial( + fn, split_nums=3, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads + ) + return final_actions + @classmethod def _get_name_mappings(cls, config: GPTConfig) -> list[StateDictNameMapping]: mappings: list[StateDictNameMapping] = [] diff --git a/paddlenlp/transformers/gpt/modeling_pp.py b/paddlenlp/transformers/gpt/modeling_pp.py index cd3dce018378..8b350e6556df 100644 --- a/paddlenlp/transformers/gpt/modeling_pp.py +++ b/paddlenlp/transformers/gpt/modeling_pp.py @@ -161,6 +161,7 @@ class GPTForCausalLMPipe(PipelinePretrainedModel, PipelineLayer): config_class = GPTConfig _get_tensor_parallel_mappings = GPTPretrainedModel._get_tensor_parallel_mappings + _get_fuse_or_split_param_mappings = GPTPretrainedModel._get_fuse_or_split_param_mappings _init_weights = GPTPretrainedModel._init_weights pretrained_init_configuration = GPTPretrainedModel.pretrained_init_configuration diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 3efbb9de89a1..81eee3f83539 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -1293,6 +1293,56 @@ def get_tensor_parallel_split_mappings(num_layers): return mappings + @classmethod + def _get_fuse_or_split_param_mappings(cls, config: LlamaConfig, is_fuse=False): + # return parameter fuse utils + from paddlenlp.transformers.conversion_utils import split_or_fuse_func + + fn = split_or_fuse_func(is_fuse=is_fuse) + + # last key is fused key, other keys are to be fused. + fuse_qkv_keys = ( + "layers.0.self_attn.q_proj.weight", + "layers.0.self_attn.k_proj.weight", + "layers.0.self_attn.v_proj.weight", + "layers.0.self_attn.qkv_proj.weight", + ) + + fuse_gate_up_keys = ( + "layers.0.mlp.gate_proj.weight", + "layers.0.mlp.up_proj.weight", + "layers.0.mlp.gate_up_fused_proj.weight", + ) + num_heads = config.num_attention_heads + num_key_value_heads = getattr(config, "num_key_value_heads", num_heads) + fuse_attention_qkv = getattr(config, "fuse_attention_qkv", False) + fuse_attention_ffn = getattr(config, "fuse_attention_ffn", False) + + final_actions = {} + if is_fuse: + if fuse_attention_qkv: + for i in range(config.num_hidden_layers): + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_qkv_keys]) + final_actions[keys] = partial( + fn, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads + ) + if fuse_attention_ffn: + for i in range(config.num_hidden_layers): + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys]) + final_actions[keys] = fn + else: + if not fuse_attention_qkv: + for i in range(config.num_hidden_layers): + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_qkv_keys]) + final_actions[keys] = partial( + fn, split_nums=3, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads + ) + if not fuse_attention_ffn: + for i in range(config.num_hidden_layers): + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys]) + final_actions[keys] = partial(fn, split_nums=2) + return final_actions + def _init_weights(self, layer): """Initialization hook""" if self.config.tensor_parallel_degree > 1: diff --git a/paddlenlp/transformers/llama/modeling_pp.py b/paddlenlp/transformers/llama/modeling_pp.py index 73600aa6b420..dd2a91814231 100644 --- a/paddlenlp/transformers/llama/modeling_pp.py +++ b/paddlenlp/transformers/llama/modeling_pp.py @@ -210,6 +210,7 @@ class LlamaForCausalLMPipe(PipelinePretrainedModel, PipelineLayer): config_class = LlamaConfig _get_tensor_parallel_mappings = LlamaPretrainedModel._get_tensor_parallel_mappings + _get_fuse_or_split_param_mappings = LlamaPretrainedModel._get_fuse_or_split_param_mappings _init_weights = LlamaPretrainedModel._init_weights _keys_to_ignore_on_load_unexpected = LlamaPretrainedModel._keys_to_ignore_on_load_unexpected diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index edc1bb2d3439..b4395a39a9dd 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -108,7 +108,6 @@ def unwrap_optimizer(optimizer, optimizer_instances=()): if is_safetensors_available(): - from safetensors import safe_open from safetensors.numpy import load_file as safe_load_file from safetensors.numpy import save_file as safe_save_file @@ -1822,6 +1821,25 @@ def _find_mismatched_keys( del state_dict[checkpoint_key] return mismatched_keys + def _fuse_or_split_keys( + state_dict, config, loaded_keys, pre_tensor_parallel_split=False, resume_state_dict=None + ): + if resume_state_dict is not None: + state_dict.update(resume_state_dict) + + before_fuse_keys = list(state_dict.keys()) + if pre_tensor_parallel_split: + tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_keys, ignore_error=True) + else: + tp_actions = None + state_dict, resume_state_dict = cls.convert_fuse_and_split(config, state_dict, tp_actions) + after_fuse_keys = list(state_dict.keys()) + + fused_keys = list(set(before_fuse_keys) - set(after_fuse_keys)) + new_keys = list(set(after_fuse_keys) - set(before_fuse_keys)) + + return state_dict, resume_state_dict, fused_keys, new_keys + if state_dict is not None: # DONT Hold tensor parallel here, only hold afer load state dict. # Whole checkpoint @@ -1831,6 +1849,16 @@ def _find_mismatched_keys( state_dict = ft_decoding.get_ft_para_conf().fit_partial_model(model_to_load, state_dict) + # have loaded all state_dict, no resume state_dict + state_dict, _, fused_keys, new_keys = _fuse_or_split_keys( + state_dict, + config, + loaded_keys, + pre_tensor_parallel_split=True if config.tensor_parallel_degree > 1 else False, + ) + missing_keys = list(set(missing_keys) - set(new_keys)) + unexpected_keys = list(set(unexpected_keys) - set(fused_keys)) + mismatched_keys = _find_mismatched_keys( state_dict, model_state_dict, @@ -1862,7 +1890,7 @@ def _find_mismatched_keys( error_msgs = [] mismatched_keys = [] - + resume_state_dict = {} if len(resolved_archive_file) > 1: resolved_archive_file = tqdm(resolved_archive_file, desc="Loading checkpoint shards") @@ -1875,13 +1903,42 @@ def _find_mismatched_keys( ): pre_tensor_parallel_split = True assert loaded_keys is not None, "loaded_keys is not None." - tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_keys) + tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_keys, ignore_error=True) # Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors + filter_dict_keys = set(expected_keys) + fuse_actions, _ = cls.get_fuse_or_split_param_convert_actions(config, loaded_keys, is_fuse=True) + split_actions, _ = cls.get_fuse_or_split_param_convert_actions(config, loaded_keys, is_fuse=False) + for k in list(fuse_actions.keys()): + need_add_except_key = k[-1] in expected_keys + if need_add_except_key: + filter_dict_keys |= set(k[:-1]) + for k in list(split_actions.keys()): + need_add_except_key = False + for item in k[:-1]: + if item in expected_keys: + need_add_except_key = True + break + if need_add_except_key: + filter_dict_keys.add(k[-1]) + + if config.quantization_config.is_weight_quantize(): + filter_dict_keys = None + state_dict = load_state_dict( - shard_file, - tp_actions if pre_tensor_parallel_split else None, - None if config.quantization_config.is_weight_quantize() else set(expected_keys), + shard_file, tp_actions if pre_tensor_parallel_split else None, filter_dict_keys + ) + + # convert for fusing or splitting weights + state_dict, resume_state_dict, fused_keys, new_keys = _fuse_or_split_keys( + state_dict, + config, + loaded_keys, + pre_tensor_parallel_split=pre_tensor_parallel_split, + resume_state_dict=resume_state_dict, ) + missing_keys = list(set(missing_keys) - set(new_keys)) + unexpected_keys = list(set(unexpected_keys) - set(fused_keys)) + if config.quantization_config.is_weight_quantize(): state_dict = convert_to_quantize_state_dict( state_dict, diff --git a/paddlenlp/transformers/opt/configuration.py b/paddlenlp/transformers/opt/configuration.py index 866da043198e..3f6f23c1c65d 100644 --- a/paddlenlp/transformers/opt/configuration.py +++ b/paddlenlp/transformers/opt/configuration.py @@ -146,6 +146,8 @@ def __init__( eos_token_id=2, enable_bias: bool = True, mp_degree: int = 1, + fuse_attention_qkv=False, + fuse_attention_ffn=False, **kwargs, ): super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) @@ -165,3 +167,6 @@ def __init__( self.enable_bias = enable_bias self.mp_degree = mp_degree + + self.fuse_attention_qkv = fuse_attention_qkv + self.fuse_attention_ffn = fuse_attention_ffn diff --git a/paddlenlp/transformers/opt/modeling.py b/paddlenlp/transformers/opt/modeling.py index 41cc45482004..c24bd357deaa 100644 --- a/paddlenlp/transformers/opt/modeling.py +++ b/paddlenlp/transformers/opt/modeling.py @@ -649,6 +649,49 @@ def _get_tensor_parallel_mappings(cls, config: OPTConfig, is_split=True): return actions + @classmethod + def _get_fuse_or_split_param_mappings(cls, config: OPTConfig, is_fuse=False): + # return parameter fuse utils + from paddlenlp.transformers.conversion_utils import split_or_fuse_func + + fn = split_or_fuse_func(is_fuse=is_fuse) + + # last key is fused key, other keys are to be fused. + fuse_qkv_keys = ( + "decoder.layers.0.self_attn.q_proj.weight", + "decoder.layers.0.self_attn.k_proj.weight", + "decoder.layers.0.self_attn.v_proj.weight", + "decoder.layers.0.self_attn.qkv_proj.weight", + ) + fuse_qkv_bias_keys = ( + "decoder.layers.0.self_attn.q_proj.bias", + "decoder.layers.0.self_attn.k_proj.bias", + "decoder.layers.0.self_attn.v_proj.bias", + "decoder.layers.0.self_attn.qkv_proj.bias", + ) + num_heads = config.num_attention_heads + num_key_value_heads = getattr(config, "num_key_value_heads", num_heads) + fuse_attention_qkv = getattr(config, "fuse_attention_qkv", False) + + final_actions = {} + if is_fuse: + if fuse_attention_qkv: + for i in range(config.num_hidden_layers): + for keys in [fuse_qkv_keys, fuse_qkv_bias_keys]: + new_keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in keys]) + final_actions[new_keys] = partial( + fn, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads + ) + else: + if not fuse_attention_qkv: + for i in range(config.num_hidden_layers): + for keys in [fuse_qkv_keys, fuse_qkv_bias_keys]: + new_keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in keys]) + final_actions[new_keys] = partial( + fn, split_nums=3, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads + ) + return final_actions + @classmethod def _get_name_mappings(cls, config: OPTConfig) -> list[StateDictNameMapping]: mappings: list[StateDictNameMapping] = [] diff --git a/tests/transformers/test_conversion_common.py b/tests/transformers/test_conversion_common.py new file mode 100644 index 000000000000..989f8665d6a1 --- /dev/null +++ b/tests/transformers/test_conversion_common.py @@ -0,0 +1,258 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import copy +import glob +import os +import tempfile +import unittest + +import paddle + +input_ids = paddle.to_tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]]) + + +def prepare_default_config(config): + config = copy.deepcopy(config) + config.hidden_size = 512 + config.num_layers = 2 + config.num_hidden_layers = 2 + config.num_attention_heads = 16 + config.num_key_value_heads = 16 + config.intermediate_size = config.hidden_size + config.word_embed_proj_dim = 512 + return config + + +def prepare_split_config(config): + config = prepare_default_config(config) + config = copy.deepcopy(config) + config.fuse_attention_qkv = False + config.fuse_attention_ffn = False + return config + + +def prepare_fuse_config(config): + config = prepare_default_config(config) + config = copy.deepcopy(config) + config.fuse_attention_qkv = True + config.fuse_attention_ffn = True + return config + + +def common_test_load(model_class, model_first, config_second, tempdir): + model_first.eval() + with paddle.no_grad(): + first = model_first(input_ids)[0] + + model_second = model_class.from_pretrained(tempdir, config=config_second) + model_second.eval() + with paddle.no_grad(): + second = model_second(input_ids)[0] + + assert paddle.allclose(paddle.mean(first), paddle.mean(second), atol=1e-7) + assert paddle.allclose(first, second, atol=1e-4) + + files = glob.glob(tempdir + "/*") + for f in files: + os.remove(f) + + +def common_test_save_and_load(config_first, config_second, model_class): + model_first = model_class.from_config(config_first) + + with tempfile.TemporaryDirectory() as tempdir: + # test load pdparams: model.pdparams + model_first.save_pretrained(save_dir=tempdir) + common_test_load(model_class, model_first, config_second, tempdir) + + # test load shard pdparams: model-001-0f-008.pdparams + model_first.save_pretrained(tempdir, max_shard_size="5MB") + common_test_load(model_class, model_first, config_second, tempdir) + + # test save safetensors: model.safetensors + model_first.save_pretrained(tempdir, safe_serialization=True) + common_test_load(model_class, model_first, config_second, tempdir) + + # test load shard safetensors: model-001-0f-008.safetensors + model_first.save_pretrained(tempdir, max_shard_size="5MB", safe_serialization=True) + common_test_load(model_class, model_first, config_second, tempdir) + + +def _test_split_to_fuse(config_class, model_class): + config = config_class() + + config_split = prepare_split_config(config) + config_fuse = prepare_fuse_config(config) + + # Test from splitted weights to fused weight + common_test_save_and_load(config_split, config_fuse, model_class) + + +def _test_fuse_to_split(config_class, model_class): + config = config_class() + + config_split = prepare_split_config(config) + config_fuse = prepare_fuse_config(config) + + # Test from fused weight to splitted weights + common_test_save_and_load(config_fuse, config_split, model_class) + + +def _test_fast_ffn(): + from functools import partial + + import paddle + from paddle import nn + + from paddlenlp.transformers import PretrainedModel + from paddlenlp.transformers.configuration_utils import PretrainedConfig + + class TestConfig(PretrainedConfig): + def __init__(self, fast_ffn_state=False, convert_fast_ffn=False): + self.fast_ffn_state = fast_ffn_state + self.convert_fast_ffn = convert_fast_ffn + super().__init__() + + class TestMLP(nn.Layer): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.gate_up_fused_proj = nn.Linear(self.hidden_size, self.hidden_size * 2, bias_attr=True) + + def forward(self, hidden_state): + hidden_state = self.gate_up_fused_proj(hidden_state) + if self.config.use_fast_ffn: + x, y = paddle.chunk(hidden_state, chunks=2, axis=-1) + else: + x, y = hidden_state[..., ::2], hidden_state[..., 1::2] + + return nn.functional.silu(x) * y + + class TestPretrainedModel(PretrainedModel): + config_class = TestConfig + + @classmethod + def _get_fuse_or_split_param_mappings(cls, config: TestConfig, is_fuse=False): + + # user defined function to get convert param mappings + def convert_fast_ffn_fn(fuse_params, convert_fast_ffn=False): + import numpy as np + + concat_fn = np.concatenate + if isinstance(fuse_params[0], paddle.Tensor): + concat_fn = paddle.concat + + if convert_fast_ffn: + # fast_ffn + first = fuse_params[0][..., ::2] + second = fuse_params[0][..., 1::2] + return concat_fn([first, second], axis=-1) + + fn = convert_fast_ffn_fn + + convert_fast_ffn_keys = ( + "layers.0.gate_up_fused_proj.weight", + "layers.0.gate_up_fused_proj.weight", + ) + convert_fast_ffn_bias_keys = ( + "layers.0.gate_up_fused_proj.bias", + "layers.0.gate_up_fused_proj.bias", + ) + fast_ffn_state = getattr(config, "fast_ffn_state", False) + convert_fast_ffn = getattr(config, "convert_fast_ffn", False) + convert_fast_ffn &= not fast_ffn_state + + final_actions = {} + if is_fuse: + # for_get_fuse_or_split_param_mappings, is_fuse have two conditions, true and false, + # to fit partial fuse or split conditions, is_fuse will called twice(True and False). + # thus, for this func, we only use one condition. + + # use_fast_ffn only in one condition + # convert when use_fast_ffn is False + if convert_fast_ffn: + for i in range(config.num_hidden_layers): + for keys in [convert_fast_ffn_keys, convert_fast_ffn_bias_keys]: + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in keys]) + final_actions[keys] = partial(fn, convert_fast_ffn=convert_fast_ffn) + return final_actions + + def _init_weights(self, layer): + if isinstance(layer, (nn.Linear, nn.Embedding)): + if isinstance(layer.weight, paddle.Tensor): + layer.weight.set_value(paddle.tensor.normal(mean=0.0, std=1.0, shape=layer.weight.shape)) + if hasattr(layer, "bias") and isinstance(layer.bias, paddle.Tensor): + layer.bias.set_value(paddle.tensor.normal(mean=0.0, std=1.0, shape=layer.bias.shape)) + + class TestModel(TestPretrainedModel): + def __init__(self, config): + super().__init__(config) + self.layers = nn.LayerList([TestMLP(config=config) for i in range(config.num_hidden_layers)]) + + def forward(self, hidden_state): + for idx, (decoder_layer) in enumerate(self.layers): + hidden_state = decoder_layer(hidden_state) + return hidden_state + + class TestForCausalLM(TestPretrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + self.embedding_layer = nn.Embedding(65535, self.config.hidden_size) + self.test = TestModel(config=config) + + def forward(self, input_ids): + hidden_state = self.embedding_layer(input_ids) + return self.test(hidden_state) + + config = TestConfig() + config = prepare_default_config(config) + config_no_fast_ffn = copy.deepcopy(config) + config_fast_ffn = copy.deepcopy(config) + + config_no_fast_ffn.use_fast_ffn = False + + config_fast_ffn.use_fast_ffn = True + config_fast_ffn.fast_ffn_state = False + config_fast_ffn.convert_fast_ffn = True + + common_test_save_and_load(config_no_fast_ffn, config_fast_ffn, TestForCausalLM) + + +from paddlenlp.transformers import ( + GPTConfig, + GPTForCausalLM, + LlamaConfig, + LlamaForCausalLM, + OPTConfig, + OPTForCausalLM, +) + + +class TestFuseOrSplit(unittest.TestCase): + def test_model_split_to_fuse(self): + _test_split_to_fuse(LlamaConfig, LlamaForCausalLM) + _test_split_to_fuse(GPTConfig, GPTForCausalLM) + _test_split_to_fuse(OPTConfig, OPTForCausalLM) + + def test_model_fuse_to_split(self): + _test_fuse_to_split(LlamaConfig, LlamaForCausalLM) + _test_fuse_to_split(GPTConfig, GPTForCausalLM) + _test_fuse_to_split(OPTConfig, OPTForCausalLM) + + def test_model_convert_fast_ffn(self): + _test_fast_ffn()