Skip to content

Commit

Permalink
[Mixtral] Change mistral op order (huggingface#27955)
Browse files Browse the repository at this point in the history
up
  • Loading branch information
younesbelkada authored Dec 11, 2023
1 parent 4850aab commit 54d0b1c
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,10 +663,10 @@ def __init__(self, config: MixtralConfig):

self.act_fn = ACT2FN[config.hidden_act]

def forward(self, hidden_states, routing_weights):
def forward(self, hidden_states):
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
current_hidden_states = self.w2(current_hidden_states)
return routing_weights * current_hidden_states
return current_hidden_states


MISTRAL_ATTENTION_CLASSES = {
Expand Down Expand Up @@ -736,7 +736,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]

# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
Expand Down

0 comments on commit 54d0b1c

Please sign in to comment.