Open
Description
Hello, I find Phi-3.5-MoE
model use the sparsemixer()
function to select the top-k experts and compute the weights, but I couldn't find this function implementation in the code. Could you give me some advices. Thanks!
def sparsemixer(scores, top_k, jitter_eps, training):
assert top_k == 2
################ first expert ################
with torch.no_grad():
# compute mask for sparsity
mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True)
factor = scores.abs().clamp(min=mask_logits_threshold)
mask_logits_threshold = (
(mask_logits_threshold - scores) / factor
) > (2 * jitter_eps)
# apply mask
masked_gates = scores.masked_fill(mask_logits_threshold, float('-inf'))
if training:
selected_experts = (
masked_gates - torch.empty_like(masked_gates, memory_format=torch.legacy_contiguous_format).exponential_().log()
).max(dim=-1)[1].unsqueeze(-1) # gumbel sampling, more robust than than the multinomial method
else:
selected_experts = max_ind
# compute scores for gradients
masked_gates = torch.softmax(masked_gates, dim=-1)
multiplier_o = masked_gates.gather(dim=-1, index=selected_experts)
if training:
# compute midpoint mask
max_scores, max_ind = masked_gates.max(dim=-1, keepdim=True)
mask_for_one = torch.logical_or(
selected_experts == max_ind,
torch.rand_like(max_scores) > 0.75 # Heun's third-order method: f(x) - f(0) = .25 f'(x) + .75 f'(x/3.)
)
# 1 -> 1.0 & 0 -> 1./3: lambda x: (x + 0.5) / 1.5
mask_for_one = torch.add(0.3333, mask_for_one, alpha=0.6667).type_as(masked_gates)
multiplier = mp.apply(
scores,
multiplier_o,
selected_experts,
masked_gates,
mask_for_one,
)
else:
multiplier = multiplier_o
# masked out first expert
masked_scores = torch.scatter(
scores,
-1,
selected_experts,
float('-inf'),
)
with torch.no_grad():
# compute mask for sparsity
mask_logits_threshold, max_ind = masked_scores.max(dim=-1, keepdim=True)
factor = scores.abs().clamp(min=mask_logits_threshold)
mask_logits_threshold = (
(mask_logits_threshold - scores) / factor
) > (2 * jitter_eps)
# apply mask
masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, float('-inf'))
if training:
selected_experts_top2 = (
masked_gates_top2 - torch.empty_like(masked_gates_top2, memory_format=torch.legacy_contiguous_format).exponential_().log()
).max(dim=-1)[1].unsqueeze(-1) # gumbel sampling, more robust than than the multinomial method
else:
selected_experts_top2 = max_ind
# compute scores for gradients
masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1)
multiplier_top2_o = masked_gates_top2.gather(dim=-1, index=selected_experts_top2)
if training:
# compute midpoint mask
max_scores, max_ind = masked_gates_top2.max(dim=-1, keepdim=True)
mask_for_one_top2 = torch.logical_or(
selected_experts_top2 == max_ind,
torch.rand_like(max_scores).uniform_() > 0.75 # Heun's third-order method: f(x) - f(0) = .25 f'(x) + .75 f'(x/3.)
)
# 1 -> 1.0 & 0 -> 1./3: lambda x: (x + 0.5) / 1.5
mask_for_one_top2 = torch.add(0.3333, mask_for_one_top2, alpha=0.6667).type_as(masked_gates_top2)
multiplier_top2 = mp.apply(
scores,
multiplier_top2_o,
selected_experts_top2,
masked_gates_top2,
mask_for_one_top2,
)
else:
multiplier_top2 = multiplier_top2_o
multiplier = torch.concat((multiplier, multiplier_top2), dim=-1)
selected_experts = torch.concat((selected_experts, selected_experts_top2), dim=-1)
# print(multiplier)
# print(selected_experts)
# print(jitter_eps)
return (
multiplier,
selected_experts,
)