Skip to content

Commit

Permalink
Fix Mixtral Parity test to keep it consistent with Transformers. (mic…
Browse files Browse the repository at this point in the history
…rosoft#20210)

### Description
I recently opened a PR in hf transformers repo to fix an issue on the
indexing part.

huggingface/transformers#29857

onnx exporter was failing because of the tolist() conversion so we had
to remove it.

I found out that the code was also a part of our codebase so this PR is
to keep the code consistent.
  • Loading branch information
AdamLouly authored Apr 8, 2024
1 parent 908a76d commit 22a61a3
Showing 1 changed file with 2 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -293,15 +293,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if top_x.shape[0] == 0:
continue

# in torch it is faster to index using lists than torch tensors
top_x_list = top_x.tolist()
idx_list = idx.tolist()

# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
Expand Down

0 comments on commit 22a61a3

Please sign in to comment.