Skip to content

Commit

Permalink
Support rms_norm_mlu (#8504)
Browse files Browse the repository at this point in the history
  • Loading branch information
PeiyuLau authored Sep 2, 2024
1 parent e204b6d commit a12781f
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
4 changes: 3 additions & 1 deletion paddlenlp/transformers/llama/fusion_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def swiglu(x, y=None):
except ImportError:
fused_rotary_position_embedding = None
try:
if get_env_device() in ["npu", "gcu"]:
if get_env_device() in ["npu", "mlu", "gcu"]:
from paddle.base import core

for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
Expand Down Expand Up @@ -128,6 +128,8 @@ def rms_norm_fused(x_in, w, eps, use_fast_ln=False):
def fusion_rms_norm(hidden_states, weight, variance_epsilon, use_fast_ln=False):
if get_env_device() == "npu":
return core.eager._run_custom_op("rms_norm_npu", hidden_states, weight, variance_epsilon)[0]
if get_env_device() == "mlu":
return core.eager._run_custom_op("rms_norm_mlu", hidden_states, weight, variance_epsilon)[0]
elif get_env_device() == "gcu":
return core.eager._run_custom_op("rms_norm_gcu", hidden_states, weight, variance_epsilon)[0]
elif get_env_device() == "xpu":
Expand Down
14 changes: 7 additions & 7 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def swiglu(x, y=None):
)

try:
if get_env_device() in ["npu", "gcu"]:
if get_env_device() in ["npu", "mlu", "gcu"]:

for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
if lib.endswith(".so"):
Expand Down Expand Up @@ -318,7 +318,7 @@ def _make_causal_mask(input_ids_shape, past_key_values_length):
"""
batch_size, target_length = input_ids_shape # target_length: seq_len

if get_env_device() == "npu":
if get_env_device() == "npu" or get_env_device() == "mlu":
mask = paddle.tril(paddle.ones((target_length, target_length))).astype("int32")
else:
mask = paddle.tril(paddle.ones((target_length, target_length), dtype="bool"))
Expand All @@ -338,7 +338,7 @@ def _expand_2d_mask(mask, dtype, tgt_length):
batch_size, src_length = mask.shape[0], mask.shape[-1]
tgt_length = tgt_length if tgt_length is not None else src_length

if get_env_device() == "npu":
if get_env_device() == "npu" or get_env_device() == "mlu":
mask = mask[:, None, None, :].astype(dtype)
else:
mask = mask[:, None, None, :].astype("bool")
Expand Down Expand Up @@ -704,7 +704,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
)

self.use_fused_rope = config.use_fused_rope
if self.use_fused_rope and get_env_device() not in ["npu", "xpu", "gcu"]:
if self.use_fused_rope and get_env_device() not in ["npu", "mlu", "xpu", "gcu"]:
if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None:
warnings.warn(
"Enable fuse rope in the config, but fuse rope is not available. "
Expand Down Expand Up @@ -1485,7 +1485,7 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
combined_attention_mask = _make_causal_mask(
input_shape, past_key_values_length=past_key_values_length
)
if get_env_device() == "npu":
if get_env_device() == "npu" or get_env_device() == "mlu":
expanded_attn_mask = expanded_attn_mask.astype("bool") & combined_attention_mask.astype("bool")
else:
expanded_attn_mask = expanded_attn_mask & combined_attention_mask
Expand All @@ -1498,7 +1498,7 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
else:
expanded_attn_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
# Convert bool attention_mask to float attention mask, which will be added to attention_scores later
if get_env_device() == "npu":
if get_env_device() == "npu" or get_env_device() == "mlu":
x = paddle.to_tensor(0.0, dtype="float32")
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32")
expanded_attn_mask = expanded_attn_mask.astype("float32")
Expand Down Expand Up @@ -1653,7 +1653,7 @@ def forward(
is_casual = True
else:
is_casual = is_casual_mask(attention_mask)
if get_env_device() != "npu":
if get_env_device() != "npu" or get_env_device() != "mlu":
if is_casual and alibi is None:
attention_mask = None
else:
Expand Down
2 changes: 2 additions & 0 deletions paddlenlp/utils/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def get_env_device():
return "gpu"
elif "npu" in paddle.device.get_all_custom_device_type():
return "npu"
elif "mlu" in paddle.device.get_all_custom_device_type():
return "mlu"
elif "gcu" in paddle.device.get_all_custom_device_type():
return "gcu"
elif paddle.is_compiled_with_rocm():
Expand Down

0 comments on commit a12781f

Please sign in to comment.