Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add moe topk(k>2) gate support #5881

Merged
merged 8 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refine
  • Loading branch information
inkcherry committed Aug 8, 2024
commit 153af8e2c7d84391a4b67754b67c2bedb19b5196
21 changes: 5 additions & 16 deletions deepspeed/moe/sharded_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,32 +386,27 @@ def topkgating(
# everything is in fp32 in this function
# get topk gates
top_gate, top_idx = torch.topk(logits, k=k, dim=1)
# gating decisions
gates = F.softmax(logits, dim=1)
num_experts = int(gates.shape[1])

# get topk mask
topk_masked_gates = torch.zeros_like(logits).scatter(1, top_idx, top_gate)

mask = torch.zeros_like(gates, dtype=torch.int64).scatter_(1, top_idx, 1)

# Compute tokens per expert
exp_counts = torch.sum(mask, dim=0).detach()
mask = torch.zeros_like(gates, dtype=torch.bool).scatter_(1, top_idx, 1)

# gating decisions
exp_counts = torch.sum(mask, dim=0).detach().to(logits.device)

# Compute l_aux
me = torch.mean(gates, dim=0)
# HPU Enable Begin
ce = torch.mean(mask.float(), dim=0, dtype=torch.float)
# HPU Enable End
ce = torch.mean(mask.float(), dim=0)
l_aux = torch.mean(me * ce) * num_experts * num_experts / k

if drop_tokens:
# Calculate configured capacity and remove locations outside capacity from mask
capacity = _capacity(gates, torch.tensor(capacity_factor * k), torch.tensor(min_capacity))
# update mask and locations by capacity
# mask *= torch.lt(locations, capacity)

if drop_policy == 'probs':
capacity_probs, capacity_indices = torch.topk(topk_masked_gates, k=capacity, dim=0, sorted=False)
capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1)
Expand All @@ -421,7 +416,7 @@ def topkgating(
elif drop_policy == "position":
locations = torch.cumsum(mask, dim=0) - 1
mask *= torch.lt(locations, capacity)
else:
else:
raise ValueError(f"Invalid drop_policy: {drop_policy}")

else:
Expand All @@ -436,7 +431,6 @@ def topkgating(
new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype)
capacity = new_capacity


# normalize gates
gates_masked = gates * mask
gates_s = torch.sum(gates_masked, dim=-1, keepdim=True)
Expand Down Expand Up @@ -485,11 +479,6 @@ def __init__(self,
top2_2nd_expert_sampling: bool = True) -> None:
super().__init__()

# Only top-1 and top-2 are supported at the moment.
#if k != 1 and k != 2:
# raise ValueError('Only top-1 and top-2 gatings are supported.')
# HPU Enable Begin
# self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
self.ep_group = ep_group
self.k = k
Expand Down
76 changes: 38 additions & 38 deletions tests/unit/moe/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from unit.simple_model import SimplePRMoEModel, SimpleMoEModel, sequence_dataloader
import deepspeed.comm as dist
from deepspeed import get_accelerator
from deepspeed.moe.sharded_moe import top1gating,topkgating
from deepspeed.moe.sharded_moe import top1gating, topkgating
from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer, is_moe_param
from deepspeed.utils.torch import required_torch_version

Expand Down Expand Up @@ -174,6 +174,7 @@ def test(self, ep_size, use_residual):
model.backward(loss)
model.step()


class TestTopk(DistributedTest):
world_size = 2

Expand All @@ -192,51 +193,50 @@ def test(self):
drop_tokens=False,
use_rts=True,
use_tutel=False)



class TestTopkGate(DistributedTest):

def test(self):
def check_equal(logits,cap,sparse_truth,res):
m,n=logits.shape
dispatch_mask_truth=torch.zeros(m,n,cap)
i,j,k=sparse_truth.t()
dispatch_mask_truth[i,j,k]=1
assert(torch.equal(dispatch_mask_truth, res))

def check_equal(logits, cap, sparse_truth, res):
m, n = logits.shape
dispatch_mask_truth = torch.zeros(m, n, cap)
i, j, k = sparse_truth.t()
dispatch_mask_truth[i, j, k] = 1
assert (torch.equal(dispatch_mask_truth, res))

#s=4 e=4 topk=2 cap=2(s*topk/e)
logits=torch.tensor([
[0.11,0.2,0.1,0.3],
[0.3,0.4,0.11,0.1],
[0.11,0.1,0.6,0.5],
[0.1,0.11,0.7,0.8]])
logits*=dist.get_rank()+1
probs_dispatch_res=topkgating(logits,2,1,min_capacity=1,drop_policy='probs')[2]
logits = torch.tensor([[0.11, 0.2, 0.1, 0.3], [0.3, 0.4, 0.11, 0.1], [0.11, 0.1, 0.6, 0.5],
[0.1, 0.11, 0.7, 0.8]])
logits *= dist.get_rank() + 1
probs_dispatch_res = topkgating(logits, 2, 1, min_capacity=1, drop_policy='probs')[2]
probs_sec_sparse = torch.tensor([[0, 1, 0], [1, 0, 0], [1, 1, 1], [2, 2, 0], [2, 3, 0], [3, 2, 1], [3, 3, 1]])
check_equal(logits,2,probs_sec_sparse,probs_dispatch_res)


position_sec_sparse =torch.tensor([[0,1,0],[0,3,0],[1,0,0],[1,1,1],[2,2,0],[2,3,1],[3,2,1]])
position_dispatch_res=topkgating(logits,2,1,min_capacity=1,drop_policy='position')[2]
check_equal(logits,2,position_sec_sparse,position_dispatch_res)

check_equal(logits, 2, probs_sec_sparse, probs_dispatch_res)

position_sec_sparse = torch.tensor([[0, 1, 0], [0, 3, 0], [1, 0, 0], [1, 1, 1], [2, 2, 0], [2, 3, 1],
[3, 2, 1]])
position_dispatch_res = topkgating(logits, 2, 1, min_capacity=1, drop_policy='position')[2]
check_equal(logits, 2, position_sec_sparse, position_dispatch_res)

#s=4 e=6 topk=3 cap=2(s*topk/e)
logits2=torch.tensor([[0.5858, 0.4801, 0.6269, 0.5397, 0.9722, 0.7034],
[0.5445, 0.6332, 0.4519, 0.6308, 0.0519, 0.6450],
[0.4874, 0.8110, 0.7467, 0.8474, 0.0277, 0.3068],
[0.8570, 0.6714, 0.5310, 0.3274, 0.4836, 0.9892]])
logits2*=dist.get_rank()+1
logits2 = torch.tensor([[0.5858, 0.4801, 0.6269, 0.5397, 0.9722, 0.7034],
[0.5445, 0.6332, 0.4519, 0.6308, 0.0519, 0.6450],
[0.4874, 0.8110, 0.7467, 0.8474, 0.0277, 0.3068],
[0.8570, 0.6714, 0.5310, 0.3274, 0.4836, 0.9892]])
logits2 *= dist.get_rank() + 1

#top3 full mask #prob_mask #postion_mask
#0 0 1 0 1 1 #0 0 1 0 1 1 #0 0 1 0 1 1
#0 0 1 0 1 1 #0 0 1 0 1 1 #0 0 1 0 1 1
#0 1 0 1 0 1 #0 0 0 1 0 0 #0 1 0 1 0 1
#0 1 1 1 0 0 #0 1 1 1 0 0 #0 1 1 1 0 0
#1 1 0 0 0 1 #1 1 0 0 0 1 #1 0 0 0 0 0
probs_dispatch_res=topkgating(logits2,3,1,min_capacity=1,drop_policy='probs')[2]
probs_sec_sparse = torch.tensor([[0,2,0],[0,4,0],[0,5,0],[1,3,0],[2,1,0],[2,2,1],[2,3,1],[3,0,0],[3,1,1],[3,5,1]])
check_equal(logits2,2, probs_sec_sparse,probs_dispatch_res)

position_sec_sparse =torch.tensor([[0,2,0],[0,4,0],[0,5,0],[1,1,0],[1,3,0],[1,5,1],[2,1,1],[2,2,1],[2,3,1],[3,0,0]])
position_dispatch_res=topkgating(logits2,3,1,min_capacity=1,drop_policy='position')[2]
check_equal(logits2,2,position_sec_sparse,position_dispatch_res)



probs_dispatch_res = topkgating(logits2, 3, 1, min_capacity=1, drop_policy='probs')[2]
probs_sec_sparse = torch.tensor([[0, 2, 0], [0, 4, 0], [0, 5, 0], [1, 3, 0], [2, 1, 0], [2, 2, 1], [2, 3, 1],
[3, 0, 0], [3, 1, 1], [3, 5, 1]])
check_equal(logits2, 2, probs_sec_sparse, probs_dispatch_res)

position_sec_sparse = torch.tensor([[0, 2, 0], [0, 4, 0], [0, 5, 0], [1, 1, 0], [1, 3, 0], [1, 5, 1],
[2, 1, 1], [2, 2, 1], [2, 3, 1], [3, 0, 0]])
position_dispatch_res = topkgating(logits2, 3, 1, min_capacity=1, drop_policy='position')[2]
check_equal(logits2, 2, position_sec_sparse, position_dispatch_res)