Skip to content

Commit

Permalink
reduce cpu host overhead when using moe
Browse files Browse the repository at this point in the history
  • Loading branch information
ranzhejiang committed May 29, 2024
1 parent 2fc702e commit 23ec4a1
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions deepspeed/moe/sharded_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def top1gating(logits: Tensor,
mask1 = einsum("s,se->se", used_token, mask1)

# gating decisions
exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')
exp_counts = torch.sum(mask1, dim=0).detach().to(logits.device)

# if we don't want to drop any tokens
if not drop_tokens:
Expand Down Expand Up @@ -322,7 +322,7 @@ def top2gating(logits: Tensor,
l_aux = torch.mean(me * ce) * num_experts * num_experts

# gating decisions
exp_counts = torch.sum(mask1 + mask2, dim=0)
exp_counts = torch.sum(mask1 + mask2, dim=0).detach().to(logits.device)

if drop_tokens:
# Calculate configured capacity and remove locations outside capacity from mask
Expand Down Expand Up @@ -366,7 +366,7 @@ def top2gating(logits: Tensor,
combine_weights = combine1_sec + combine2_sec
dispatch_mask = combine_weights.bool()

return l_aux, combine_weights, dispatch_mask, exp_counts.detach().to('cpu')
return l_aux, combine_weights, dispatch_mask, exp_counts


class TopKGate(Module):
Expand Down

0 comments on commit 23ec4a1

Please sign in to comment.