diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index e53798bad01d..027e91eac95e 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -1653,7 +1653,7 @@ def forward( is_casual = True else: is_casual = is_casual_mask(attention_mask) - if get_env_device() != "npu" or get_env_device() != "mlu": + if get_env_device() not in ["npu", "mlu"]: if is_casual and alibi is None: attention_mask = None else: