From 23ec4a1c80d1a3bde1d7031feac74c4cfd6da6e9 Mon Sep 17 00:00:00 2001 From: ranzhejiang Date: Wed, 29 May 2024 03:54:44 +0000 Subject: [PATCH] reduce cpu host overhead when using moe --- deepspeed/moe/sharded_moe.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index 96eab5e2ab17e..ea0e7abdcb131 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -206,7 +206,7 @@ def top1gating(logits: Tensor, mask1 = einsum("s,se->se", used_token, mask1) # gating decisions - exp_counts = torch.sum(mask1, dim=0).detach().to('cpu') + exp_counts = torch.sum(mask1, dim=0).detach().to(logits.device) # if we don't want to drop any tokens if not drop_tokens: @@ -322,7 +322,7 @@ def top2gating(logits: Tensor, l_aux = torch.mean(me * ce) * num_experts * num_experts # gating decisions - exp_counts = torch.sum(mask1 + mask2, dim=0) + exp_counts = torch.sum(mask1 + mask2, dim=0).detach().to(logits.device) if drop_tokens: # Calculate configured capacity and remove locations outside capacity from mask @@ -366,7 +366,7 @@ def top2gating(logits: Tensor, combine_weights = combine1_sec + combine2_sec dispatch_mask = combine_weights.bool() - return l_aux, combine_weights, dispatch_mask, exp_counts.detach().to('cpu') + return l_aux, combine_weights, dispatch_mask, exp_counts class TopKGate(Module):