Skip to content

[ Kernel ] AWQ Fused MoE #6422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _zeros,
int64_t split_k_iters);

torch::Tensor awq_fused_moe(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros, torch::Tensor _topk_weights,
torch::Tensor _sorted_token_ids_ptr,
torch::Tensor _expert_ids_ptr,
torch::Tensor _num_tokens_post_padded,
bool mul_weights, int64_t split_k_iters);

torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros, int64_t split_k_iters,
Expand Down
424 changes: 417 additions & 7 deletions csrc/quantization/awq/gemm_kernels.cu

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("awq_gemm", &awq_gemm);
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);

// Quantized Grouped GEMM for AWQ.
ops.def("awq_fused_moe", &awq_fused_moe);
ops.impl("awq_fused_moe", torch::kCUDA, &awq_fused_moe);

// Dequantization for AWQ.
ops.def("awq_dequantize", &awq_dequantize);
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
Expand Down
96 changes: 95 additions & 1 deletion tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

from vllm import _custom_ops as ops
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe import (fused_experts_awq, fused_moe,
fused_topk)
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.models.mixtral import MixtralMoE


Expand Down Expand Up @@ -99,3 +102,94 @@ def test_mixtral_moe(dtype: torch.dtype):
vllm_states,
rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype])


def torch_moe_awq(a, w1, w1_scale, w1_zero, w2, w2_scale, w2_zero, score,
topk):
score = torch.softmax(score.float(), dim=-1)
topk_weight, topk_ids = torch.topk(score, topk)
(B, D) = a.shape
a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D)
out = torch.zeros(B * topk_ids.shape[1],
w2.shape[2] * 8,
dtype=a.dtype,
device=a.device)
topk_ids = topk_ids.view(-1)
topk_weight = topk_weight.view(-1)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
dw1 = ops.awq_dequantize(w1[i], w1_scale[i], w1_zero[i], 0, 0, 0)
dw2 = ops.awq_dequantize(w2[i], w2_scale[i], w2_zero[i], 0, 0, 0)
r1 = SiluAndMul()(torch.matmul(a[mask].half(), dw1))
out[mask] = torch.matmul(r1, dw2).to(out.dtype)
return (out.view(B, -1, w2.shape[2] * 8) *
topk_weight.view(B, -1, 1)).sum(dim=1).half()


@pytest.mark.parametrize("m", [1024, 222, 33, 1])
@pytest.mark.parametrize("n", [2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 512, 1024])
@pytest.mark.parametrize("e", [8])
@pytest.mark.parametrize("topk", [2, 6])
def test_fused_moe_awq(
m: int,
n: int,
k: int,
e: int,
topk: int,
):
# awq requires minimum capability 75
if torch.version.hip is not None:
return
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < 75:
return

RANGE = 1000000000
groupsize = 128
a = torch.randn((m, k), device='cuda', dtype=torch.half) / 10
qw1 = torch.randint(-RANGE,
RANGE, (e, k, n * 2 // 8),
dtype=torch.int,
device='cuda')
qw2 = torch.randint(-RANGE,
RANGE, (e, n, k // 8),
dtype=torch.int,
device='cuda')

scale1 = torch.randn(
(e, k // groupsize, n * 2), dtype=torch.half, device='cuda') / 50
scale2 = torch.randn(
(e, n // groupsize, k), dtype=torch.half, device='cuda') / 50

zero1 = torch.randint(-RANGE,
RANGE, (e, k // groupsize, (n * 2 // 32) * 4),
dtype=torch.int32,
device='cuda')
zero2 = torch.randint(-RANGE,
RANGE, (e, n // groupsize, (k // 32) * 4),
dtype=torch.int32,
device='cuda')
w1 = {"qweight": qw1, "scales": scale1, "qzeros": zero1}
w2 = {"qweight": qw2, "scales": scale2, "qzeros": zero2}

score = torch.randn((m, e), device='cuda', dtype=torch.half)

quant_config = AWQConfig(4, groupsize, False)
torch_output = torch_moe_awq(a, qw1, scale1, zero1, qw2, scale2, zero2,
score, topk)

topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
cuda_output = fused_experts_awq(hidden_states=a,
w1=w1["qweight"],
w2=w2["qweight"],
w1_scales=w1["scales"],
w2_scales=w2["scales"],
w1_qzeros=w1["qzeros"],
w2_qzeros=w2["qzeros"],
topk_weights=topk_weights,
topk_ids=topk_ids,
pack_factor=quant_config.pack_factor)
assert torch.allclose(cuda_output, torch_output, atol=1e-2, rtol=0)
11 changes: 11 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,17 @@ def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)


def awq_fused_moe(input: torch.Tensor, qweight: torch.Tensor,
scales: torch.Tensor, qzeros: torch.Tensor,
topk_weights: torch.Tensor, sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor, num_tokens_post_padded: int,
mul_weights: bool, pack_factor: int) -> torch.Tensor:
return torch.ops._C.awq_fused_moe(input, qweight, scales, qzeros,
topk_weights, sorted_token_ids,
expert_ids, num_tokens_post_padded,
mul_weights, pack_factor)


# gptq
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
Expand Down
6 changes: 4 additions & 2 deletions vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@
]

if HAS_TRITON:

from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_moe, fused_topk, get_config_file_name,
grouped_topk)
from vllm.model_executor.layers.fused_moe.fused_moe_awq import (
fused_experts_awq)

__all__ += [
"fused_experts_awq",
"fused_moe",
"fused_topk",
"fused_experts",
"fused_topk",
"get_config_file_name",
"grouped_topk",
]
87 changes: 87 additions & 0 deletions vllm/model_executor/layers/fused_moe/fused_moe_awq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Fused MoE utilities for AWQ."""
import torch

from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, moe_align_block_size)

logger = init_logger(__name__)

NAIVE_THRESHOLD = 1024
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a bit high and it is worth commenting how it was calibrated (what model, benchmark, GPU used)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robertgshaw2-neuralmagic do we know why this is 1024 specifically?



def fused_experts_awq(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scales: torch.Tensor,
w2_scales: torch.Tensor,
w1_qzeros: torch.Tensor,
w2_qzeros: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
pack_factor: int,
) -> torch.Tensor:
"""
This function computes an AWQ fused_expert.

Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- w1_scales (torch.Tensor): scale to be used for w1.
- w2_scales (torch.Tensor): scale to be used for w2.
- w1_qzeros (torch.Tensor): zero point to be used for w1.
- w2_qzeros (torch.Tensor): zero point to be used for w2.
- pack_factor (int): Weight packing factor (int4 in int32 == 8)

Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""

# If large seq_len prefill, dequantize and use the fp16 MoE kernel.
do_naive_dequant = hidden_states.shape[:-1].numel() >= NAIVE_THRESHOLD
if do_naive_dequant:
# NOTE: not contiguous because of the permutation operation
dequant_w1 = ops.awq_dequantize(w1, w1_scales, w1_qzeros, 0, 0,
0).permute(0, 2, 1).contiguous()
dequant_w2 = ops.awq_dequantize(w2, w2_scales, w2_qzeros, 0, 0,
0).permute(0, 2, 1).contiguous()

return fused_experts(hidden_states, dequant_w1, dequant_w2,
topk_weights, topk_ids)

(sorted_token_ids, expert_ids,
num_tokens_post_padded) = moe_align_block_size(topk_ids, 16, w1.shape[0])

x = hidden_states.view(hidden_states.shape[0], 1, *hidden_states.shape[1:])

gate_up = ops.awq_fused_moe(input=x,
qweight=w1,
scales=w1_scales,
qzeros=w1_qzeros,
topk_weights=topk_weights,
sorted_token_ids=sorted_token_ids,
expert_ids=expert_ids,
num_tokens_post_padded=num_tokens_post_padded,
mul_weights=False,
pack_factor=pack_factor)

out = torch.empty((gate_up.shape[:-1] + (gate_up.shape[-1] // 2, )),
dtype=hidden_states.dtype,
device=hidden_states.device)
ops.silu_and_mul(out, gate_up)

out = ops.awq_fused_moe(input=out,
qweight=w2,
scales=w2_scales,
qzeros=w2_qzeros,
topk_weights=topk_weights,
sorted_token_ids=sorted_token_ids,
expert_ids=expert_ids,
num_tokens_post_padded=num_tokens_post_padded,
mul_weights=True,
pack_factor=pack_factor)

return torch.sum(out, dim=1)
Loading
Loading