diff --git a/llm/predict/predictor.py b/llm/predict/predictor.py index c3bd05fea133..5e09478be83b 100644 --- a/llm/predict/predictor.py +++ b/llm/predict/predictor.py @@ -134,7 +134,10 @@ class PredictorArgument: @property def total_max_length(self): - return 8192 # Maximum sequence length. + if self.device == "npu": + return self.src_length + self.max_length + else: + return 8192 # Maximum sequence length. @dataclass @@ -861,6 +864,35 @@ def init_model_inputs(self, config: PredictorArgument): self.model_inputs["tgt_mask"] = ( alibi_decoder + (1 - self.model_inputs["tgt_mask"]) * paddle.finfo(self.dtype).min ).cast(self.dtype) + elif config.device == "npu" and self.model_config.get("alibi", False): + lower_one_tril = paddle.tril( + paddle.ones(shape=(config.total_max_length, config.total_max_length), dtype=self.dtype) + ) + lower_one_tril = lower_one_tril[None, None, :, :] + src_mask = lower_one_tril.tile([config.batch_size, 1, 1, 1]) + tgt_mask = paddle.full( + shape=[config.batch_size, 1, 1, config.total_max_length], fill_value=1, dtype=self.dtype + ) + arange_tensor_encoder = paddle.arange(config.total_max_length).astype(self.dtype) + alibi_slopes = llm_utils.get_alibi_slopes(self.num_attention_heads) + alibi = alibi_slopes[None, :, None, None] * arange_tensor_encoder + alibi_encoder = alibi.tile([config.batch_size, 1, config.total_max_length, 1]) + alibi_decoder = alibi.tile( + [ + config.batch_size, + 1, + 1, + 1, + ] + ) + # self.model_inputs["src_mask/tgt_mask"] is read only, will not be updated! + src_mask = ( + alibi_encoder + (1 - src_mask) * paddle.finfo(self.dtype).min + ).cast(self.dtype) + tgt_mask = ( + alibi_decoder + (1 - tgt_mask) * paddle.finfo(self.dtype).min + ).cast(self.dtype) + self.model_inputs["rope_emb"] = paddle.concat([src_mask.reshape([-1]), tgt_mask.reshape([-1])]) def _preprocess(self, input_text: list[str]): if self.tokenizer.chat_template is not None: diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index 099436ef6304..6100c2080a74 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -39,8 +39,7 @@ "The paddlenlp_ops package is not installed. you can read the docs and install it by hand, " "you can refer to: https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md" ) -if core.is_compiled_with_xpu() or core.is_compiled_with_cuda(): - from paddlenlp_ops import rebuild_padding_v2 +from paddlenlp_ops import rebuild_padding_v2 if core.is_compiled_with_cuda(): if os.getenv("FLAGS_CUTLASS_FP8_GEMM", "False") == "True": diff --git a/paddlenlp/utils/llm_utils.py b/paddlenlp/utils/llm_utils.py index 6ef5aae9dfb0..7a0813828397 100644 --- a/paddlenlp/utils/llm_utils.py +++ b/paddlenlp/utils/llm_utils.py @@ -461,7 +461,7 @@ def get_alibi_slopes(num_heads): extra_base = 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))) num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) extra_powers = np.arange(1, 1 + 2 * num_remaining_heads, 2) - slopes = np.concatante([slopes, np.power(extra_base, extra_powers)], axis=0) + slopes = np.concatenate([slopes, np.power(extra_base, extra_powers)], axis=0) return slopes.astype("float32")