Skip to content

Commit

Permalink
Fix graph breaks in Mixtral (#65) (huggingface#1705)
Browse files Browse the repository at this point in the history
  • Loading branch information
ShengYang1 authored Feb 5, 2025
1 parent d214819 commit 6a520ff
Showing 1 changed file with 19 additions and 21 deletions.
40 changes: 19 additions & 21 deletions optimum/habana/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

"""PyTorch Mixtral model."""

import contextlib
import math
import os
from typing import List, Optional, Tuple, Union
Expand Down Expand Up @@ -76,18 +75,12 @@
print("Not using HPU fused kernel for apply_rotary_pos_emb")
FusedRoPE = None

try:
from habana_frameworks.torch.hpu import sdp_kernel

SDPContext = True
except ImportError:
SDPContext = False

deepspeed_available = is_deepspeed_available()
logger = logging.get_logger(__name__)


def apply_customized_rope(q, k, cos, sin, position_ids, training=True):
if q.device.type == "hpu" and FusedRoPE:
if q.device.type == "hpu" and FusedRoPE is not None:
return apply_customized_rope_module(q, k, cos, sin, position_ids, training)
else:
return apply_rotary_pos_emb(q, k, cos, sin, position_ids)
Expand All @@ -99,7 +92,7 @@ def gaudi_mixtral_rmsnorm_forward(self, hidden_states):
The only differences are:
- override RMSNorm with Habana fused RMSNorm
"""
if hidden_states.device.type == "hpu" and FusedRMSNorm:
if hidden_states.device.type == "hpu" and FusedRMSNorm is not None:
# mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype
if hidden_states.dtype != self.weight.dtype:
orig_dtype = hidden_states.dtype
Expand Down Expand Up @@ -307,7 +300,7 @@ def forward(
else:
past_key_value = None

if FusedSDPA:
if FusedSDPA is not None:
if query_states.dtype != key_states.dtype:
key_states = key_states.type(query_states.dtype)
value_states = value_states.type(query_states.dtype)
Expand All @@ -324,12 +317,17 @@ def forward(
)
htcore.mark_step()
else:
with (
sdp_kernel(enable_recompute=flash_attention_recompute) if SDPContext else contextlib.nullcontext()
):
attn_output = FusedSDPA.apply(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)
attn_output = FusedSDPA.apply(
query_states,
key_states,
value_states,
attention_mask,
0.0,
False,
None,
"None",
flash_attention_recompute,
)
else:
query_states, key_states, value_states, attention_mask = gaudi_mixtral_repeat_kv(
query_states, key_states, value_states, attention_mask, self.num_key_value_groups
Expand All @@ -353,7 +351,7 @@ def forward(

attn_output = self.o_proj(attn_output)

if not output_attentions or FusedSDPA:
if not output_attentions or FusedSDPA is not None:
attn_weights = None

return attn_output, attn_weights, past_key_value
Expand All @@ -379,7 +377,7 @@ def gaudi_mixtral_block_sparse_moe_forward(self, hidden_states: torch.Tensor) ->
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)

if is_deepspeed_available() and (not self.training):
if deepspeed_available and (not self.training):
from deepspeed import comm as dist

if dist.is_initialized():
Expand Down Expand Up @@ -427,7 +425,7 @@ def gaudi_mixtral_block_dynamic_moe_forward(self, hidden_states: torch.Tensor) -
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)

if is_deepspeed_available() and (not self.training):
if deepspeed_available and (not self.training):
from deepspeed import comm as dist

if dist.is_initialized():
Expand All @@ -453,7 +451,7 @@ def gaudi_mixtral_block_dynamic_moe_forward(self, hidden_states: torch.Tensor) -
experts_min=0,
experts_max=7,
)
if is_deepspeed_available() and (not self.training):
if deepspeed_available and (not self.training):
from deepspeed import comm as dist

if dist.is_initialized():
Expand Down

0 comments on commit 6a520ff

Please sign in to comment.