Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable/disable moe token dropping. #1492

Merged
merged 3 commits into from
Oct 28, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add a flag to enable/disable token dropping in moe/top-1 gating.
  • Loading branch information
awan-10 committed Oct 27, 2021
commit 2d1ce86136dc6256632fe306cc334c8f1463e8f0
6 changes: 4 additions & 2 deletions deepspeed/moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def __init__(self,
capacity_factor=1.,
eval_capacity_factor=1.,
min_capacity=4,
noisy_gate_policy: typing.Optional[str] = None):
noisy_gate_policy: typing.Optional[str] = None,
drop_tokens: bool = True):
"""Initialize an MoE layer.

Arguments:
Expand Down Expand Up @@ -66,7 +67,8 @@ def __init__(self,
capacity_factor,
eval_capacity_factor,
min_capacity,
noisy_gate_policy),
noisy_gate_policy,
drop_tokens),
experts,
num_local_experts,
group=groups.get_expert_parallel_group())
Expand Down
17 changes: 12 additions & 5 deletions deepspeed/moe/sharded_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,8 @@ def top1gating(logits: torch.Tensor,
capacity_factor: float,
min_capacity: int,
used_token: torch.Tensor = None,
noisy_gate_policy: Optional[str] = None) -> Tuple[Tensor,
Tensor,
Tensor]:
noisy_gate_policy: Optional[str] = None
drop_tokens: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
"""Implements Top1Gating on logits."""
if noisy_gate_policy == 'RSample':
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
Expand Down Expand Up @@ -167,6 +166,12 @@ def top1gating(logits: torch.Tensor,
# gating decisions
exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')

# if we don't want to drop any tokens
if not drop_tokens:
new_capacity = torch.max(exp_counts).to(logits.device)
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.group.WORLD)
capacity = new_capacity

# Compute l_aux
me = torch.mean(gates, dim=0)
ce = torch.mean(mask1.float(), dim=0)
Expand Down Expand Up @@ -306,7 +311,8 @@ def __init__(self,
capacity_factor: float = 1.0,
eval_capacity_factor: float = 1.0,
min_capacity: int = 4,
noisy_gate_policy: Optional[str] = None) -> None:
noisy_gate_policy: Optional[str] = None,
drop_tokens: bool = True) -> None:
super().__init__()

# Only top-1 and top-2 are supported at the moment.
Expand Down Expand Up @@ -347,7 +353,8 @@ def forward(
self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity,
used_token,
self.noisy_gate_policy if self.training else None)
self.noisy_gate_policy if self.training else None,
drop_tokens)

else:
gate_output = top2gating(
Expand Down