Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change llama/modeling.py to opt npu performence #8342

Merged
merged 7 commits into from
Apr 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def swiglu(x, y=None):
"LlamaPretrainingCriterion",
]

npu_is_casual = False

def _get_interleave(n):
def _get_interleave_power_of_2(n):
Expand Down Expand Up @@ -244,7 +245,7 @@ def scaled_dot_product_attention(
attention_mask is None,
True,
False,
False,
npu_is_casual,
)[0]
else:
attn_output = F.scaled_dot_product_attention(
Expand Down Expand Up @@ -1118,6 +1119,7 @@ def __init__(self, config, layerwise_recompute: bool = False):
self.layerwise_recompute = layerwise_recompute
self.recompute_granularity = config.recompute_granularity


def forward(
self,
hidden_states: paddle.Tensor,
Expand Down Expand Up @@ -1612,11 +1614,12 @@ def forward(
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]
if self.config.use_flash_attention:
is_casual = is_casual_mask(attention_mask)
ZHUI marked this conversation as resolved.
Show resolved Hide resolved
if get_env_device() != "npu":
is_casual = is_casual_mask(attention_mask)
if is_casual and alibi is None:
attention_mask = None
else:
npu_is_casual = is_casual
attention_mask = attention_mask.astype("bool")
hidden_states = inputs_embeds
# decoder layers
Expand Down Expand Up @@ -1722,9 +1725,12 @@ def forward(self, prediction_scores, masked_lm_labels):
_hcg = fleet.get_hybrid_communicate_group()
masked_lm_loss = ConcatSePMaskedLoss.apply(masked_lm_loss, axis=1, group=_hcg.get_sep_parallel_group())
# skip ignore_index which loss == 0
masked_lm_loss = masked_lm_loss[masked_lm_loss > 0]
loss = paddle.mean(masked_lm_loss)

# masked_lm_loss = masked_lm_loss[masked_lm_loss > 0]
# loss = paddle.mean(masked_lm_loss)
binary_sequence = paddle.where(masked_lm_loss > 0, paddle.ones_like(masked_lm_loss), paddle.zeros_like(masked_lm_loss))
sum_ = paddle.sum(binary_sequence)
loss = 0 if sum_ == 0 else paddle.sum(masked_lm_loss * binary_sequence) / sum_

return loss


Expand Down
Loading