Skip to content

Commit

Permalink
MoE for mixtral 8x7b (#2535)
Browse files Browse the repository at this point in the history
* MoE for mixtral 8x7b
* removing bnb_sparse for now
  • Loading branch information
vince62s authored Dec 19, 2023
1 parent 2509d93 commit 05cde4d
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 13 deletions.
57 changes: 45 additions & 12 deletions onmt/decoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from onmt.modules import MultiHeadedAttention, AverageAttention
from onmt.modules.position_ffn import PositionwiseFeedForward
from onmt.modules.position_ffn import ActivationFunction
from onmt.modules.moe import MoE
from onmt.utils.misc import sequence_mask

try:
Expand Down Expand Up @@ -43,6 +44,8 @@ def __init__(
parallel_gpu=1,
sliding_window=0,
rotary_interleave=True,
num_experts=0,
num_experts_per_tok=2,
):
"""
Args:
Expand Down Expand Up @@ -109,18 +112,34 @@ def __init__(
d_model, dropout=attention_dropout, aan_useffn=aan_useffn
)

self.feed_forward = PositionwiseFeedForward(
d_model,
d_ff,
dropout,
pos_ffn_activation_fn,
add_ffnbias,
parallel_residual,
layer_norm,
norm_eps,
use_ckpting=use_ckpting,
parallel_gpu=parallel_gpu,
)
if num_experts > 0:
self.feed_forward = MoE(
num_experts,
num_experts_per_tok,
d_model,
d_ff,
dropout,
pos_ffn_activation_fn,
add_ffnbias,
parallel_residual,
layer_norm,
norm_eps,
use_ckpting=use_ckpting,
parallel_gpu=parallel_gpu,
)
else:
self.feed_forward = PositionwiseFeedForward(
d_model,
d_ff,
dropout,
pos_ffn_activation_fn,
add_ffnbias,
parallel_residual,
layer_norm,
norm_eps,
use_ckpting=use_ckpting,
parallel_gpu=parallel_gpu,
)
self.parallel_residual = parallel_residual
self.shared_layer_norm = shared_layer_norm
if layer_norm == "standard":
Expand Down Expand Up @@ -261,6 +280,8 @@ def __init__(
parallel_gpu=1,
sliding_window=0,
rotary_interleave=True,
num_experts=0,
num_experts_per_tok=2,
):
"""
Args:
Expand Down Expand Up @@ -290,6 +311,8 @@ def __init__(
parallel_gpu=parallel_gpu,
sliding_window=sliding_window,
rotary_interleave=rotary_interleave,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
)
self.context_attn = MultiHeadedAttention(
heads,
Expand Down Expand Up @@ -450,6 +473,8 @@ def from_opt(cls, opt, embeddings):
else 1,
sliding_window=opt.sliding_window,
rotary_interleave=opt.rotary_interleave,
num_experts=opt.num_experts,
num_experts_per_tok=opt.num_experts_per_tok,
)

def init_state(self, src, enc_out, enc_final_hs):
Expand Down Expand Up @@ -569,6 +594,8 @@ def __init__(
parallel_gpu=1,
sliding_window=0,
rotary_interleave=True,
num_experts=0,
num_experts_per_tok=2,
):
super(TransformerDecoder, self).__init__(
d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps
Expand Down Expand Up @@ -600,6 +627,8 @@ def __init__(
parallel_gpu=parallel_gpu,
sliding_window=sliding_window,
rotary_interleave=rotary_interleave,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
)
for i in range(num_layers)
]
Expand Down Expand Up @@ -836,6 +865,8 @@ def __init__(
parallel_gpu=1,
sliding_window=0,
rotary_interleave=True,
num_experts=0,
num_experts_per_tok=2,
):
super(TransformerLMDecoder, self).__init__(
d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps
Expand Down Expand Up @@ -866,6 +897,8 @@ def __init__(
parallel_gpu=parallel_gpu,
sliding_window=sliding_window,
rotary_interleave=rotary_interleave,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
)
for i in range(num_layers)
]
Expand Down
7 changes: 6 additions & 1 deletion onmt/modules/bnb_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
try:
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
from bitsandbytes import MatmulLtState
from bitsandbytes.nn import Linear4bit, Linear8bitLt, Params4bit, Int8Params
from bitsandbytes.nn import (
Linear4bit,
Linear8bitLt,
Params4bit,
Int8Params,
)
except ImportError:
raise ImportError("Install bitsandbytes to use 4/8bit compression")

Expand Down
63 changes: 63 additions & 0 deletions onmt/modules/moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""MoE mixture of experts"."""
import torch
import torch.nn as nn
from onmt.modules.position_ffn import PositionwiseFeedForward


class MoE(nn.Module):
def __init__(
self,
num_experts,
num_experts_per_tok,
d_model,
d_ff,
dropout,
pos_ffn_activation_fn,
add_ffnbias,
parallel_residual,
layer_norm,
norm_eps,
use_ckpting=[],
parallel_gpu=1,
):
super().__init__()
self.experts = nn.ModuleList(
[
PositionwiseFeedForward(
d_model,
d_ff,
dropout,
pos_ffn_activation_fn,
add_ffnbias,
parallel_residual,
layer_norm,
norm_eps,
use_ckpting=use_ckpting,
parallel_gpu=parallel_gpu,
)
for i in range(num_experts)
]
)
self.gate = nn.Linear(d_model, num_experts, bias=False)
self.num_experts_per_tok = num_experts_per_tok

def forward(self, x):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])

scores = self.gate(x)
expert_weights, expert_indices = torch.topk(
scores, self.num_experts_per_tok, dim=-1
)
expert_weights = expert_weights.softmax(dim=-1)
flat_expert_indices = expert_indices.view(-1)

x = x.repeat_interleave(self.num_experts_per_tok, dim=0)
y = torch.empty_like(x)
for i, expert in enumerate(self.experts):
if torch.any(flat_expert_indices == i):
y[flat_expert_indices == i] = expert(x[flat_expert_indices == i])
y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(
dim=1
)
return y.view(*orig_shape)
14 changes: 14 additions & 0 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,20 @@ def model_opts(parser):
default=2048,
help="Size of hidden transformer feed-forward",
)
group.add(
"--num_experts",
"-num_experts",
type=int,
default=0,
help="Number of experts",
)
group.add(
"--num_experts_per_tok",
"-num_experts_per_tok",
type=int,
default=2,
help="Number of experts per token",
)
group.add(
"--aan_useffn",
"-aan_useffn",
Expand Down

0 comments on commit 05cde4d

Please sign in to comment.