Skip to content

Commit

Permalink
reorder operations
Browse files Browse the repository at this point in the history
  • Loading branch information
scv119 committed Jan 17, 2024
1 parent f955162 commit 1089dd8
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions vllm/model_executor/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
expanded_hidden_states, experts_range, self.w1s.data,
self.w2s.data, self.w3s.data)

# Step 3: apply weights to the output of each expert, and reduce
# across ranks.
# Step 3: apply weights to the output of each expert
expanded_hidden_states.mul_(expanded_weights.unsqueeze(-1))
tensor_model_parallel_all_reduce(expanded_hidden_states)

# Step 4: merge the output of each expert, according to the indices.
return self.merge_expert_outputs(expanded_hidden_states,
reverse_indices).view(
batch_size, sequence_length,
hidden_size)
merged_hidden_states = self.merge_expert_outputs(
expanded_hidden_states,
reverse_indices).view(batch_size, sequence_length, hidden_size)

# Step 5: reduce across ranks.
tensor_model_parallel_all_reduce(merged_hidden_states)
return merged_hidden_states

def expand_and_permutate_hidden_states(
self,
Expand Down

0 comments on commit 1089dd8

Please sign in to comment.