Skip to content

moe quant with dedicated kernels [wip] #2325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 113 additions & 1 deletion test/quantization/test_moe_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,26 @@
Int8WeightOnlyConfig,
LinearActivationQuantizedTensor,
quantize_,
PerRow,
PerTensor,
)
from torchao.quantization.utils import compute_error
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
is_sm_at_least_90,
)
from torchao.quantization.utils import compute_error

if torch.version.hip is not None:
pytest.skip(
"ROCm support for MoE quantization is under development",
allow_module_level=True,
)
from torchao.prototype.moe_quant.kernels import fp8_dq_moe_op
from torchao.quantization.utils import _fbgemm_available

torch.manual_seed(0)

class TestMoEQuantCompile(unittest.TestCase):
DEFAULT_PARAMS = (512, 256, 8, 2) # hidden_dim, expert_dim, num_experts, top_k
Expand Down Expand Up @@ -68,7 +74,6 @@ def _test_impl_moe_quant(
.to(device)
)
input = torch.randn(input_shape, dtype=torch.bfloat16, device=device)

out = model(input)

quantize_(model, config, cond_ffn_filter)
Expand Down Expand Up @@ -363,6 +368,113 @@ def test_fp8dq_base(self, name, num_tokens, fullgraph):
fullgraph=fullgraph,
)

class TestFusedMoEQuant(unittest.TestCase):
DEFAULT_PARAMS = (512, 256, 8, 2) # hidden_dim, expert_dim, num_experts, top_k

@parameterized.expand(
[
("multiple_tokens", 8),
]
)
def test_pytorch_scaled_grouped_gemm(self, name, num_tokens):
if not torch.cuda.is_available():
self.skipTest("Need CUDA available")
if not is_sm_at_least_90():
self.skipTest("Requires CUDA capability >= 9.0")

device = "cuda"
dtype = torch.bfloat16

config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()))

model_params = self.DEFAULT_PARAMS

input_shape = (num_tokens, model_params[0])
input = torch.randn(input_shape, dtype=torch.bfloat16, device=device)

model = (
MOEFeedForwardAOQuantizable(*model_params, empty_init=False)
)
model = model.to(dtype).to(device)

out_orig = model(input)

quantize_(model, config, cond_ffn_filter)

w1 = model.experts.w1
w2 = model.experts.w2
w3 = model.experts.w3

router = model.router
top_k = model.top_k

# preprocess
scores = router(input) # [T, E]
scores = torch.nn.functional.softmax(scores, dim=-1)
scores, expert_indices = torch.topk(
scores, top_k, dim=-1
) # [T, A], [T, A]
scores /= scores.sum(dim=-1, keepdim=True).to(input.dtype) # [T, A]

out = fp8_dq_moe_op(input, w1, w2, w3, expert_indices, scores)
out2 = model(input)

self.assertTrue(compute_error(out_orig, out) > 20)
self.assertTrue(compute_error(out_orig, out2) > 20)


@parameterized.expand(
[
("multiple_tokens", 8),
]
)
def test_fbgemm_scaled_grouped_gemm(self, name, num_tokens):
if not _fbgemm_available:
self.skipTest("Need FBGEMM available")
if not torch.cuda.is_available():
self.skipTest("Need CUDA available")
if not is_sm_at_least_90():
self.skipTest("Requires CUDA capability >= 9.0")

device = "cuda"
dtype = torch.bfloat16

config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()))

model_params = self.DEFAULT_PARAMS

input_shape = (num_tokens, model_params[0])
input = torch.randn(input_shape, dtype=torch.bfloat16, device=device)

model = (
MOEFeedForwardAOQuantizable(*model_params, empty_init=False, use_fbgemm_kernel=True)
)
model = model.to(dtype).to(device)

out_orig = model(input)

quantize_(model, config, cond_ffn_filter)

w1 = model.experts.w1
w2 = model.experts.w2
w3 = model.experts.w3

router = model.router
top_k = model.top_k

# preprocess
scores = router(input) # [T, E]
scores = torch.nn.functional.softmax(scores, dim=-1)
scores, expert_indices = torch.topk(
scores, top_k, dim=-1
) # [T, A], [T, A]
scores /= scores.sum(dim=-1, keepdim=True).to(input.dtype) # [T, A]

out = fp8_dq_moe_op(input, w1, w2, w3, expert_indices, scores, use_fbgemm_kernel=True)
out2 = model(input)

self.assertTrue(compute_error(out_orig, out) > 20)
self.assertTrue(compute_error(out_orig, out2) > 20)

if __name__ == "__main__":
unittest.main()
52 changes: 14 additions & 38 deletions torchao/_models/mixtral-moe/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
from typing import Optional

import torch
import torchao
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F

from torchao.prototype.moe_quant.utils import FakeExtraDimTensor

from torchao.quantization.utils import _torchtitan_available
from torchao.prototype.moe_quant.kernels import fp8_dq_moe_op

def find_multiple(n: int, k: int) -> int:
if n % k == 0:
Expand All @@ -34,6 +36,7 @@ class ModelArgs:
norm_eps: float = 1e-5
num_experts: int = 8
num_activated_experts: int = 2
use_fbgemm_kernel: bool = False

def __post_init__(self):
if self.n_local_heads == -1:
Expand Down Expand Up @@ -225,43 +228,6 @@ def forward(
y = self.wo(y)
return y


# class ConditionalFeedForward(nn.Module):
# def __init__(self, config):
# super().__init__()
# self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))
# self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size))
# self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))

# def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor:
# w1_weights = self.w1[expert_indices] # [T, A, D, D]
# w3_weights = self.w3[expert_indices] # [T, A, D, D]
# w2_weights = self.w2[expert_indices] # [T, A, D, D]
# x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights))
# x3 = torch.einsum('ti, taoi -> tao', x, w3_weights)
# expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights)
# return expert_outs


# class MOEFeedForward(nn.Module):
# def __init__(self, config) -> None:
# super().__init__()
# self.gate = nn.Linear(config.dim, config.num_experts, bias=False)
# self.cond_ffn = ConditionalFeedForward(config)
# self.dim = config.dim
# self.num_activated_experts = config.num_activated_experts
# def forward(self, x: Tensor) -> Tensor:
# x = x.view(-1, self.dim)
# # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts
# # x: [T, D]
# scores = self.gate(x) # [T, E]
# expert_weights = F.softmax(scores, dim=-1)
# expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A]
# expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A]
# expert_outs = self.cond_ffn(x, expert_indices)
# return torch.einsum('tai,ta -> ti', expert_outs, expert_weights)


class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
Expand Down Expand Up @@ -347,7 +313,9 @@ def __init__(self, config):
torch.empty(config.num_experts, config.intermediate_size, config.dim)
) # E, I, D
self.num_experts = config.num_experts
self.use_fbgemm_kernel = config.use_fbgemm_kernel

# TODO move this into kernels, single token decomposed kernel, multi token...etc
def forward(
self,
x: Tensor, # T, D
Expand Down Expand Up @@ -382,6 +350,14 @@ def forward(
.unsqueeze(-1)
)
return final_out
# fp8 dq moe
elif (
isinstance(self.w1, torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor) and
isinstance(self.w1.original_weight_tensor._layout, torchao.dtypes.floatx.float8_layout.Float8Layout)
):

final_out = fp8_dq_moe_op(x, self.w1, self.w2, self.w3, expert_indices, expert_weights, use_fbgemm_kernel=self.use_fbgemm_kernel)
return final_out
else:
expert_list = [x for x in range(self.num_experts)]

Expand Down
Loading
Loading