diff --git a/llm/data.py b/llm/data.py index 5a13ddd11db7..8000bd598455 100644 --- a/llm/data.py +++ b/llm/data.py @@ -43,11 +43,21 @@ def get_convert_example(model): if base_model_prefix == "chatglm": return convert_example_chatglm - elif base_model_prefix in ["chatglm_v2", "llama", "bloom", "opt", "qwen", "mixtral", "gemma"]: + elif base_model_prefix in [ + "chatglm_v2", + "llama", + "bloom", + "opt", + "qwen", + "mixtral", + "gemma", + "qwen2", + "qwen2_moe", + ]: return convert_example_common else: raise ValueError( - f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama, qwen, mixtral, gemma" + f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama, qwen, mixtral, gemma, qwen2, qwen2_moe", ) diff --git a/llm/qwen/README.md b/llm/qwen/README.md index d923a4ac3677..22ac37c19e17 100644 --- a/llm/qwen/README.md +++ b/llm/qwen/README.md @@ -5,15 +5,57 @@ [通义千问(Qwen)](https://arxiv.org/abs/2205.01068) 是阿里云研发的通义千问大模型系列的模型, 有 70 亿和 140 亿两个规模。Qwen是基于Transformer的大语言模型, 在超大规模的预训练数据上进行训练得到。预训练数据类型多样,覆盖广泛,包括大量网络文本、专业书籍、代码等。 **支持模型权重:** -| Model | -|-------------------| -| qwen/qwen-7b | -| qwen/qwen-7b-chat | -| qwen/qwen-14b | -| qwen/qwen-14b-chat| -| qwen/qwen-72b | -| qwen/qwen-72b-chat| -| qwen/qwen1.5-moe-a2.7b| +| Model | +|--------------------| +| qwen/qwen-7b | +| qwen/qwen-7b-chat | +| qwen/qwen-14b | +| qwen/qwen-14b-chat | +| qwen/qwen-72b | +| qwen/qwen-72b-chat | + + + +[通义千问(Qwen1.5)](https://qwenlm.github.io/blog/qwen1.5/) 是阿里云研发的通义千问系列模型升级版。Qwen1.5包括0.5B、1.8B、4B、7B、14B、32B、72B、110B和MoE共计9个不同规模的Base和Chat模型。 + +**支持模型权重:** +| Model (qwen-1.5) | +|-----------------------------| +| Qwen/Qwen1.5-0.5B | +| Qwen/Qwen1.5-0.5B-Chat | +| Qwen/Qwen1.5-1.8B | +| Qwen/Qwen1.5-1.8B-Chat | +| Qwen/Qwen1.5-4B | +| Qwen/Qwen1.5-4B-Chat | +| Qwen/Qwen1.5-7B | +| Qwen/Qwen1.5-7B-Chat | +| Qwen/Qwen1.5-14B | +| Qwen/Qwen1.5-14B-Chat | +| Qwen/Qwen1.5-32B | +| Qwen/Qwen1.5-32B-Chat | +| Qwen/Qwen1.5-72B | +| Qwen/Qwen1.5-72B-Chat | +| Qwen/Qwen1.5-110B | +| Qwen/Qwen1.5-110B-Chat | +| Qwen/Qwen1.5-MoE-A2.7B | +| Qwen/Qwen1.5-MoE-A2.7B-Chat | + + +[通义千问(Qwen2)](https://qwenlm.github.io/blog/qwen2/) 是阿里云研发的通义千问系列模型升级版。Qwen2包括0.5B、1.5B、7B、72B和MoE共计5个不同规模的Base和Chat模型。 +**支持模型权重:** +| Model (qwen2) | +|------------------------------| +| Qwen/Qwen2-0.5B | +| Qwen/Qwen2-0.5B-Instruct | +| Qwen/Qwen2-1.5B | +| Qwen/Qwen2-1.5B-Instruct | +| Qwen/Qwen2-7B | +| Qwen/Qwen2-7B-Instruct | +| Qwen/Qwen2-72B | +| Qwen/Qwen2-72B-Instruct | +| Qwen/Qwen2-57B-A14B | +| Qwen/Qwen2-57B-A14B-Instruct | + ## 2. 模型精调 请参考[LLM全流程工具介绍](../README.md) diff --git a/llm/qwen/lora_argument_qwen2moe.json b/llm/qwen/lora_argument_qwen2moe.json new file mode 100644 index 000000000000..0344e3885ba0 --- /dev/null +++ b/llm/qwen/lora_argument_qwen2moe.json @@ -0,0 +1,32 @@ +{ + "model_name_or_path": "Qwen/Qwen1.5-MoE-A2.7B", + "dataset_name_or_path": "./data", + "output_dir": "./checkpoints/qwen2moe_lora_ckpts", + "per_device_train_batch_size": 4, + "gradient_accumulation_steps": 4, + "per_device_eval_batch_size": 8, + "eval_accumulation_steps":16, + "num_train_epochs": 3, + "learning_rate": 3e-04, + "warmup_steps": 30, + "logging_steps": 1, + "evaluation_strategy": "epoch", + "save_strategy": "epoch", + "src_length": 1024, + "max_length": 32768, + "bf16": true, + "fp16_opt_level": "O2", + "do_train": true, + "do_eval": true, + "disable_tqdm": true, + "load_best_model_at_end": true, + "eval_with_do_generation": false, + "metric_for_best_model": "accuracy", + "recompute": true, + "save_total_limit": 1, + "tensor_parallel_degree": 8, + "pipeline_parallel_degree": 1, + "lora": true, + "zero_padding": false, + "use_flash_attention": false + } \ No newline at end of file diff --git a/llm/qwen/sft_argument_qwen2moe.json b/llm/qwen/sft_argument_qwen2moe.json new file mode 100644 index 000000000000..75d3a93500f5 --- /dev/null +++ b/llm/qwen/sft_argument_qwen2moe.json @@ -0,0 +1,30 @@ +{ + "model_name_or_path": "Qwen/Qwen1.5-MoE-A2.7B", + "dataset_name_or_path": "./data", + "output_dir": "./checkpoints/qwen2moe_sft_ckpts", + "per_device_train_batch_size": 4, + "gradient_accumulation_steps": 4, + "per_device_eval_batch_size": 8, + "eval_accumulation_steps":16, + "num_train_epochs": 3, + "learning_rate": 3e-05, + "warmup_steps": 30, + "logging_steps": 1, + "evaluation_strategy": "epoch", + "save_strategy": "epoch", + "src_length": 1024, + "max_length": 32768, + "bf16": true, + "fp16_opt_level": "O2", + "do_train": true, + "do_eval": true, + "disable_tqdm": true, + "load_best_model_at_end": true, + "eval_with_do_generation": false, + "metric_for_best_model": "accuracy", + "recompute": true, + "save_total_limit": 1, + "tensor_parallel_degree": 8, + "sharding": "stage2", + "pipeline_parallel_degree": 1 +} \ No newline at end of file diff --git a/llm/utils.py b/llm/utils.py index 2c177fc041f4..4bca491f4806 100644 --- a/llm/utils.py +++ b/llm/utils.py @@ -41,7 +41,6 @@ def compute_metrics(eval_preds): - flattened_preds = np.array(eval_preds.predictions).flatten() flattened_labels = np.array(eval_preds.label_ids).flatten() filtered_preds = flattened_preds[flattened_labels != -100] @@ -157,10 +156,22 @@ def get_lora_target_modules(model): ".*k_proj.*", ".*v_proj.*", ".*o_proj.*", + # ".*gate.*", # TODO(DrownFish19): Does the gate weight require training? ".*w1.*", ".*w2.*", ".*w3.*", ] + elif model.base_model_prefix == "qwen2_moe": + target_modules = [ + ".*q_proj.*", + ".*k_proj.*", + ".*v_proj.*", + ".*o_proj.*", + # ".*gate.*", # TODO(DrownFish19): Does the gate weight require training? + ".*gate_proj.*", + ".*up_proj.*", + ".*down_proj.*", + ] else: raise ValueError(f"Unknown base_model_prefix: {model.base_model_prefix}.") return target_modules diff --git a/paddlenlp/transformers/__init__.py b/paddlenlp/transformers/__init__.py index bd743c59f934..324ec2a09e4f 100644 --- a/paddlenlp/transformers/__init__.py +++ b/paddlenlp/transformers/__init__.py @@ -294,6 +294,8 @@ from .deberta_v2.modeling import * from .deberta_v2.tokenizer import * from .deberta_v2.configuration import * +from .qwen2 import * +from .qwen2_moe import * # For faster tokenizer from ..utils.import_utils import is_fast_tokenizer_available diff --git a/paddlenlp/transformers/auto/modeling.py b/paddlenlp/transformers/auto/modeling.py index ed23b6758ca2..f22126f76c8a 100644 --- a/paddlenlp/transformers/auto/modeling.py +++ b/paddlenlp/transformers/auto/modeling.py @@ -118,6 +118,8 @@ ("Bloom", "bloom"), ("QWen", "qwen"), ("Mixtral", "mixtral"), + ("Qwen2", "qwen2"), + ("Qwen2Moe", "qwen2_moe"), ("Gemma", "gemma"), ] ) @@ -215,15 +217,20 @@ def _get_model_class_from_config(cls, pretrained_model_name_or_path, config_file else: init_class = config.pop("init_class", None) init_class = init_class[:-5] if init_class is not None and init_class.endswith("Model") else init_class + + # Sort the MAPPING_NAMES to reorder the model class names with longest-first rule + # thus the names with same prefix can be correctly inferred + # such as QWen and QWen2MOE, QWen2MOE is the longest prefix of QWen2MOEModel model_name = None + SORTED_MAPPING_NAMES = dict(sorted(MAPPING_NAMES.items(), key=lambda x: len(x[0]), reverse=True)) if init_class: - for model_flag, name in MAPPING_NAMES.items(): + for model_flag, name in SORTED_MAPPING_NAMES.items(): if model_flag in init_class: model_name = model_flag + "Model" break else: # From pretrained_model_name_or_path - for model_flag, name in MAPPING_NAMES.items(): + for model_flag, name in SORTED_MAPPING_NAMES.items(): if name in pretrained_model_name_or_path.lower(): model_name = model_flag + "Model" break diff --git a/paddlenlp/transformers/qwen2/__init__.py b/paddlenlp/transformers/qwen2/__init__.py new file mode 100644 index 000000000000..6e9485ea762a --- /dev/null +++ b/paddlenlp/transformers/qwen2/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2024 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration import * +from .modeling import * +from .tokenizer import * diff --git a/paddlenlp/transformers/qwen2/configuration.py b/paddlenlp/transformers/qwen2/configuration.py new file mode 100644 index 000000000000..71eb7ee2f2d2 --- /dev/null +++ b/paddlenlp/transformers/qwen2/configuration.py @@ -0,0 +1,182 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen2 model configuration""" + +from ..configuration_utils import PretrainedConfig + +__all__ = [ + "Qwen2Config", +] + + +class Qwen2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a + Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 22016): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import Qwen2Model, Qwen2Config + + >>> # Initializing a Qwen2 style configuration + >>> configuration = Qwen2Config() + + >>> # Initializing a model from the Qwen2-7B style configuration + >>> model = Qwen2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="silu", + max_position_embeddings=32768, + seq_length=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + use_recompute=False, + recompute_granularity="full", + pp_recompute_interval=1, + no_recompute_layers=None, + fuse_attention_qkv=False, + fuse_attention_ffn=False, + use_flash_attention=False, + use_fused_rms_norm=False, + use_fused_rope=False, + tensor_parallel_output=True, + sequence_parallel=False, + fuse_sequence_parallel_allreduce=False, + virtual_pp_degree=1, + tie_word_embeddings=False, + rope_theta=10000.0, + pad_token_id=0, + bos_token_id=151643, + eos_token_id=151643, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + rope_scaling_factor=1.0, + rope_scaling_type=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.seq_length = seq_length + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + + self.use_cache = use_cache + self.use_recompute = use_recompute + self.recompute_granularity = recompute_granularity + self.no_recompute_layers = no_recompute_layers + self.pp_recompute_interval = pp_recompute_interval + self.fuse_attention_qkv = fuse_attention_qkv + self.use_flash_attention = use_flash_attention + self.fuse_attention_ffn = fuse_attention_ffn + self.use_fused_rms_norm = use_fused_rms_norm + self.tensor_parallel_output = tensor_parallel_output + self.sequence_parallel = sequence_parallel + self.fuse_sequence_parallel_allreduce = fuse_sequence_parallel_allreduce + self.virtual_pp_degree = virtual_pp_degree + + self.use_fused_rope = use_fused_rope + self.rope_scaling_factor = rope_scaling_factor + self.rope_scaling_type = rope_scaling_type + + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/paddlenlp/transformers/qwen2/modeling.py b/paddlenlp/transformers/qwen2/modeling.py new file mode 100644 index 000000000000..6cc4b83a359a --- /dev/null +++ b/paddlenlp/transformers/qwen2/modeling.py @@ -0,0 +1,1527 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Paddle Qwen2 model.""" + +import math +import warnings +from functools import partial +from typing import List, Optional, Tuple, Union + +import paddle +import paddle.distributed.fleet.meta_parallel as mpu +import paddle.nn.functional as F +from paddle import Tensor, nn +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker +from paddle.distributed.fleet.utils import recompute +from paddle.distributed.fleet.utils.sequence_parallel_utils import ( + ColumnSequenceParallelLinear, + GatherOp, + RowSequenceParallelLinear, + ScatterOp, + mark_as_sequence_parallel_parameter, +) + +from ..activations import ACT2FN +from ..conversion_utils import StateDictNameMapping, init_name_mappings +from ..model_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ..model_utils import PretrainedModel, register_base_model +from .configuration import Qwen2Config + +try: + from paddle.incubate.nn.functional import fused_rotary_position_embedding +except ImportError: + fused_rotary_position_embedding = None + +try: + from paddle.nn.functional.flash_attention import flash_attention +except: + flash_attention = None + + +__all__ = [ + "Qwen2Model", + "Qwen2PretrainedModel", + "Qwen2ForCausalLM", + "Qwen2PretrainingCriterion", +] + + +def get_triangle_upper_mask(x, mask=None): + if mask is not None: + return mask + # [bsz, n_head, q_len, kv_seq_len] + shape = x.shape + # [bsz, 1, q_len, kv_seq_len] + shape[1] = 1 + mask = paddle.full(shape, paddle.finfo(x.dtype).min, dtype=x.dtype) + mask = paddle.triu(mask, diagonal=1) + mask.stop_gradient = True + return mask + + +def assign_kv_heads(num_kv_heads: int, num_gpus: int): + # Initialize the assignment list + """ + Assign kv heads to different GPUs in the Tensor Parallel Setup + + Examples: + assign_kv_heads(num_kv_heads=1, num_gpus=2): [[0], [0]] + assign_kv_heads(num_kv_heads=2, num_gpus=2): [[0], [1]] + assign_kv_heads(num_kv_heads=4, num_gpus=2): [[0,1], [2,3]] + assign_kv_heads(num_kv_heads=1, num_gpus=4): [[0],[0],[0],[0]] + assign_kv_heads(num_kv_heads=2, num_gpus=4): [[0],[0],[1],[1]] + assign_kv_heads(num_kv_heads=4, num_gpus=4): [[0],[1],[2],[3]] + """ + assignment_list = [[] for _ in range(num_gpus)] + # Case 1: more heads than cards + if num_kv_heads > num_gpus: + num_heads_per_card = num_kv_heads // num_gpus + for i in range(num_gpus): + for j in range(num_heads_per_card): + assignment_list[i].append(i * num_heads_per_card + j) + # Case 2: more cards than heads. each card get only 1 head. + else: + num_card_per_heads = num_gpus // num_kv_heads + for i in range(num_kv_heads): + for j in range(num_card_per_heads): + assignment_list[i * num_card_per_heads + j].append(i) + return assignment_list + + +def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True): + is_fleet_init = True + tensor_parallel_degree = 1 + try: + hcg = fleet.get_hybrid_communicate_group() + model_parallel_group = hcg.get_model_parallel_group() + tensor_parallel_degree = hcg.get_model_parallel_world_size() + except: + is_fleet_init = False + + if paddle.in_dynamic_mode(): + y_is_distributed = y.is_distributed + else: + y_is_distributed = tensor_parallel_degree > 1 + + if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed: + # if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg' + input_parallel = paddle.distributed.collective._c_identity(x, group=model_parallel_group) + logits = paddle.matmul(input_parallel, y, transpose_y=False) + + if tensor_parallel_output: + return logits + + return paddle.distributed.collective._c_concat(logits, group=model_parallel_group) + + else: + logits = paddle.matmul(x, y, transpose_y=False) + return logits + + +def scaled_dot_product_attention( + query_states, + config, + key_states, + value_states, + attention_mask, + output_attentions, + training=True, + sequence_parallel=False, +): + bsz, q_len, num_heads, head_dim = query_states.shape + _, kv_seq_len, _, _ = value_states.shape + + if config.use_flash_attention and flash_attention: + # Paddle Flash Attention input [ bz, seqlen, nhead, head_dim] + # Torch Flash Attention input [ bz, nhead, seqlen, head_dim] + + version = paddle.version.full_version + if version != "0.0.0" and version <= "2.5.2": + attn_output, attn_weights = flash_attention( + query_states, + key_states, + value_states, + causal=True, + return_softmax=output_attentions, + ) + else: + attn_output = F.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + is_causal=attention_mask is None, + dropout_p=config.attention_dropout if training else 0.0, + training=training, + ) + attn_weights = None + + if sequence_parallel: + attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads]) + else: + attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) + return (attn_output, attn_weights) if output_attentions else attn_output + else: + # [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim] + query_states = paddle.transpose(query_states, [0, 2, 1, 3]) + # merge with the next tranpose + key_states = paddle.transpose(key_states, [0, 2, 1, 3]) + value_states = paddle.transpose(value_states, [0, 2, 1, 3]) + + # matmul and devide by sqrt(head_dim) + attn_weights = paddle.matmul(query_states / math.sqrt(head_dim), key_states.transpose([0, 1, 3, 2])) + + if attn_weights.shape != [bsz, num_heads, q_len, kv_seq_len]: + raise ValueError( + f"Attention weights should be of shape {(bsz, num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.shape}" + ) + + if attention_mask is None: + attention_mask = get_triangle_upper_mask(attn_weights) + attention_mask = attention_mask.reshape([bsz, 1, q_len, kv_seq_len]) + if attention_mask.shape != [bsz, 1, q_len, kv_seq_len]: + raise ValueError( + f"Attention mask should be of shape {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.shape}" + ) + + attn_weights = attn_weights + attention_mask + if not paddle.in_dynamic_mode(): + attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype) + else: + with paddle.amp.auto_cast(False): + attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype) + + attn_weights = F.dropout(attn_weights, p=config.attention_dropout, training=training) + + attn_output = paddle.matmul(attn_weights, value_states) + attn_output = attn_output.transpose([0, 2, 1, 3]) + + if sequence_parallel: + attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads]) + else: + attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) + return (attn_output, attn_weights) if output_attentions else attn_output + + +def masked_fill(x, mask, value): + y = paddle.full(x.shape, value, x.dtype) + return paddle.where(mask, y, x) + + +def is_casual_mask(attention_mask): + """ + Upper triangular of attention_mask equals to attention_mask is casual + """ + return (paddle.triu(attention_mask) == attention_mask).all().item() + + +def _make_causal_mask(input_ids_shape, past_key_values_length): + """ + Make causal mask used for self-attention + """ + batch_size, target_length = input_ids_shape # target_length: seq_len + + mask = paddle.tril(paddle.ones((target_length, target_length), dtype="bool")) + + if past_key_values_length > 0: + # [tgt_len, tgt_len + past_len] + mask = paddle.concat([paddle.ones([target_length, past_key_values_length], dtype="bool"), mask], axis=-1) + + # [bs, 1, tgt_len, tgt_len + past_len] + return mask[None, None, :, :].expand([batch_size, 1, target_length, target_length + past_key_values_length]) + + +def _expand_2d_mask(mask, dtype, tgt_length): + """ + Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`. + """ + batch_size, src_length = mask.shape[0], mask.shape[-1] + tgt_length = tgt_length if tgt_length is not None else src_length + + mask = mask[:, None, None, :].astype("bool") + mask.stop_gradient = True + expanded_mask = mask.expand([batch_size, 1, tgt_length, src_length]) + + return expanded_mask + + +class Qwen2RMSNorm(nn.Layer): + def __init__(self, config: Qwen2Config): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.hidden_size = config.hidden_size + self.weight = paddle.create_parameter( + shape=[self.hidden_size], + dtype=paddle.get_default_dtype(), + default_initializer=nn.initializer.Constant(1.0), + ) + self.variance_epsilon = config.rms_norm_eps + self.config = config + + if config.sequence_parallel: + mark_as_sequence_parallel_parameter(self.weight) + + def forward(self, hidden_states): + if paddle.in_dynamic_mode(): + with paddle.amp.auto_cast(False): + hidden_states = hidden_states.astype("float32") + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states + else: + hidden_states = hidden_states.astype("float32") + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states + + if self.weight.dtype in [paddle.float16, paddle.bfloat16]: + hidden_states = paddle.cast(hidden_states, self.weight.dtype) + return hidden_states * self.weight + + +class Qwen2RotaryEmbedding(nn.Layer): + def __init__(self, dim, max_position_embeddings=2048, base=10000): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + # [dim / 2] + self.inv_freq = 1.0 / (self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32") / self.dim)) + self._set_cos_sin_cache(seq_len=max_position_embeddings) + + def _set_cos_sin_cache(self, seq_len): + self.max_seq_len_cached = seq_len + # [seq_len] + t = paddle.arange(seq_len, dtype="float32") + # [seq_len, dim/2] + freqs = paddle.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + # [seq_len, dim] + emb = paddle.concat([freqs, freqs], axis=-1) + # [1, seqlen, 1, dim] + self.cos_cached = emb.cos()[None, :, None, :] + self.sin_cached = emb.sin()[None, :, None, :] + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len) + cos = self.cos_cached[:, :seq_len, :, :] + sin = self.sin_cached[:, :seq_len, :, :] + return ( + cos.cast(x.dtype) if cos.dtype != x.dtype else cos, + sin.cast(x.dtype) if sin.dtype != x.dtype else sin, + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return paddle.concat([-x2, x1], axis=-1) # shape is the same as x + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + if position_ids is None: + # Note: Only for Qwen2MoEForCausalLMPipe model pretraining + cos = cos[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] + sin = sin[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] + else: + cos = cos.squeeze(axis=[0, 2]) # [seq_len, dim] + sin = sin.squeeze(axis=[0, 2]) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Qwen2MLP(nn.Layer): + def __init__(self, config: Qwen2Config, is_shared=False): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + self.tensor_parallel_degree = config.tensor_parallel_degree + + if config.sequence_parallel: + ColumnParallelLinear = ColumnSequenceParallelLinear + RowParallelLinear = RowSequenceParallelLinear + else: + ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear + RowParallelLinear = fleet.meta_parallel.RowParallelLinear + + if config.tensor_parallel_degree > 1: + self.gate_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=False, + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=False, + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + input_is_parallel=True, + has_bias=False, + ) + else: + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) # w1 + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) # w3 + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False) # w2 + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor: + """ + This is the equivalent of paddle.repeat_interleave(hidden_states, n_rep, axis=1). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, slen, num_key_value_heads, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + + hidden_states = hidden_states.unsqueeze(-2).tile([1, 1, 1, n_rep, 1]) + return hidden_states.reshape([batch, slen, num_key_value_heads * n_rep, head_dim]) + + +class Qwen2Attention(nn.Layer): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2Config, layerwise_recompute: bool = True): + super().__init__() + + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.head_dim = self.hidden_size // config.num_attention_heads + + self.num_key_value_heads = config.num_key_value_heads + assert config.num_attention_heads // config.num_key_value_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.gqa_or_mqa = config.num_attention_heads != config.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + self.seq_length = config.seq_length + self.sequence_parallel = config.sequence_parallel + + # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True + # Enable_recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.layerwise_recompute = layerwise_recompute + self.recompute_granularity = config.recompute_granularity + if config.tensor_parallel_degree > 1: + assert ( + self.num_heads % config.tensor_parallel_degree == 0 + ), f"num_heads: {self.num_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}" + self.num_heads = self.num_heads // config.tensor_parallel_degree + + assert ( + self.num_key_value_heads % config.tensor_parallel_degree == 0 + ), f"num_key_value_heads: {self.num_key_value_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}" + self.num_key_value_heads = self.num_key_value_heads // config.tensor_parallel_degree + + self.use_fused_rope = config.use_fused_rope + if self.use_fused_rope: + if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None: + warnings.warn( + "Enable fuse rope in the config, but fuse rope is not available. " + "Will disable fuse rope. Try using latest gpu version of Paddle." + ) + self.use_fused_rope = False + + if config.sequence_parallel: + ColumnParallelLinear = ColumnSequenceParallelLinear + RowParallelLinear = RowSequenceParallelLinear + else: + ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear + RowParallelLinear = fleet.meta_parallel.RowParallelLinear + + if config.tensor_parallel_degree > 1: + self.q_proj = ColumnParallelLinear(self.hidden_size, self.hidden_size, has_bias=True, gather_output=False) + self.k_proj = ColumnParallelLinear( + self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=True, gather_output=False + ) + self.v_proj = ColumnParallelLinear( + self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=True, gather_output=False + ) + self.o_proj = RowParallelLinear(self.hidden_size, self.hidden_size, has_bias=False, input_is_parallel=True) + else: + self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias_attr=True) + self.k_proj = nn.Linear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, bias_attr=True) + self.v_proj = nn.Linear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, bias_attr=True) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias_attr=False) + + self.rotary_emb = Qwen2RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def forward( + self, + hidden_states, + position_ids: Optional[Tuple[paddle.Tensor]] = None, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: + """Input shape: Batch x Time x Channel""" + # [bs, seq_len, num_head * head_dim] -> [seq_len / n, bs, num_head * head_dim] (n is model parallelism) + + batch_size, seq_len, _ = hidden_states.shape + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + if self.sequence_parallel: + target_query_shape = [-1, self.seq_length, self.num_heads, self.head_dim] + target_key_value_shape = [-1, self.seq_length, self.num_key_value_heads, self.head_dim] + else: + target_query_shape = [0, 0, self.num_heads, self.head_dim] + target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim] + query_states = query_states.reshape(shape=target_query_shape) + key_states = key_states.reshape(shape=target_key_value_shape) + value_states = value_states.reshape(shape=target_key_value_shape) + + kv_seq_len = key_states.shape[-3] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-3] + if self.use_fused_rope: + assert past_key_value is None, "fuse rotary not support cache kv for now" + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states, _ = fused_rotary_position_embedding( + query_states, + key_states, + v=None, + sin=sin, + cos=cos, + position_ids=position_ids, + use_neox_rotary_style=False, + ) + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + # [bs, seq_len, num_head, head_dim] + if past_key_value is not None: + key_states = paddle.concat([past_key_value[0], key_states], axis=1) + value_states = paddle.concat([past_key_value[1], value_states], axis=1) + past_key_value = (key_states, value_states) if use_cache else None + + # TODO(wj-Mcat): use broadcast strategy when n_kv_heads = 1 + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + has_gradient = not (query_states.stop_gradient and key_states.stop_gradient and value_states.stop_gradient) + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "core_attn" + ): + outputs = recompute( + scaled_dot_product_attention, + query_states, + self.config, + key_states, + value_states, + attention_mask, + output_attentions, + self.training, + self.sequence_parallel, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + outputs = scaled_dot_product_attention( + query_states, + self.config, + key_states, + value_states, + attention_mask, + output_attentions, + self.training, + self.sequence_parallel, + ) + if output_attentions: + attn_output, attn_weights = outputs + else: + attn_output = outputs + + # if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim] + # else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism. + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + outputs = (attn_output,) + + if output_attentions: + outputs += (attn_weights,) + + if use_cache: + outputs += (past_key_value,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class Qwen2DecoderLayer(nn.Layer): + def __init__(self, config: Qwen2Config, layerwise_recompute: bool = False): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Qwen2Attention(config, layerwise_recompute) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config) + self.post_attention_layernorm = Qwen2RMSNorm(config) + + # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True + # Enable_recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.layerwise_recompute = layerwise_recompute + self.recompute_granularity = config.recompute_granularity + + def forward( + self, + hidden_states: paddle.Tensor, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + """ + Args: + hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`paddle.Tensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states + """ + + # [bs * seq_len, embed_dim] -> [seq_len * bs / n, embed_dim] (sequence_parallel) + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "full_attn" + ): + outputs = recompute( + self.self_attn, + hidden_states, + position_ids, + past_key_value, + attention_mask, + output_attentions, + use_cache, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + outputs = self.self_attn( + hidden_states, + position_ids, + past_key_value, + attention_mask, + output_attentions, + use_cache, + ) + + if type(outputs) is tuple: + hidden_states = outputs[0] + else: + hidden_states = outputs + + if output_attentions: + self_attn_weights = outputs[1] + + if use_cache: + present_key_value = outputs[2 if output_attentions else 1] + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class Qwen2PretrainedModel(PretrainedModel): + config_class = Qwen2Config + base_model_prefix = "qwen2" + _keys_to_ignore_on_load_unexpected = [r"self_attn.rotary_emb.inv_freq"] + + @classmethod + def _get_name_mappings(cls, config: Qwen2Config) -> list[StateDictNameMapping]: + mappings: list[StateDictNameMapping] = [] + model_mappings = [ + ["embed_tokens.weight"], + ["norm.weight"], + ] + for layer_index in range(config.num_hidden_layers): + layer_mappings = [ + [f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.k_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.v_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.q_proj.bias", None], + [f"layers.{layer_index}.self_attn.k_proj.bias", None], + [f"layers.{layer_index}.self_attn.v_proj.bias", None], + [f"layers.{layer_index}.self_attn.o_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.up_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.gate_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.rotary_emb.inv_freq"], + [f"layers.{layer_index}.input_layernorm.weight"], + [f"layers.{layer_index}.post_attention_layernorm.weight"], + ] + model_mappings.extend(layer_mappings) + + init_name_mappings(mappings=model_mappings) + # base-model prefix "Qwen2MoEModel" + if "Qwen2Model" not in config.architectures: + for mapping in model_mappings: + mapping[0] = "model." + mapping[0] + mapping[1] = "qwen2." + mapping[1] + model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"]) + + mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)] + return mappings + + @classmethod + def _get_tensor_parallel_mappings(cls, config: Qwen2Config, is_split=True): + from paddlenlp.transformers.conversion_utils import split_or_merge_func + + fn = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + ) + + def get_tensor_parallel_split_mappings(num_layers): + final_actions = {} + + base_actions = { + "lm_head.weight": partial(fn, is_column=True), + # Row Linear + "embed_tokens.weight": partial(fn, is_column=False), + "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), + } + + if not config.vocab_size % config.tensor_parallel_degree == 0: + base_actions.pop("lm_head.weight") + base_actions.pop("embed_tokens.weight") + + # Column Linear + base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True) + # if we have enough num_key_value_heads to split, then split it. + if config.num_key_value_heads % config.tensor_parallel_degree == 0: + base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.k_proj.bias"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.v_proj.bias"] = partial(fn, is_column=True) + + base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.down_proj.weight"] = partial(fn, is_column=False) + + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range(num_layers): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + final_actions[key] = action + + return final_actions + + mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers) + + return mappings + + def _init_weights(self, layer): + """Initialization hook""" + if self.config.tensor_parallel_degree > 1: + rng_tracker = get_rng_state_tracker().rng_state + if isinstance( + layer, + ( + nn.Linear, + nn.Embedding, + mpu.VocabParallelEmbedding, + mpu.ColumnParallelLinear, + mpu.RowParallelLinear, + Qwen2LMHead, + ColumnSequenceParallelLinear, + RowSequenceParallelLinear, + ), + ): + # In the dygraph mode, use the `set_value` to reset the parameter directly, + # and reset the `state_dict` to update parameter in static mode. + if isinstance(layer.weight, paddle.Tensor): + if layer.weight.is_distributed: + with rng_tracker(): + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.qwen2.config.initializer_range, + shape=layer.weight.shape, + ) + ) + else: + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.qwen2.config.initializer_range, + shape=layer.weight.shape, + ) + ) + if hasattr(layer, "bias") and isinstance(layer.bias, paddle.Tensor): + layer.bias.set_value(paddle.zeros_like(layer.bias)) + # Layer.apply is DFS https://github.com/PaddlePaddle/Paddle/blob/a6f5021fcc58b21f4414bae6bf4731ef6971582c/python/paddle/nn/layer/layers.py#L527-L530 + # sublayer is init first + # scale RowParallelLinear weight + with paddle.no_grad(): + if isinstance(layer, Qwen2MLP): + factor = 1 / math.sqrt(2 * self.config.num_hidden_layers) + layer.down_proj.weight.scale_(factor) + if isinstance(layer, Qwen2Attention): + factor = 1 / math.sqrt(2 * self.config.num_hidden_layers) + layer.o_proj.weight.scale_(factor) + + +@register_base_model +class Qwen2Model(Qwen2PretrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + + Args: + config: Qwen2Config + """ + + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.hidden_size = config.hidden_size + self.sequence_parallel = config.sequence_parallel + self.recompute_granularity = config.recompute_granularity + self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else [] + + # Recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0: + self.embed_tokens = mpu.VocabParallelEmbedding( + self.vocab_size, + self.hidden_size, + weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()), + ) + else: + self.embed_tokens = nn.Embedding( + self.vocab_size, + self.hidden_size, + ) + + self.layers = nn.LayerList( + [ + Qwen2DecoderLayer(config, layerwise_recompute=layer_idx not in self.no_recompute_layers) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = Qwen2RMSNorm(config) + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @staticmethod + def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length, dtype): + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + if len(attention_mask.shape) == 2: + expanded_attn_mask = _expand_2d_mask(attention_mask, dtype, tgt_length=input_shape[-1]) + # For decoding phase in generation, seq_length = 1, we don't need to add causal mask + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + past_key_values_length=past_key_values_length, + ) + expanded_attn_mask = expanded_attn_mask & combined_attention_mask + # [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len] + elif len(attention_mask.shape) == 3: + expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool") + # if attention_mask is already 4-D, do nothing + else: + expanded_attn_mask = attention_mask + else: + expanded_attn_mask = _make_causal_mask( + input_shape, + past_key_values_length=past_key_values_length, + ) + # Convert bool attention_mask to float attention mask, which will be added to attention_scores later + expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype) + return expanded_attn_mask + + @paddle.jit.not_to_static + def recompute_training_full( + self, + layer_module: nn.Layer, + hidden_states: Tensor, + position_ids: Optional[Tensor], + attention_mask: Tensor, + output_attentions: bool, + past_key_value: Tensor, + use_cache: bool, + ): + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = recompute( + create_custom_forward(layer_module), + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + use_reentrant=self.config.recompute_use_reentrant, + ) + + return hidden_states + + def forward( + self, + input_ids: paddle.Tensor = None, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + use_cache: Optional[bool] = None, + past_key_values: Optional[List[paddle.Tensor]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if past_key_values is None: + past_key_values = tuple([None] * len(self.layers)) + # NOTE: to make cache can be clear in-time + past_key_values = list(past_key_values) + + seq_length_with_past = seq_length + cache_length = 0 + if past_key_values[0] is not None: + cache_length = past_key_values[0][0].shape[1] + seq_length_with_past += cache_length + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self.sequence_parallel: + # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim] + bs, seq_len, hidden_size = inputs_embeds.shape + inputs_embeds = paddle.reshape_(inputs_embeds, [bs * seq_len, hidden_size]) + # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) + inputs_embeds = ScatterOp.apply(inputs_embeds) + + # embed positions + if attention_mask is None: + # [bs, seq_len] + attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) + + if position_ids is None: + position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype + ) # [bs, 1, seq_len, seq_len] + if self.config.use_flash_attention: + is_casual = is_casual_mask(attention_mask) + if is_casual: + attention_mask = None + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, (decoder_layer) in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + past_key_value = past_key_values[idx] if past_key_values is not None else None + + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and idx not in self.no_recompute_layers + and has_gradient + and self.recompute_granularity == "full" + ): + layer_outputs = self.recompute_training_full( + decoder_layer, + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + ) + + # NOTE: clear outdate cache after it has been used for memory saving + past_key_value = past_key_values[idx] = None + if type(layer_outputs) is tuple: + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class Qwen2PretrainingCriterion(nn.Layer): + """ + Criterion for Mixtral. + It calculates the final loss. + """ + + def __init__(self, config: Qwen2Config): + super(Qwen2PretrainingCriterion, self).__init__() + self.ignore_index = getattr(config, "ignore_index", -100) + self.config = config + self.enable_parallel_cross_entropy = config.tensor_parallel_degree > 1 and config.tensor_parallel_output + + if self.enable_parallel_cross_entropy: # and False: # and lm_head is distributed + self.loss_func = mpu.ParallelCrossEntropy(ignore_index=self.ignore_index) + else: + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) + + def forward(self, prediction_scores, masked_lm_labels): + if self.enable_parallel_cross_entropy: + if prediction_scores.shape[-1] == self.config.vocab_size: + warnings.warn( + f"enable_parallel_cross_entropy, the vocab_size should be splited: {prediction_scores.shape[-1]}, {self.config.vocab_size}" + ) + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) + + with paddle.amp.auto_cast(False): + masked_lm_loss = self.loss_func(prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2)) + + # skip ignore_index which loss == 0 + masked_lm_loss = masked_lm_loss[masked_lm_loss > 0] + loss = paddle.mean(masked_lm_loss) + + return loss + + +class Qwen2LMHead(nn.Layer): + def __init__(self, config: Qwen2Config): + super(Qwen2LMHead, self).__init__() + self.config = config + if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0: + vocab_size = config.vocab_size // config.tensor_parallel_degree + else: + vocab_size = config.vocab_size + + self.weight = self.create_parameter( + shape=[config.hidden_size, vocab_size], + dtype=paddle.get_default_dtype(), + ) + # Must set distributed attr for Tensor Parallel ! + self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False + if self.weight.is_distributed: + self.weight.split_axis = 1 + + def forward(self, hidden_states, tensor_parallel_output=None): + if self.config.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) + seq_length = self.config.seq_length + hidden_states = paddle.reshape_(hidden_states, [-1, seq_length, self.config.hidden_size]) + + if tensor_parallel_output is None: + tensor_parallel_output = self.config.tensor_parallel_output + + logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output) + return logits + + +class Qwen2ForCausalLM(Qwen2PretrainedModel): + enable_to_static_method = True + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.qwen2 = Qwen2Model(config) + self.lm_head = Qwen2LMHead(config) + self.criterion = Qwen2PretrainingCriterion(config) + self.vocab_size = config.vocab_size + + def get_input_embeddings(self): + return self.qwen2.embed_tokens + + def set_input_embeddings(self, value): + self.qwen2.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.qwen2 = decoder + + def get_decoder(self): + return self.qwen2 + + def prepare_inputs_for_generation( + self, + input_ids, + use_cache=False, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + output_router_logits=False, + **kwargs + ): + batch_size, seq_length = input_ids.shape + position_ids = kwargs.get("position_ids", paddle.arange(seq_length).expand((batch_size, seq_length))) + attention_mask = kwargs.get("attention_mask", None) + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(axis=-1) + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "output_router_logits": output_router_logits, + } + ) + return model_inputs + + def _get_model_inputs_spec(self, dtype: str): + return { + "input_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"), + "attention_mask": paddle.static.InputSpec(shape=[None, None], dtype="int64"), + "position_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"), + } + + @staticmethod + def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False): + # update cache + if isinstance(outputs, tuple) and len(outputs) > 1 and not isinstance(outputs[1], paddle.Tensor): + model_kwargs["past_key_values"] = outputs[1] + + if isinstance(outputs, CausalLMOutputWithPast) and "past_key_values" in outputs: + model_kwargs["past_key_values"] = outputs.past_key_values + + # update position_ids + if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None: + position_ids = model_kwargs["position_ids"] + model_kwargs["position_ids"] = paddle.concat([position_ids, position_ids[..., -1:] + 1], axis=-1) + + if not is_encoder_decoder and "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = paddle.concat( + [attention_mask, paddle.ones([attention_mask.shape[0], 1], dtype=attention_mask.dtype)], axis=-1 + ) + + return model_kwargs + + def forward( + self, + input_ids: paddle.Tensor = None, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + labels: Optional[paddle.Tensor] = None, + use_cache: Optional[bool] = None, + past_key_values: Optional[List[paddle.Tensor]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`paddle.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen2ForCausalLM + + >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.qwen2( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + # if labels is None,means we need full output, instead of tensor_parallel_output + # tensor_parallel_output is togather with ParallelCrossEntropy + tensor_parallel_output = ( + self.config.tensor_parallel_output and labels is not None and self.config.tensor_parallel_degree > 1 + ) + + logits = self.lm_head(hidden_states, tensor_parallel_output=tensor_parallel_output) + + loss = None + if labels is not None: + loss = self.criterion(logits, labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class Qwen2ForSequenceClassification(Qwen2PretrainedModel): + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.num_labels = config.num_labels + self.qwen2 = Qwen2Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + def get_input_embeddings(self): + return self.qwen2.embed_tokens + + def set_input_embeddings(self, value): + self.qwen2.embed_tokens = value + + def forward( + self, + input_ids: paddle.Tensor = None, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + past_key_values: Optional[List[paddle.Tensor]] = None, + labels: Optional[paddle.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`paddle.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.qwen2( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = paddle.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[paddle.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == paddle.int64 or labels.dtype == paddle.int32): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = nn.MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(pooled_logits.reshape([-1, self.num_labels]), labels.reshape([-1])) + elif self.config.problem_type == "multi_label_classification": + loss_fct = nn.BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2, LLAMA->QWEN2 +class Qwen2ForTokenClassification(Qwen2PretrainedModel): + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.num_labels = config.num_labels + self.qwen2 = Qwen2Model(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + def get_input_embeddings(self): + return self.qwen2.embed_tokens + + def set_input_embeddings(self, value): + self.qwen2.embed_tokens = value + + def forward( + self, + input_ids: paddle.Tensor = None, + attention_mask: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + past_key_values: Optional[List[paddle.Tensor]] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + labels: Optional[paddle.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`paddle.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.qwen2( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(logits.reshape([-1, self.num_labels]), labels.reshape([-1])) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/paddlenlp/transformers/qwen2/modeling_pp.py b/paddlenlp/transformers/qwen2/modeling_pp.py new file mode 100644 index 000000000000..952daf419cbd --- /dev/null +++ b/paddlenlp/transformers/qwen2/modeling_pp.py @@ -0,0 +1,232 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.distributed.fleet as fleet +import paddle.nn as nn +from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer +from paddle.distributed.fleet.utils import recompute + +from ...utils.tools import get_env_device +from ..model_utils import PipelinePretrainedModel +from .modeling import ( + Qwen2Config, + Qwen2DecoderLayer, + Qwen2LMHead, + Qwen2Model, + Qwen2PretrainedModel, + Qwen2PretrainingCriterion, + Qwen2RMSNorm, +) + +__all__ = [ + "QWenForCausalLMPipe", +] + + +def parse_args(args): + if isinstance(args, tuple): + if len(args) == 3: + hidden_states, attention_mask, position_ids = args + elif len(args) == 2: + hidden_states, attention_mask = args + position_ids = None + elif len(args) == 1: + hidden_states = args + attention_mask, position_ids = None, None + else: + hidden_states = args + attention_mask, position_ids = None, None + + if position_ids is not None: + position_ids.stop_gradient = True + + if attention_mask is not None: + attention_mask.stop_gradient = True + + return hidden_states, attention_mask, position_ids + + +def return_args(hidden_states, attention_mask=None, position_ids=None): + ret = (hidden_states,) + + if attention_mask is not None: + ret += (attention_mask.clone(),) + if position_ids is not None: + ret += (position_ids.clone(),) + if len(ret) == 1: + ret = ret[0] + + return ret + + +class Qwen2EmbeddingPipe(nn.Layer): + """Extends QWenEmbeddings to forward attention_mask through the pipeline.""" + + def __init__(self, config: Qwen2Config): + super().__init__() + self.hidden_size = config.hidden_size + if config.tensor_parallel_degree > 1: + self.embed_tokens = fleet.meta_parallel.VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()), + ) + else: + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + + def forward(self, args): + """_summary_ + + Args: + input (_type_): _description_ + + Returns: + _type_: _description_ + """ + input_ids, attention_mask, position_ids = parse_args(args) + input_embeds = self.embed_tokens(input_ids) + if self.sequence_parallel: + from paddlenlp.transformers import ScatterOp + + # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim] + bs, seq_len, hidden_size = input_embeds.shape + input_embeds = paddle.reshape_(input_embeds, [bs * seq_len, hidden_size]) + # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) + input_embeds = ScatterOp.apply(input_embeds) + + batch_size, seq_length = input_ids.shape + + if attention_mask is not None: + attention_mask = Qwen2Model._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), 0, input_embeds.dtype + ) + attention_mask.stop_gradient = True + if get_env_device() == "npu": + attention_mask = attention_mask.astype("bool") + elif get_env_device() == "npu": + attention_mask = paddle.tril(paddle.ones((seq_length, seq_length), dtype="bool")) + attention_mask.stop_gradient = True + + return return_args(input_embeds, attention_mask, position_ids) + + +class Qwen2DecoderLayerPipe(Qwen2DecoderLayer): + def forward(self, args): + hidden_states, attention_mask, position_ids = parse_args(args) + + has_gradient = not hidden_states.stop_gradient + + if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: + if attention_mask is not None: + hidden_states = recompute( + super().forward, hidden_states, attention_mask=attention_mask, use_reentrant=False + ) + else: + # for pretrain + hidden_states = recompute( + super().forward, hidden_states, use_reentrant=self.config.recompute_use_reentrant + ) + else: + hidden_states = super().forward(hidden_states, attention_mask=attention_mask) + + return return_args(hidden_states, attention_mask, position_ids) + + +class Qwen2RMSNormPipe(Qwen2RMSNorm): + def __init__(self, config: Qwen2Config): + super().__init__() + self.norm = Qwen2RMSNorm(config) + + def forward(self, args): + hidden_states, attention_mask, position_ids = parse_args(args) + return self.norm(hidden_states) + + +class QWenForCausalLMPipe(PipelinePretrainedModel, PipelineLayer): + """QWenForPretraining adapted for pipeline parallelism. + + The largest change is flattening the QWenModel class so we can express it as a + sequence of layers including embedding, transformer layers, and output. + """ + + config_class = Qwen2Config + + _get_tensor_parallel_mappings = Qwen2PretrainedModel._get_tensor_parallel_mappings + _init_weights = Qwen2PretrainedModel._init_weights + _keys_to_ignore_on_load_unexpected = Qwen2PretrainedModel._keys_to_ignore_on_load_unexpected + + # DONOT Add base_model_prefix !!!! + + def __init__(self, config: Qwen2Config): + self.config = config + + self.use_recompute = self.config.use_recompute + self.recompute_granularity = self.config.recompute_granularity + self.pp_recompute_interval = self.config.pp_recompute_interval + self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else [] + if self.recompute_granularity == "full": + assert len(self.no_recompute_layers) == 0, "for pp with full recompute, no_recompute_layers is not support" + + virtual_pp_degree = getattr(self.config, "virtual_pp_degree", 1) + + def get_hcg(): + return fleet.get_hybrid_communicate_group() + + hcg = get_hcg() + tensor_parallel_degree = max(hcg.get_model_parallel_world_size(), 1) + tensor_parallel_rank = max(hcg.get_model_parallel_rank(), 0) + + # TODO: fix tensor_parallel_degree rewrite in here + config.tensor_parallel_degree = tensor_parallel_degree + config.tensor_parallel_rank = tensor_parallel_rank + + self.add_sequential_layer(LayerDesc(Qwen2EmbeddingPipe, config=config), "qwen2") + for i in range(config.num_hidden_layers): + self.add_sequential_layer( + LayerDesc(Qwen2DecoderLayerPipe, config=config, layerwise_recompute=i not in self.no_recompute_layers), + f"qwen2.layers.{i}", + ) + self.add_sequential_layer(LayerDesc(Qwen2RMSNormPipe, config=config), "norm") + self.add_sequential_layer(LayerDesc(Qwen2LMHead, config=config), "lm_head") + + recompute_interval = 0 + if self.use_recompute and self.recompute_granularity == "full": + assert self.config.pp_recompute_interval <= config.num_hidden_layers // ( + virtual_pp_degree * get_hcg().topology().get_dim_size("pipe") + ), "pp recompute interval should smaller than num layers of each pp chunk" + recompute_interval = self.config.pp_recompute_interval + + seg_method = "layer:Qwen2DecoderLayer" + if config.num_hidden_layers % get_hcg().topology().get_dim_size("pipe") != 0: + seg_method = "uniform" + + PipelineLayer.__init__( + self, + layers=self.get_sequential_layers(), + loss_fn=Qwen2PretrainingCriterion(config), + topology=get_hcg().topology(), + seg_method=seg_method, + recompute_interval=recompute_interval, + recompute_ctx={ + "mp_group": get_hcg().get_model_parallel_group(), + "offload": False, + "partition": False, + }, + num_virtual_pipeline_stages=virtual_pp_degree, + ) + # You should call init here, since there is a diamond inheritance problem + self.apply(self._init_weights) + # DON'T init PipelinePretrainedModel + # PipelinePretrainedModel.__init__(self.super(), config=config) diff --git a/paddlenlp/transformers/qwen2/tokenizer.py b/paddlenlp/transformers/qwen2/tokenizer.py new file mode 100644 index 000000000000..f637e4a10706 --- /dev/null +++ b/paddlenlp/transformers/qwen2/tokenizer.py @@ -0,0 +1,333 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for Qwen2.""" + +import json +import os +import unicodedata +from functools import lru_cache +from typing import Optional, Tuple + +import regex as re + +from .. import AddedToken, PretrainedTokenizer + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "merges_file": "merges.txt", +} + +__all__ = ["Qwen2Tokenizer"] + +MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768} + +PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. + + The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab + if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for + decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup + tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """ + Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class Qwen2Tokenizer(PretrainedTokenizer): + """ + Construct a Qwen2 tokenizer. Based on byte-level Byte-Pair-Encoding. + + Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will + be encoded differently whether it is at the beginning of the sentence (without space) or not: + + ```python + >>> from transformers import Qwen2Tokenizer + + >>> tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen-tokenizer") + >>> tokenizer("Hello world")["input_ids"] + [9707, 1879] + + >>> tokenizer(" Hello world")["input_ids"] + [21927, 1879] + ``` + This is expected. + + You should not use GPT2Tokenizer instead, because of the different pretokenization rules. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str`, *optional*): + The beginning of sequence token. Not applicable for this tokenizer. + eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The end of sequence token. + pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The token used for padding, for example when batching sequences of different lengths. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not the model should cleanup the spaces that were added when splitting the input text during the + tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces. + split_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the special tokens should be split during the tokenization process. The default behavior is + to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") = + ['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<', + '|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment. + """ + + resource_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + max_model_input_sizes = MAX_MODEL_INPUT_SIZES + + def __init__( + self, + vocab_file, + merges_file, + errors="replace", + unk_token="<|endoftext|>", + bos_token=None, + eos_token="<|endoftext|>", + pad_token="<|endoftext|>", + clean_up_tokenization_spaces=False, + split_special_tokens=False, + **kwargs, + ): + super().__init__(**kwargs) + # Qwen vocab does not contain control tokens; added tokens need to be special + bos_token = ( + AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(bos_token, str) + else bos_token + ) + eos_token = ( + AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(eos_token, str) + else eos_token + ) + unk_token = ( + AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(unk_token, str) + else unk_token + ) + pad_token = ( + AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(pad_token, str) + else pad_token + ) + + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + bpe_merges = [] + with open(merges_file, encoding="utf-8") as merges_handle: + for i, line in enumerate(merges_handle): + line = line.strip() + if (i == 0 and line.startswith("#version:")) or not line: + continue + bpe_merges.append(tuple(line.split())) + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + # NOTE: the cache can grow without bound and will get really large for long running processes + # (esp. for texts of language that do not use space between word, e.g. Chinese); technically + # not a memory leak but appears as one. + # GPT2Tokenizer has the same problem, so let's be consistent. + self.cache = {} + + self.pat = re.compile(PRETOKENIZE_REGEX) + + # if kwargs.get("add_prefix_space", False): + # logger.warning_once( + # f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect." + # ) + self.bos_token_id = kwargs["bos_token_id"] if "bos_token_id" in kwargs else None + self.eos_token_id = kwargs["eos_token_id"] if "eos_token_id" in kwargs else None + self.unk_token_id = kwargs["unk_token_id"] if "unk_token_id" in kwargs else None + self.pad_token_id = kwargs["pad_token_id"] if "pad_token_id" in kwargs else None + + super().__init__( + errors=errors, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + unk_token=unk_token, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + split_special_tokens=split_special_tokens, + **kwargs, + ) + + @property + def vocab_size(self) -> int: + return len(self.encoder) + + def get_vocab(self): + return dict(self.encoder, **self.added_tokens_encoder) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def _tokenize(self, text): + """Tokenize a string.""" + bpe_tokens = [] + for token in re.findall(self.pat, text): + token = "".join( + self.byte_encoder[b] for b in token.encode("utf-8") + ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.added_tokens_encoder.get(token, len(self.encoder))) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.decoder.get(index, self.added_tokens_decoder.get(index, self.unk_token)) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + text = "".join(tokens) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) + return text + + def _decode( + self, + token_ids, + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: Optional[bool] = False, + spaces_between_special_tokens: bool = False, + **kwargs, + ) -> str: + # `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers + # and cannot be configured elsewhere, but it should default to False for Qwen2Tokenizer + return super()._decode( + token_ids, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + spaces_between_special_tokens=spaces_between_special_tokens, + **kwargs, + ) + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + # if not os.path.isdir(save_directory): + # logger.error(f"Vocabulary path ({save_directory}) should be a directory") + # return + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + merge_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write("#version: 0.2\n") + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + # logger.warning( + # f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." + # " Please check that the tokenizer is not corrupted!" + # ) + index = token_index + writer.write(" ".join(bpe_tokens) + "\n") + index += 1 + + return vocab_file, merge_file + + def prepare_for_tokenization(self, text, **kwargs): + text = unicodedata.normalize("NFC", text) + return (text, kwargs) diff --git a/paddlenlp/transformers/qwen2_moe/__init__.py b/paddlenlp/transformers/qwen2_moe/__init__.py new file mode 100644 index 000000000000..2f2acfa9b339 --- /dev/null +++ b/paddlenlp/transformers/qwen2_moe/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..qwen2.tokenizer import * +from .configuration import * +from .modeling import * diff --git a/paddlenlp/transformers/qwen2_moe/configuration.py b/paddlenlp/transformers/qwen2_moe/configuration.py new file mode 100644 index 000000000000..46fe1f91437d --- /dev/null +++ b/paddlenlp/transformers/qwen2_moe/configuration.py @@ -0,0 +1,203 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Qwen2Moe model configuration""" + +from paddlenlp.transformers.configuration_utils import PretrainedConfig + +__all__ = [ + "Qwen2MoeConfig", +] + + +class Qwen2MoeConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2MoeModel`]. It is used to instantiate a + Qwen2Moe model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen1.5-MoE-A2.7B" [Qwen/Qwen1.5-MoE-A2.7B"](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B"). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen2Moe model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2MoeModel`] + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 5632): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 16): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + decoder_sparse_step (`int`, *optional*, defaults to 1): + The frequency of the MoE layer. + moe_intermediate_size (`int`, *optional*, defaults to 1408): + Intermediate size of the routed expert. + shared_expert_intermediate_size (`int`, *optional*, defaults to 5632): + Intermediate size of the shared expert. + num_experts_per_tok (`int`, *optional*, defaults to 4): + Number of selected experts. + num_experts (`int`, *optional*, defaults to 60): + Number of routed experts. + norm_topk_prob (`bool`, *optional*, defaults to `False`): + Whether to normalize the topk probabilities. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabeling this will also + allow the model to output the auxiliary loss, including load balancing loss and router z-loss. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + + ```python + >>> from paddlenlp.transformers import Qwen2MoeModel, Qwen2MoeConfig + + >>> # Initializing a Qwen2Moe style configuration + >>> configuration = Qwen2MoeConfig() + + >>> # Initializing a model from the Qwen1.5-MoE-A2.7B" style configuration + >>> model = Qwen2MoeModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_moe" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=151936, + hidden_size=2048, + intermediate_size=5632, + num_hidden_layers=24, + num_attention_heads=16, + num_key_value_heads=16, + hidden_act="silu", + max_position_embeddings=8192, + seq_length=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + use_recompute=False, + recompute_granularity="full", + no_recompute_layers=None, + use_flash_attention=False, + attention_dropout=0.0, + use_fused_rope=False, + rope_theta=1000000.0, + tensor_parallel_output=True, + sequence_parallel=False, + fuse_sequence_parallel_allreduce=False, + pad_token_id=0, + bos_token_id=151643, + eos_token_id=151643, + tie_word_embeddings=False, + use_sliding_window=False, + sliding_window=32768, + max_window_layers=28, + decoder_sparse_step=1, + moe_intermediate_size=1408, + shared_expert_intermediate_size=5632, + num_experts_per_tok=4, + num_experts=60, + norm_topk_prob=False, + output_router_logits=False, + router_aux_loss_coef=0.001, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.seq_length = seq_length + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + + self.use_cache = use_cache + self.use_recompute = use_recompute + self.recompute_granularity = recompute_granularity + self.no_recompute_layers = no_recompute_layers + self.use_flash_attention = use_flash_attention + self.tensor_parallel_output = tensor_parallel_output + self.sequence_parallel = sequence_parallel + self.fuse_sequence_parallel_allreduce = fuse_sequence_parallel_allreduce + + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + self.use_fused_rope = use_fused_rope + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + + # MoE arguments + self.decoder_sparse_step = decoder_sparse_step + self.moe_intermediate_size = moe_intermediate_size + self.shared_expert_intermediate_size = shared_expert_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.norm_topk_prob = norm_topk_prob + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + tensor_parallel_output=tensor_parallel_output, + **kwargs, + ) diff --git a/paddlenlp/transformers/qwen2_moe/modeling.py b/paddlenlp/transformers/qwen2_moe/modeling.py new file mode 100644 index 000000000000..da993c319c98 --- /dev/null +++ b/paddlenlp/transformers/qwen2_moe/modeling.py @@ -0,0 +1,1553 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Paddle Qwen2Moe model.""" +from __future__ import annotations + +import math +import warnings +from functools import partial +from typing import Optional, Tuple + +import paddle +import paddle.distributed.fleet.meta_parallel as mpu +import paddle.nn.functional as F +from paddle import Tensor, nn +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker +from paddle.distributed.fleet.utils import recompute +from paddle.distributed.fleet.utils.sequence_parallel_utils import ( + ColumnSequenceParallelLinear, + GatherOp, + RowSequenceParallelLinear, + ScatterOp, + mark_as_sequence_parallel_parameter, +) + +from ...utils.log import logger +from ..activations import ACT2FN +from ..conversion_utils import StateDictNameMapping, init_name_mappings +from ..model_outputs import MoECausalLMOutputWithPast, MoEModelOutputWithPast +from ..model_utils import PretrainedModel, register_base_model +from .configuration import Qwen2MoeConfig + +try: + from paddle.incubate.nn.functional import fused_rotary_position_embedding +except ImportError: + fused_rotary_position_embedding = None + +try: + from paddle.nn.functional.flash_attention import flash_attention +except: + flash_attention = None + +__all__ = [ + "Qwen2MoeModel", + "Qwen2MoePretrainedModel", + "Qwen2MoeForCausalLM", + "Qwen2MoePretrainingCriterion", +] + + +def load_balancing_loss_func(gate_logits, num_experts, top_k=2, attention_mask=None): + """ + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Paddle. + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + Args: + gate_logits (Union[`paddle.Tensor`, Tuple[paddle.Tensor]): + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts (`int`): + Number of experts. + top_k (`int`): + Number of top k experts to be considered for the loss computation. + attention_mask (`paddle.Tensor`, None): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + concatenated_gate_logits = paddle.concat( + gate_logits, axis=0 + ) # [num_hidden_layers X batch_size X sequence_length, num_experts] + + routing_weights = F.softmax(concatenated_gate_logits, axis=-1) + _, selected_experts = paddle.topk(routing_weights, top_k, axis=-1) + expert_mask = F.one_hot( + selected_experts, num_classes=num_experts + ) # [num_hidden_layers X batch_size X sequence_length, top_k, num_experts] + + if attention_mask is None or len(attention_mask.shape) == 4: + # Only intokens strategy has 4-D attention_mask, we currently do not support excluding padding tokens. + # Compute the percentage of tokens routed to each experts + tokens_per_expert = paddle.mean(expert_mask.astype("float32"), axis=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = paddle.mean(routing_weights, axis=0) + else: + # Exclude the load balancing loss of padding tokens. + if len(attention_mask.shape) == 2: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape([-1, top_k, num_experts]) + ) # [num_hidden_layers * batch_size * sequence_length, top_k, num_experts] + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = paddle.sum(expert_mask.astype("float32") * expert_attention_mask, axis=0) / paddle.sum( + expert_attention_mask, axis=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape([-1, num_experts]) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = paddle.sum( + routing_weights * router_per_expert_attention_mask, axis=0 + ) / paddle.sum(router_per_expert_attention_mask, axis=0) + + overall_loss = paddle.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +def get_triangle_upper_mask(x, mask=None): + if mask is not None: + return mask + # [bsz, n_head, q_len, kv_seq_len] + shape = x.shape + # [bsz, 1, q_len, kv_seq_len] + shape[1] = 1 + mask = paddle.full(shape, paddle.finfo(x.dtype).min, dtype=x.dtype) + mask = paddle.triu(mask, diagonal=1) + mask.stop_gradient = True + return mask + + +def assign_kv_heads(num_kv_heads: int, num_gpus: int): + # Initialize the assignment list + """ + Assign kv heads to different GPUs in the Tensor Parallel Setup + + Examples: + assign_kv_heads(num_kv_heads=1, num_gpus=2): [[0], [0]] + assign_kv_heads(num_kv_heads=2, num_gpus=2): [[0], [1]] + assign_kv_heads(num_kv_heads=4, num_gpus=2): [[0,1], [2,3]] + assign_kv_heads(num_kv_heads=1, num_gpus=4): [[0],[0],[0],[0]] + assign_kv_heads(num_kv_heads=2, num_gpus=4): [[0],[0],[1],[1]] + assign_kv_heads(num_kv_heads=4, num_gpus=4): [[0],[1],[2],[3]] + """ + assignment_list = [[] for _ in range(num_gpus)] + # Case 1: more heads than cards + if num_kv_heads > num_gpus: + num_heads_per_card = num_kv_heads // num_gpus + for i in range(num_gpus): + for j in range(num_heads_per_card): + assignment_list[i].append(i * num_heads_per_card + j) + # Case 2: more cards than heads. each card get only 1 head. + else: + num_card_per_heads = num_gpus // num_kv_heads + for i in range(num_kv_heads): + for j in range(num_card_per_heads): + assignment_list[i * num_card_per_heads + j].append(i) + return assignment_list + + +def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True): + is_fleet_init = True + tensor_parallel_degree = 1 + try: + hcg = fleet.get_hybrid_communicate_group() + model_parallel_group = hcg.get_model_parallel_group() + tensor_parallel_degree = hcg.get_model_parallel_world_size() + except: + is_fleet_init = False + + if paddle.in_dynamic_mode(): + y_is_distributed = y.is_distributed + else: + y_is_distributed = tensor_parallel_degree > 1 + + if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed: + # if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg' + input_parallel = paddle.distributed.collective._c_identity(x, group=model_parallel_group) + logits = paddle.matmul(input_parallel, y, transpose_y=False) + + if tensor_parallel_output: + return logits + + return paddle.distributed.collective._c_concat(logits, group=model_parallel_group) + + else: + logits = paddle.matmul(x, y, transpose_y=False) + return logits + + +def scaled_dot_product_attention( + query_states, + config, + key_states, + value_states, + attention_mask, + output_attentions, + training=True, + sequence_parallel=False, +): + bsz, q_len, num_heads, head_dim = query_states.shape + _, kv_seq_len, _, _ = value_states.shape + + if config.use_flash_attention and flash_attention: + # Paddle Flash Attention input [ bz, seqlen, nhead, head_dim] + # Torch Flash Attention input [ bz, nhead, seqlen, head_dim] + + version = paddle.version.full_version + if version != "0.0.0" and version <= "2.5.2": + attn_output, attn_weights = flash_attention( + query_states, + key_states, + value_states, + causal=True, + return_softmax=output_attentions, + ) + else: + attn_output = F.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + is_causal=attention_mask is None, + dropout_p=config.attention_dropout if training else 0.0, + training=training, + ) + attn_weights = None + + if sequence_parallel: + attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads]) + else: + attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) + return (attn_output, attn_weights) if output_attentions else attn_output + else: + # [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim] + query_states = paddle.transpose(query_states, [0, 2, 1, 3]) + # merge with the next tranpose + key_states = paddle.transpose(key_states, [0, 2, 1, 3]) + value_states = paddle.transpose(value_states, [0, 2, 1, 3]) + + # matmul and devide by sqrt(head_dim) + attn_weights = paddle.matmul(query_states / math.sqrt(head_dim), key_states.transpose([0, 1, 3, 2])) + + if attn_weights.shape != [bsz, num_heads, q_len, kv_seq_len]: + raise ValueError( + f"Attention weights should be of shape {(bsz, num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.shape}" + ) + + if attention_mask is None: + attention_mask = get_triangle_upper_mask(attn_weights) + attention_mask = attention_mask.reshape([bsz, 1, q_len, kv_seq_len]) + if attention_mask.shape != [bsz, 1, q_len, kv_seq_len]: + raise ValueError( + f"Attention mask should be of shape {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.shape}" + ) + + attn_weights = attn_weights + attention_mask + if not paddle.in_dynamic_mode(): + attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype) + else: + with paddle.amp.auto_cast(False): + attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype) + + attn_weights = F.dropout(attn_weights, p=config.attention_dropout, training=training) + + attn_output = paddle.matmul(attn_weights, value_states) + attn_output = attn_output.transpose([0, 2, 1, 3]) + + if sequence_parallel: + attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads]) + else: + attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) + return (attn_output, attn_weights) if output_attentions else attn_output + + +def masked_fill(x, mask, value): + y = paddle.full(x.shape, value, x.dtype) + return paddle.where(mask, y, x) + + +def is_casual_mask(attention_mask): + """ + Upper triangular of attention_mask equals to attention_mask is casual + """ + return (paddle.triu(attention_mask) == attention_mask).all().item() + + +def _make_causal_mask(input_ids_shape, past_key_values_length): + """ + Make causal mask used for self-attention + """ + batch_size, target_length = input_ids_shape # target_length: seq_len + + mask = paddle.tril(paddle.ones((target_length, target_length), dtype="bool")) + + if past_key_values_length > 0: + # [tgt_len, tgt_len + past_len] + mask = paddle.concat([paddle.ones([target_length, past_key_values_length], dtype="bool"), mask], axis=-1) + + # [bs, 1, tgt_len, tgt_len + past_len] + return mask[None, None, :, :].expand([batch_size, 1, target_length, target_length + past_key_values_length]) + + +def _expand_2d_mask(mask, dtype, tgt_length): + """ + Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`. + """ + batch_size, src_length = mask.shape[0], mask.shape[-1] + tgt_length = tgt_length if tgt_length is not None else src_length + + mask = mask[:, None, None, :].astype("bool") + mask.stop_gradient = True + expanded_mask = mask.expand([batch_size, 1, tgt_length, src_length]) + + return expanded_mask + + +class Qwen2MoeRMSNorm(nn.Layer): + def __init__(self, config: Qwen2MoeConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.weight = paddle.create_parameter( + shape=[self.hidden_size], + dtype=paddle.get_default_dtype(), + default_initializer=nn.initializer.Constant(1.0), + ) + self.variance_epsilon = config.rms_norm_eps + self.config = config + + if config.sequence_parallel: + mark_as_sequence_parallel_parameter(self.weight) + + def forward(self, hidden_states): + if paddle.in_dynamic_mode(): + with paddle.amp.auto_cast(False): + hidden_states = hidden_states.astype("float32") + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states + else: + hidden_states = hidden_states.astype("float32") + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states + + if self.weight.dtype in [paddle.float16, paddle.bfloat16]: + hidden_states = paddle.cast(hidden_states, self.weight.dtype) + return hidden_states * self.weight + + +class Qwen2MoeRotaryEmbedding(nn.Layer): + def __init__(self, dim, max_position_embeddings=2048, base=10000): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + # [dim / 2] + self.inv_freq = 1.0 / (self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32") / self.dim)) + self._set_cos_sin_cache(seq_len=max_position_embeddings) + + def _set_cos_sin_cache(self, seq_len): + self.max_seq_len_cached = seq_len + # [seq_len] + t = paddle.arange(seq_len, dtype="float32") + # [seq_len, dim/2] + freqs = paddle.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + # [seq_len, dim] + emb = paddle.concat([freqs, freqs], axis=-1) + # [1, seqlen, 1, dim] + self.cos_cached = emb.cos()[None, :, None, :] + self.sin_cached = emb.sin()[None, :, None, :] + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len) + cos = self.cos_cached[:, :seq_len, :, :] + sin = self.sin_cached[:, :seq_len, :, :] + return ( + cos.cast(x.dtype) if cos.dtype != x.dtype else cos, + sin.cast(x.dtype) if sin.dtype != x.dtype else sin, + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return paddle.concat([-x2, x1], axis=-1) # shape is the same as x + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + if position_ids is None: + # Note: Only for Qwen2MoeForCausalLMPipe model pretraining + cos = cos[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] + sin = sin[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] + else: + cos = cos.squeeze(axis=[0, 2]) # [seq_len, dim] + sin = sin.squeeze(axis=[0, 2]) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Modified from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2Moe +class Qwen2MoeMLP(nn.Layer): + def __init__(self, config: Qwen2MoeConfig, is_shared=False): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = ( + config.moe_intermediate_size if not is_shared else config.shared_expert_intermediate_size + ) + self.tensor_parallel_degree = config.tensor_parallel_degree + + if config.sequence_parallel: + ColumnParallelLinear = ColumnSequenceParallelLinear + RowParallelLinear = RowSequenceParallelLinear + else: + ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear + RowParallelLinear = fleet.meta_parallel.RowParallelLinear + + if config.tensor_parallel_degree > 1: + self.gate_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=False, + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=False, + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + input_is_parallel=True, + has_bias=False, + ) + else: + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) # w1 + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) # w3 + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False) # w2 + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor: + """ + This is the equivalent of paddle.repeat_interleave(hidden_states, n_rep, axis=1). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, slen, num_key_value_heads, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + + hidden_states = hidden_states.unsqueeze(-2).tile([1, 1, 1, n_rep, 1]) + return hidden_states.reshape([batch, slen, num_key_value_heads * n_rep, head_dim]) + + +class Qwen2MoeAttention(nn.Layer): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2MoeConfig, layerwise_recompute: bool = True): + super().__init__() + + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.head_dim = self.hidden_size // config.num_attention_heads + + self.num_key_value_heads = config.num_key_value_heads + assert config.num_attention_heads // config.num_key_value_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.gqa_or_mqa = config.num_attention_heads != config.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + self.seq_length = config.seq_length + self.sequence_parallel = config.sequence_parallel + + # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True + # Enable_recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.layerwise_recompute = layerwise_recompute + self.recompute_granularity = config.recompute_granularity + if config.tensor_parallel_degree > 1: + assert ( + self.num_heads % config.tensor_parallel_degree == 0 + ), f"num_heads: {self.num_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}" + self.num_heads = self.num_heads // config.tensor_parallel_degree + + assert ( + self.num_key_value_heads % config.tensor_parallel_degree == 0 + ), f"num_key_value_heads: {self.num_key_value_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}" + self.num_key_value_heads = self.num_key_value_heads // config.tensor_parallel_degree + + self.use_fused_rope = config.use_fused_rope + if self.use_fused_rope: + if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None: + warnings.warn( + "Enable fuse rope in the config, but fuse rope is not available. " + "Will disable fuse rope. Try using latest gpu version of Paddle." + ) + self.use_fused_rope = False + + if config.sequence_parallel: + ColumnParallelLinear = ColumnSequenceParallelLinear + RowParallelLinear = RowSequenceParallelLinear + else: + ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear + RowParallelLinear = fleet.meta_parallel.RowParallelLinear + + if config.tensor_parallel_degree > 1: + self.q_proj = ColumnParallelLinear(self.hidden_size, self.hidden_size, has_bias=True, gather_output=False) + self.k_proj = ColumnParallelLinear( + self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=True, gather_output=False + ) + self.v_proj = ColumnParallelLinear( + self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=True, gather_output=False + ) + self.o_proj = RowParallelLinear(self.hidden_size, self.hidden_size, has_bias=False, input_is_parallel=True) + else: + self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias_attr=True) + self.k_proj = nn.Linear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, bias_attr=True) + self.v_proj = nn.Linear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, bias_attr=True) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias_attr=False) + + self.rotary_emb = Qwen2MoeRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def forward( + self, + hidden_states, + position_ids: Optional[Tuple[paddle.Tensor]] = None, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: + """Input shape: Batch x Time x Channel""" + # [bs, seq_len, num_head * head_dim] -> [seq_len / n, bs, num_head * head_dim] (n is model parallelism) + + batch_size, seq_len, _ = hidden_states.shape + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + if self.sequence_parallel: + target_query_shape = [-1, self.seq_length, self.num_heads, self.head_dim] + target_key_value_shape = [-1, self.seq_length, self.num_key_value_heads, self.head_dim] + else: + target_query_shape = [0, 0, self.num_heads, self.head_dim] + target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim] + query_states = query_states.reshape(shape=target_query_shape) + key_states = key_states.reshape(shape=target_key_value_shape) + value_states = value_states.reshape(shape=target_key_value_shape) + + kv_seq_len = key_states.shape[-3] + + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-3] + + if self.use_fused_rope: + assert past_key_value is None, "fuse rotary not support cache kv for now" + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states, _ = fused_rotary_position_embedding( + query_states, + key_states, + v=None, + sin=sin, + cos=cos, + position_ids=position_ids, + use_neox_rotary_style=False, + ) + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + # [bs, seq_len, num_head, head_dim] + if past_key_value is not None: + key_states = paddle.concat([past_key_value[0], key_states], axis=1) + value_states = paddle.concat([past_key_value[1], value_states], axis=1) + past_key_value = (key_states, value_states) if use_cache else None + + # TODO(wj-Mcat): use broadcast strategy when n_kv_heads = 1 + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + has_gradient = not (query_states.stop_gradient and key_states.stop_gradient and value_states.stop_gradient) + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "core_attn" + ): + outputs = recompute( + scaled_dot_product_attention, + query_states, + self.config, + key_states, + value_states, + attention_mask, + output_attentions, + self.training, + self.sequence_parallel, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + outputs = scaled_dot_product_attention( + query_states, + self.config, + key_states, + value_states, + attention_mask, + output_attentions, + self.training, + self.sequence_parallel, + ) + if output_attentions: + attn_output, attn_weights = outputs + else: + attn_output = outputs + + # if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim] + # else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism. + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + outputs = (attn_output,) + + if output_attentions: + outputs += (attn_weights,) + + if use_cache: + outputs += (past_key_value,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class Qwen2MoeSparseMoEBlock(nn.Layer): + def __init__(self, config: Qwen2MoeConfig): + super().__init__() + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + + self.gate = nn.Linear(config.hidden_size, self.num_experts, bias_attr=False) + self.experts = nn.LayerList([Qwen2MoeMLP(config) for _ in range(self.num_experts)]) + + self.shared_expert = Qwen2MoeMLP(config, is_shared=True) + self.shared_expert_gate = nn.Linear(config.hidden_size, 1, bias_attr=False) + + def forward(self, hidden_states): + batch_size, seq_len, hidden_dim = hidden_states.shape + hidden_states = hidden_states.reshape([-1, hidden_dim]) + # router_logits: [batch_size * seq_len, num_experts] + router_logits = self.gate(hidden_states) + + with paddle.amp.auto_cast(False): + routing_weights = F.softmax(router_logits.astype("float32"), axis=1) + routing_weights, selected_experts = paddle.topk(routing_weights, self.top_k, axis=-1) + if self.norm_topk_prob: # Note: Mixtral is set norm as default, Qwen2Moe is set to no norm + routing_weights /= routing_weights.sum(axis=-1, keepdim=True) + # we cast back to input dtype + routing_weights = routing_weights.astype(hidden_states.dtype) + + final_hidden_states = paddle.zeros( + [batch_size * seq_len, hidden_dim], + dtype=hidden_states.dtype, + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated. + # shape: [num_experts, top_k, batch_size * seq_len] + expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).transpose([2, 1, 0]) + + # Loop over all available experts in the model and perform the computation on each expert. + for expert_id in range(self.num_experts): + expert_layer = self.experts[expert_id] + idx, top_x = paddle.where(expert_mask[expert_id]) + + if top_x.shape[0] == 0: + continue + + current_state = paddle.gather(hidden_states, top_x.squeeze()) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx] + + top_x = top_x.squeeze() + if top_x.shape == []: + top_x = paddle.to_tensor([top_x.item()]) + final_hidden_states = paddle.index_add_( + final_hidden_states, top_x, 0, current_hidden_states.astype(hidden_states.dtype) + ) + + shared_expert_output = self.shared_expert(hidden_states) + shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output + + final_hidden_states = final_hidden_states + shared_expert_output + + final_hidden_states = final_hidden_states.reshape([batch_size, seq_len, hidden_dim]) + return final_hidden_states, router_logits + + +class Qwen2MoeDecoderLayer(nn.Layer): + def __init__(self, config: Qwen2MoeConfig, layerwise_recompute: bool = False): + super().__init__() + self.config = config + + self.self_attn = Qwen2MoeAttention(config, layerwise_recompute) + + if config.num_experts > 0: + self.mlp = Qwen2MoeSparseMoEBlock(config) + else: + # num_experts == 0 or this layer is not sparse layer + self.mlp = Qwen2MoeMLP(config) + + self.input_layernorm = Qwen2MoeRMSNorm(config) + self.post_attention_layernorm = Qwen2MoeRMSNorm(config) + + self.sequence_parallel = config.sequence_parallel + # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True + # Enable_recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.layerwise_recompute = layerwise_recompute + self.recompute_granularity = config.recompute_granularity + + def forward( + self, + hidden_states: paddle.Tensor, + position_ids: Optional[Tuple[paddle.Tensor]] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + """ + Args: + hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`paddle.Tensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states + """ + + # [bs * seq_len, embed_dim] -> [seq_len * bs / n, embed_dim] (sequence_parallel) + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "full_attn" + ): + outputs = recompute( + self.self_attn, + hidden_states, + position_ids, + past_key_value, + attention_mask, + output_attentions, + use_cache, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + outputs = self.self_attn( + hidden_states, + position_ids, + past_key_value, + attention_mask, + output_attentions, + use_cache, + ) + + if type(outputs) is tuple: + hidden_states = outputs[0] + else: + hidden_states = outputs + + if output_attentions: + self_attn_weights = outputs[1] + + if use_cache: + present_key_value = outputs[2 if output_attentions else 1] + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + if isinstance(hidden_states, tuple): + hidden_states, router_logits = hidden_states + else: + router_logits = None + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if output_router_logits: + outputs += (router_logits,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class Qwen2MoePretrainedModel(PretrainedModel): + config_class = Qwen2MoeConfig + base_model_prefix = "qwen2_moe" + _keys_to_ignore_on_load_unexpected = [r"self_attn.rotary_emb.inv_freq"] + + @classmethod + def _get_name_mappings(cls, config: Qwen2MoeConfig) -> list[StateDictNameMapping]: + mappings: list[StateDictNameMapping] = [] + model_mappings = [ + ["embed_tokens.weight"], + ["norm.weight"], + ] + for layer_index in range(config.num_hidden_layers): + layer_mappings = [ + [f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.k_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.v_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.q_proj.bias", None], + [f"layers.{layer_index}.self_attn.k_proj.bias", None], + [f"layers.{layer_index}.self_attn.v_proj.bias", None], + [f"layers.{layer_index}.self_attn.o_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.rotary_emb.inv_freq"], + [f"layers.{layer_index}.input_layernorm.weight"], + [f"layers.{layer_index}.post_attention_layernorm.weight"], + ] + model_mappings.extend(layer_mappings) + + for expert_idx in range(config.num_experts): + expert_mappings = [ + [f"layers.{layer_index}.mlp.experts.{expert_idx}.gate_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.experts.{expert_idx}.down_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.experts.{expert_idx}.up_proj.weight", None, "transpose"], + ] + model_mappings.extend(expert_mappings) + model_mappings.append([f"layers.{layer_index}.mlp.gate.weight", None, "transpose"]) + + model_mappings.append([f"layers.{layer_index}.mlp.shared_expert.gate_proj.weight", None, "transpose"]) + model_mappings.append([f"layers.{layer_index}.mlp.shared_expert.down_proj.weight", None, "transpose"]) + model_mappings.append([f"layers.{layer_index}.mlp.shared_expert.up_proj.weight", None, "transpose"]) + model_mappings.append([f"layers.{layer_index}.mlp.shared_expert_gate.weight", None, "transpose"]) + + init_name_mappings(mappings=model_mappings) + # base-model prefix "Qwen2MoeModel" + if "Qwen2MoeModel" not in config.architectures: + for mapping in model_mappings: + mapping[0] = "model." + mapping[0] + mapping[1] = "qwen2_moe." + mapping[1] + model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"]) + + mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)] + return mappings + + @classmethod + def _get_tensor_parallel_mappings(cls, config: Qwen2MoeConfig, is_split=True): + from paddlenlp.transformers.conversion_utils import split_or_merge_func + + fn = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + ) + + def get_tensor_parallel_split_mappings(num_layers, num_experts): + final_actions = {} + + base_actions = { + "lm_head.weight": partial(fn, is_column=True), + # Row Linear + "embed_tokens.weight": partial(fn, is_column=False), + "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), + } + + if not config.vocab_size % config.tensor_parallel_degree == 0: + base_actions.pop("lm_head.weight") + base_actions.pop("embed_tokens.weight") + + # Column Linear + base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True) + # if we have enough num_key_value_heads to split, then split it. + if config.num_key_value_heads % config.tensor_parallel_degree == 0: + base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.k_proj.bias"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.v_proj.bias"] = partial(fn, is_column=True) + + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range(num_layers): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + final_actions[key] = action + + # Add tp split for expert params. + base_actions = { + "layers.0.mlp.experts.0.gate_proj.weight": partial(fn, is_column=True), + "layers.0.mlp.experts.0.down_proj.weight": partial(fn, is_column=False), + "layers.0.mlp.experts.0.up_proj.weight": partial(fn, is_column=True), + } + for key, action in base_actions.items(): + for i in range(num_layers): + newkey = key.replace("layers.0.", f"layers.{i}.") + for j in range(num_experts): + newkey2 = newkey.replace("experts.0.", f"experts.{j}.") + final_actions[newkey2] = action + + # Add tp split for shared expert params. + base_actions = { + "layers.0.mlp.shared_expert.gate_proj.weight": partial(fn, is_column=True), + "layers.0.mlp.shared_expert.up_proj.weight": partial(fn, is_column=True), + "layers.0.mlp.shared_expert.down_proj.weight": partial(fn, is_column=False), + } + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range(num_layers): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + final_actions[key] = action + + return final_actions + + mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers, config.num_experts) + + return mappings + + def _init_weights(self, layer): + """Initialization hook""" + if self.config.tensor_parallel_degree > 1: + rng_tracker = get_rng_state_tracker().rng_state + if isinstance( + layer, + ( + nn.Linear, + nn.Embedding, + mpu.VocabParallelEmbedding, + mpu.ColumnParallelLinear, + mpu.RowParallelLinear, + Qwen2MoeLMHead, + ColumnSequenceParallelLinear, + RowSequenceParallelLinear, + ), + ): + # In the dygraph mode, use the `set_value` to reset the parameter directly, + # and reset the `state_dict` to update parameter in static mode. + if isinstance(layer.weight, paddle.Tensor): + if layer.weight.is_distributed: + with rng_tracker(): + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.qwen2_moe.config.initializer_range, + shape=layer.weight.shape, + ) + ) + else: + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.qwen2_moe.config.initializer_range, + shape=layer.weight.shape, + ) + ) + if hasattr(layer, "bias") and isinstance(layer.bias, paddle.Tensor): + layer.bias.set_value(paddle.zeros_like(layer.bias)) + # Layer.apply is DFS https://github.com/PaddlePaddle/Paddle/blob/a6f5021fcc58b21f4414bae6bf4731ef6971582c/python/paddle/nn/layer/layers.py#L527-L530 + # sublayer is init first + # scale RowParallelLinear weight + with paddle.no_grad(): + if isinstance(layer, Qwen2MoeMLP): + factor = 1 / math.sqrt(2 * self.config.num_hidden_layers) + layer.down_proj.weight.scale_(factor) + if isinstance(layer, Qwen2MoeAttention): + factor = 1 / math.sqrt(2 * self.config.num_hidden_layers) + layer.o_proj.weight.scale_(factor) + + +@register_base_model +class Qwen2MoeModel(Qwen2MoePretrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2MoeDecoderLayer`] + Args: + config: Qwen2MoeConfig + """ + + def __init__(self, config: Qwen2MoeConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + self.sequence_parallel = config.sequence_parallel + self.recompute_granularity = config.recompute_granularity + self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else [] + + # Recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0: + self.embed_tokens = mpu.VocabParallelEmbedding( + self.vocab_size, + self.hidden_size, + weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()), + ) + else: + self.embed_tokens = nn.Embedding( + self.vocab_size, + self.hidden_size, + ) + + self.layers = nn.LayerList( + [ + Qwen2MoeDecoderLayer(config, layerwise_recompute=layer_idx not in self.no_recompute_layers) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = Qwen2MoeRMSNorm(config) + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @staticmethod + def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length, dtype): + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + if len(attention_mask.shape) == 2: + expanded_attn_mask = _expand_2d_mask(attention_mask, dtype, tgt_length=input_shape[-1]) + # For decoding phase in generation, seq_length = 1, we don't need to add causal mask + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + past_key_values_length=past_key_values_length, + ) + expanded_attn_mask = expanded_attn_mask & combined_attention_mask + # [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len] + elif len(attention_mask.shape) == 3: + expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool") + # if attention_mask is already 4-D, do nothing + else: + expanded_attn_mask = attention_mask + else: + expanded_attn_mask = _make_causal_mask( + input_shape, + past_key_values_length=past_key_values_length, + ) + # Convert bool attention_mask to float attention mask, which will be added to attention_scores later + expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype) + return expanded_attn_mask + + @paddle.jit.not_to_static + def recompute_training_full( + self, + layer_module: nn.Layer, + hidden_states: Tensor, + position_ids: Optional[Tensor], + attention_mask: Tensor, + output_attentions: bool, + output_router_logits: bool, + past_key_value: Tensor, + use_cache: bool, + ): + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = recompute( + create_custom_forward(layer_module), + hidden_states, + position_ids, + attention_mask, + output_attentions, + output_router_logits, + past_key_value, + use_cache, + use_reentrant=self.config.recompute_use_reentrant, + ) + + return hidden_states + + def forward( + self, + input_ids=None, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + use_cache=None, + past_key_values=None, + output_attentions=False, + output_hidden_states=None, + output_router_logits: Optional[bool] = None, + return_dict=False, + **kwargs, + ): + if self.sequence_parallel and use_cache: + raise ValueError("We currently only support sequence parallel without cache.") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if past_key_values is None: + past_key_values = tuple([None] * len(self.layers)) + # NOTE: to make cache can be clear in-time + past_key_values = list(past_key_values) + + seq_length_with_past = seq_length + cache_length = 0 + if past_key_values[0] is not None: + cache_length = past_key_values[0][0].shape[1] + seq_length_with_past += cache_length + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self.sequence_parallel: + # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim] + bs, seq_len, hidden_size = inputs_embeds.shape + inputs_embeds = paddle.reshape_(inputs_embeds, [bs * seq_len, hidden_size]) + # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) + inputs_embeds = ScatterOp.apply(inputs_embeds) + + # embed positions + if attention_mask is None: + # [bs, seq_len] + attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) + + if position_ids is None: + position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype + ) # [bs, 1, seq_len, seq_len] + if self.config.use_flash_attention: + is_casual = is_casual_mask(attention_mask) + if is_casual: + attention_mask = None + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = () if use_cache else None + + for idx, (decoder_layer) in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + past_key_value = past_key_values[idx] if past_key_values is not None else None + + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and idx not in self.no_recompute_layers + and has_gradient + and self.recompute_granularity == "full" + ): + layer_outputs = self.recompute_training_full( + decoder_layer, + hidden_states, + position_ids, + attention_mask, + output_attentions, + output_router_logits, + past_key_value, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + position_ids, + attention_mask, + output_attentions, + output_router_logits, + past_key_value, + use_cache, + ) + + # NOTE: clear outdate cache after it has been used for memory saving + past_key_value = past_key_values[idx] = None + if type(layer_outputs) is tuple: + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + return MoEModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + +class Qwen2MoePretrainingCriterion(nn.Layer): + """ + Criterion for Mixtral. + It calculates the final loss. + """ + + def __init__(self, config: Qwen2MoeConfig): + super(Qwen2MoePretrainingCriterion, self).__init__() + self.ignore_index = getattr(config, "ignore_index", -100) + self.config = config + self.enable_parallel_cross_entropy = config.tensor_parallel_degree > 1 and config.tensor_parallel_output + + if self.enable_parallel_cross_entropy: # and False: # and lm_head is distributed + self.loss_func = mpu.ParallelCrossEntropy(ignore_index=self.ignore_index) + else: + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) + + def forward(self, prediction_scores, masked_lm_labels): + if self.enable_parallel_cross_entropy: + if prediction_scores.shape[-1] == self.config.vocab_size: + warnings.warn( + f"enable_parallel_cross_entropy, the vocab_size should be splited: {prediction_scores.shape[-1]}, {self.config.vocab_size}" + ) + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) + + with paddle.amp.auto_cast(False): + masked_lm_loss = self.loss_func(prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2)) + + # skip ignore_index which loss == 0 + masked_lm_loss = masked_lm_loss[masked_lm_loss > 0] + loss = paddle.mean(masked_lm_loss) + + return loss + + +class Qwen2MoeLMHead(nn.Layer): + def __init__(self, config: Qwen2MoeConfig): + super(Qwen2MoeLMHead, self).__init__() + self.config = config + if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0: + vocab_size = config.vocab_size // config.tensor_parallel_degree + else: + vocab_size = config.vocab_size + + self.weight = self.create_parameter( + shape=[config.hidden_size, vocab_size], + dtype=paddle.get_default_dtype(), + ) + # Must set distributed attr for Tensor Parallel ! + self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False + if self.weight.is_distributed: + self.weight.split_axis = 1 + + def forward(self, hidden_states, tensor_parallel_output=None): + if self.config.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) + seq_length = self.config.seq_length + hidden_states = paddle.reshape_(hidden_states, [-1, seq_length, self.config.hidden_size]) + + if tensor_parallel_output is None: + tensor_parallel_output = self.config.tensor_parallel_output + + logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output) + return logits + + +class Qwen2MoeForCausalLM(Qwen2MoePretrainedModel): + enable_to_static_method = True + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: Qwen2MoeConfig): + super().__init__(config) + self.config = config + + self.qwen2_moe = Qwen2MoeModel(config) + self.lm_head = Qwen2MoeLMHead(config) + self.criterion = Qwen2MoePretrainingCriterion(config) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_experts + self.num_experts_per_tok = config.num_experts_per_tok + # Initialize weights and apply final processing + + if config.sliding_window: + self.config.sliding_window = False + logger.warning("We do not support sliding window attention for now.") + + def get_input_embeddings(self): + return self.qwen2_moe.embed_tokens + + def set_input_embeddings(self, value): + self.qwen2_moe.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.qwen2_moe = decoder + + def get_decoder(self): + return self.qwen2_moe + + def prepare_inputs_for_generation( + self, + input_ids, + use_cache=False, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + output_router_logits=False, + **kwargs + ): + batch_size, seq_length = input_ids.shape + position_ids = kwargs.get("position_ids", paddle.arange(seq_length).expand((batch_size, seq_length))) + attention_mask = kwargs.get("attention_mask", None) + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(axis=-1) + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "output_router_logits": output_router_logits, + } + ) + return model_inputs + + def _get_model_inputs_spec(self, dtype: str): + return { + "input_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"), + "attention_mask": paddle.static.InputSpec(shape=[None, None], dtype="int64"), + "position_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"), + } + + @staticmethod + def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False): + # update cache + if isinstance(outputs, tuple) and len(outputs) > 1 and not isinstance(outputs[1], paddle.Tensor): + model_kwargs["past_key_values"] = outputs[1] + + if isinstance(outputs, MoECausalLMOutputWithPast) and "past_key_values" in outputs: + model_kwargs["past_key_values"] = outputs.past_key_values + + # update position_ids + if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None: + position_ids = model_kwargs["position_ids"] + model_kwargs["position_ids"] = paddle.concat([position_ids, position_ids[..., -1:] + 1], axis=-1) + + if not is_encoder_decoder and "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = paddle.concat( + [attention_mask, paddle.ones([attention_mask.shape[0], 1], dtype=attention_mask.dtype)], axis=-1 + ) + + return model_kwargs + + def forward( + self, + input_ids=None, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + labels=None, + use_cache=False, + past_key_values=None, + output_attentions=None, + output_hidden_states=None, + output_router_logits: Optional[bool] = None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.qwen2_moe( + input_ids=input_ids, # [bs, seq_len] + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + ) + + hidden_states = outputs[0] # [bs, seq_len, dim] + + # if labels is None,means we need full output, instead of tensor_parallel_output + # tensor_parallel_output is togather with ParallelCrossEntropy + tensor_parallel_output = ( + self.config.tensor_parallel_output and labels is not None and self.config.tensor_parallel_degree > 1 + ) + + logits = self.lm_head(hidden_states, tensor_parallel_output=tensor_parallel_output) + + loss = None + if labels is not None: + loss = self.criterion(logits, labels) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoECausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) diff --git a/tests/transformers/qwen2moe/__init__.py b/tests/transformers/qwen2moe/__init__.py new file mode 100644 index 000000000000..fd05a9208165 --- /dev/null +++ b/tests/transformers/qwen2moe/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/transformers/qwen2moe/test_modeling.py b/tests/transformers/qwen2moe/test_modeling.py new file mode 100644 index 000000000000..d2b11e76f6b0 --- /dev/null +++ b/tests/transformers/qwen2moe/test_modeling.py @@ -0,0 +1,317 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import unittest + +import paddle + +from paddlenlp.transformers import Qwen2MoeConfig, Qwen2MoeForCausalLM, Qwen2MoeModel +from tests.transformers.test_configuration_common import ConfigTester +from tests.transformers.test_generation_utils import GenerationTesterMixin +from tests.transformers.test_modeling_common import ( + ModelTesterMixin, + ids_tensor, + random_attention_mask, +) + + +class Qwen2MoeModelTester: + def __init__( + self, + parent, + vocab_size=32000, + hidden_size=64, + num_hidden_layers=2, + num_attention_heads=8, + num_key_value_heads=8, + masked_softmax_fusion=True, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + is_training=True, + use_cache=False, + bos_token_id=1, + eos_token_id=2, + apply_residual_connection_post_layernorm=False, + hidden_dropout=0.0, + attention_dropout=0.0, + attention_softmax_in_fp32=True, + pretraining_tp=1, # TP rank used when training with megatron + dtype="bfloat16", + slow_but_exact=False, + batch_size: int = 2, + seq_length: int = 10, + type_sequence_label_size=2, + activation_function="gelu", + num_labels=3, + num_choices=4, + scope=None, + dropout=0.56, + use_input_mask: bool = False, + use_labels: bool = False, + return_dict=False, + ): + self.parent: Qwen2MoeModelTest = parent + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.masked_softmax_fusion = masked_softmax_fusion + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.is_training = is_training + self.use_cache = use_cache + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.pretraining_tp = pretraining_tp + self.dtype = dtype + self.slow_but_exact = slow_but_exact + + self.batch_size = batch_size + self.seq_length = seq_length + self.type_sequence_label_size = type_sequence_label_size + self.activation_function = activation_function + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + self.dropout = dropout + + self.use_input_mask = use_input_mask + self.use_labels = use_labels + self.return_dict = return_dict + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size, dtype=paddle.int64) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self) -> Qwen2MoeConfig: + return Qwen2MoeConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + masked_softmax_fusion=self.masked_softmax_fusion, + layer_norm_epsilon=self.layer_norm_epsilon, + initializer_range=self.initializer_range, + use_cache=self.use_cache, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm, + hidden_dropout=self.hidden_dropout, + attention_dropout=self.attention_dropout, + attention_softmax_in_fp32=self.attention_softmax_in_fp32, + pretraining_tp=self.pretraining_tp, + dtype=self.dtype, + slow_but_exact=self.slow_but_exact, + activation_function=self.activation_function, + ) + + def create_and_check_model( + self, config: Qwen2MoeConfig, input_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = Qwen2MoeModel(config) + model.eval() + result = model(input_ids) + self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.hidden_size]) + + def create_and_check_model_attention_mask( + self, config: Qwen2MoeConfig, input_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = Qwen2MoeModel(config) + model.eval() + attn_mask_2d = random_attention_mask([self.batch_size, self.seq_length]) + result_2d = model(input_ids, attention_mask=attn_mask_2d)[0] + batch, seq_length = input_ids.shape + causal_mask = paddle.tril(paddle.ones((batch, seq_length, seq_length), dtype=attn_mask_2d.dtype)) + attn_mask_3d = causal_mask & attn_mask_2d.unsqueeze(-1) + result_3d = model(input_ids, attention_mask=attn_mask_3d)[0] + attn_mask_4d = attn_mask_3d.unsqueeze(1) + result_4d = model(input_ids, attention_mask=attn_mask_4d)[0] + result_no_attention_mask = model(input_ids, attention_mask=None)[0] + # Assert non-padding tokens have the same logits with different attention_mask shape + self.parent.assertTrue((result_2d[attn_mask_2d] == result_3d[attn_mask_2d]).all()) + self.parent.assertTrue((result_2d[attn_mask_2d] == result_4d[attn_mask_2d]).all()) + self.parent.assertTrue((result_2d[attn_mask_2d] == result_no_attention_mask[attn_mask_2d]).all()) + + def create_and_check_model_past_large_inputs( + self, + config: Qwen2MoeConfig, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + model = Qwen2MoeModel(config) + model.eval() + + # first forward pass + outputs = model(input_ids, attention_mask=input_mask, use_cache=True, return_dict=self.return_dict) + past_key_values = outputs.past_key_values if self.return_dict else outputs[2] + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), self.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = paddle.concat([input_ids, next_tokens], axis=-1) + next_attention_mask = paddle.concat([input_mask, next_mask], axis=-1) + + outputs = model( + next_input_ids, attention_mask=next_attention_mask, output_hidden_states=True, return_dict=self.return_dict + ) + + output_from_no_past = outputs[2][0] + + outputs = model( + next_tokens, + attention_mask=next_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + return_dict=self.return_dict, + ) + + output_from_past = outputs[2][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(paddle.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + def create_and_check_lm_head_model(self, config, input_ids, input_mask, *args): + model = Qwen2MoeForCausalLM(config) + model.eval() + + result = model( + input_ids, + use_cache=True, + labels=input_ids if self.parent.use_labels else None, + return_dict=self.parent.return_dict, + ) + if self.parent.use_labels: + self.parent.assertIsInstance(result[0].item(), float) + self.parent.assertEqual(result[1].shape, [self.batch_size, self.seq_length, self.vocab_size]) + else: + self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.vocab_size]) + + def check_model_position_ids(self, config, input_ids, input_mask, *args): + model = Qwen2MoeForCausalLM(config) + model.eval() + + result_no_position_id = model( + input_ids, + labels=input_ids if self.parent.use_labels else None, + return_dict=self.parent.return_dict, + ) + batch_size, seq_len = input_ids.shape + position_ids = paddle.arange(seq_len).expand((batch_size, seq_len)) + result_position_id = model( + input_ids, + position_ids, + labels=input_ids if self.parent.use_labels else None, + return_dict=self.parent.return_dict, + ) + if self.parent.use_labels: + self.parent.assertTrue((result_position_id[1] == result_no_position_id[1]).all()) + else: + self.parent.assertTrue((result_position_id[0] == result_no_position_id[0]).all()) + + +class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + base_model_class = Qwen2MoeModel + return_dict = False + use_labels = False + use_test_model_name_list = False + + all_model_classes = (Qwen2MoeModel, Qwen2MoeForCausalLM) + all_generative_model_classes = {Qwen2MoeForCausalLM: (Qwen2MoeModel, "qwen2_moe")} + + def setUp(self): + super().setUp() + + self.model_tester = Qwen2MoeModelTester(self) + self.config_tester = ConfigTester(self, config_class=Qwen2MoeConfig, vocab_size=256, hidden_size=24) + + def _get_input_ids_and_config(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + input_ids = inputs_dict[self.input_name] + attention_mask = paddle.ones_like(input_ids, dtype=paddle.int64) + + max_batch_size = 2 + sequence_length = input_ids.shape[-1] // 2 + input_ids = input_ids[:max_batch_size, :sequence_length] + attention_mask = attention_mask[:max_batch_size, :sequence_length] + max_length = 3 + + return config, input_ids, attention_mask, max_length + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_attention_mask(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_attention_mask(*config_and_inputs) + + def test_model_position_ids(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_model_position_ids(*config_and_inputs) + + def test_generate_without_input_ids(self): + # this requires 4-D attention mask logic, which is not supported yet + pass + + def test_qwen2moe_lm_head_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_lm_head_model(*config_and_inputs)