Skip to content

[ Kernel ] AWQ Fused MoE #6415

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _zeros,
int64_t split_k_iters);

torch::Tensor awq_fused_moe(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros, torch::Tensor _topk_weights,
torch::Tensor _sorted_token_ids_ptr,
torch::Tensor _expert_ids_ptr,
torch::Tensor _num_tokens_post_padded,
bool mul_weights, int64_t split_k_iters);

torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros, int64_t split_k_iters,
Expand Down
421 changes: 414 additions & 7 deletions csrc/quantization/awq/gemm_kernels.cu

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("awq_gemm", &awq_gemm);
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);

// Quantized Grouped GEMM for AWQ.
ops.def("awq_fused_moe", &awq_fused_moe);
ops.impl("awq_fused_moe", torch::kCUDA, &awq_fused_moe);

// Dequantization for AWQ.
ops.def("awq_dequantize", &awq_dequantize);
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
Expand Down
11 changes: 11 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,17 @@ def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)


def awq_fused_moe(input: torch.Tensor, qweight: torch.Tensor,
scales: torch.Tensor, qzeros: torch.Tensor,
topk_weights: torch.Tensor, sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor, num_tokens_post_padded: int,
mul_weights: bool, pack_factor: int) -> torch.Tensor:
return torch.ops._C.awq_fused_moe(input, qweight, scales, qzeros,
topk_weights, sorted_token_ids,
expert_ids, num_tokens_post_padded,
mul_weights, pack_factor)


# gptq
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk)
from vllm.model_executor.layers.fused_moe.fused_moe_awq import fused_moe_awq
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase)

__all__ = [
"fused_moe",
"fused_moe_awq",
"fused_topk",
"fused_experts",
"get_config_file_name",
Expand Down
79 changes: 79 additions & 0 deletions vllm/model_executor/layers/fused_moe/fused_moe_awq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""Fused MoE utilities for AWQ."""
import torch

from vllm import _custom_ops as ops
from vllm.logger import init_logger

from .fused_moe import fused_moe, fused_topk, moe_align_block_size

logger = init_logger(__name__)


def fused_moe_awq(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
pack_factor: int,
w1_scales: torch.Tensor,
w2_scales: torch.Tensor,
w1_qzeros: torch.Tensor,
w2_qzeros: torch.Tensor,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.

Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- pack_factor (int): Weight packing factor (int4 in int32 == 8)
- w1_scales (torch.Tensor): scale to be used for w1.
- w2_scales (torch.Tensor): scale to be used for w2.
- w1_qzeros (torch.Tensor): zero point to be used for w1.
- w2_qzeros (torch.Tensor): zero point to be used for w2.

Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""

# If large seq_len prefill, dequantize and use the fp16 MoE kernel.
do_naive_dequant = hidden_states.shape[:-1].numel() >= 1024
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think numel of a shape works here, you should use the product

if do_naive_dequant:
# TODO: why is this not contiguous already?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@casper-hansen Any idea why these are not contiguous by default?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure. The dequantization kernels were originally implemented in FasterTransformer, then adapted for dequantization for AWQ. I can only assume it would cause problems when running the GEMM kernel which uses shared memory

dequant_w1 = ops.awq_dequantize(w1, w1_scales, w1_qzeros, 0, 0,
0).permute(0, 2, 1).contiguous()
dequant_w2 = ops.awq_dequantize(w2, w2_scales, w2_qzeros, 0, 0,
0).permute(0, 2, 1).contiguous()

return fused_moe(hidden_states, dequant_w1, dequant_w2, gating_output,
topk, renormalize)

topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
(sorted_token_ids, expert_ids,
num_tokens_post_padded) = moe_align_block_size(topk_ids, 16, w1.shape[0])

x = hidden_states.view(hidden_states.shape[0], 1, *hidden_states.shape[1:])

gate_up = ops.awq_fused_moe(x, w1, w1_scales, w1_qzeros, topk_weights,
sorted_token_ids, expert_ids,
num_tokens_post_padded, False, pack_factor)

out = torch.empty((gate_up.shape[:-1] + (gate_up.shape[-1] // 2, )),
dtype=hidden_states.dtype,
device=hidden_states.device)
ops.silu_and_mul(out, gate_up)

out = ops.awq_fused_moe(out, w2, w2_scales, w2_qzeros, topk_weights,
sorted_token_ids, expert_ids,
num_tokens_post_padded, True, pack_factor)

return torch.sum(out, dim=1)
83 changes: 58 additions & 25 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,14 @@ def __init__(
params_dtype=params_dtype,
weight_loader=self.weight_loader)

def weight_loader(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str,
shard_id: int, expert_id: int):
def _load_fp8_scale(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str,
shard_id: int, expert_id: int) -> None:
param_data = param.data

# FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral.
# Follow up PR to enable fp8 for other MoE models.
if "input_scale" in weight_name or "w2.weight_scale" in weight_name:
if "input_scale" in weight_name or "w2_weight_scale" in weight_name:
if param_data[expert_id] != 1 and (param_data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
Expand All @@ -155,28 +155,61 @@ def weight_loader(self, param: torch.nn.Parameter,
elif "weight_scale" in weight_name:
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
assert "w1" in weight_name or "w3" in weight_name
shard_id = 0 if "w1" in weight_name else 1
param_data[expert_id][shard_id] = loaded_weight
assert shard_id == 0 or shard_id == 2
shard_idx = 0 if shard_id == 0 else 1
param_data[expert_id][shard_idx] = loaded_weight

def weight_loader(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str,
shard_id: int, expert_id: int) -> None:
if shard_id not in [0,1,2]:
raise ValueError(f"Shard id must be in [0,1,2] but got {shard_id}")

# Special case for fp8 scales.
if getattr(param, "is_fp8_scale", False):
self._load_fp8_scale(param.data, loaded_weight, weight_name,
shard_id, expert_id)
return

expert_data = param.data[expert_id]
tp_rank = get_tensor_model_parallel_rank()
is_gate_proj = (shard_id == 0)
is_down_proj = (shard_id == 1)
is_up_proj = (shard_id == 2)

# If transposed, weight is saved as [input_dim, output_dim]
# Otherwise, weight is saved as [output_dim, input_dim]
is_transposed = getattr(param, "is_transposed", False)
input_dim = 0 if is_transposed else 1
output_dim = 1 if is_transposed else 0

# Index the loaded weight for tp sharding.
# * down_proj: "RowParallel" so tp sharding on input_dim
if (is_down_proj):
shard_dim = input_dim
shard_size = expert_data.shape[shard_dim]
# * gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
elif (is_gate_proj or is_up_proj):
shard_dim = output_dim
shard_size = expert_data.shape[output_dim] // 2
offset = shard_size * tp_rank
loaded_weight = loaded_weight.narrow(shard_dim, offset, shard_size)

# Narrow parameter and load.
# w1, gate_proj: Load into first shard of w13.
if is_gate_proj:
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
expert_data.copy_(loaded_weight)
# w3, up_proj: Load into second shard of w13.
elif is_up_proj:
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight)
# w2, down_proj: Load into only shard of w2.
elif is_down_proj:
expert_data.copy_(loaded_weight)
else:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.intermediate_size_per_partition
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)

# w1, gate_proj case: Load into first shard of w13.
if shard_id == 0:
param_data[expert_id,
0:shard_size, :] = loaded_weight[shard, :]
# w3, up_proj case: Load into second shard of w13.
elif shard_id == 2:
param_data[expert_id, shard_size:2 *
shard_size, :] = loaded_weight[shard, :]
# w2, down_proj case: Load into only shard of w2.
elif shard_id == 1:
param_data[expert_id, :, :] = loaded_weight[:, shard]
else:
raise ValueError(
f"Shard id must be in [0,1,2] but got {shard_id}")
raise ValueError


def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
Expand Down
128 changes: 126 additions & 2 deletions vllm/model_executor/layers/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from torch.nn.parameter import Parameter

from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
fused_moe_awq)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs


Expand Down Expand Up @@ -64,9 +66,11 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
return cls(weight_bits, group_size, zero_point)

def get_quant_method(
self, layer: torch.nn.Module) -> Optional["AWQLinearMethod"]:
self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
return AWQLinearMethod(self)
elif isinstance(layer, FusedMoE):
return AWQMoEMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
Expand Down Expand Up @@ -174,3 +178,123 @@ def apply(self,
if bias is not None:
out.add_(bias)
return out.reshape(out_shape)


class AWQMoEMethod(FusedMoEMethodBase):

def __init__(self, quant_config: AWQConfig):
self.quant_config = quant_config

def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
params_dtype: torch.dtype, **extra_weight_attrs):

# WEIGHTS
w13_qweight = Parameter(torch.empty(num_experts,
hidden_size,
2 * intermediate_size //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_qweight", w13_qweight)
set_weight_attrs(
w13_qweight, {
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
"is_transposed": True,
**extra_weight_attrs
})

w2_qweight = Parameter(torch.empty(num_experts,
intermediate_size,
hidden_size //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w2_qweight", w2_qweight)
set_weight_attrs(
w2_qweight, {
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
"is_transposed": True,
**extra_weight_attrs
})

# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
w13_scales = Parameter(torch.empty(num_experts,
hidden_size //
self.quant_config.group_size,
intermediate_size * 2,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_scales", w13_scales)
set_weight_attrs(w13_scales, {
"is_transposed": True,
**extra_weight_attrs
})

w2_scales = Parameter(torch.empty(num_experts,
intermediate_size //
self.quant_config.group_size,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_scales", w2_scales)
set_weight_attrs(w2_scales, {
"is_transposed": True,
**extra_weight_attrs
})

# WEIGHT_ZERO_POINT
# Allocate 2 zero points for w1 and w3 respectively.
w13_qzeros = Parameter(torch.empty(
num_experts,
hidden_size // self.quant_config.group_size,
2 * intermediate_size // self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_qzeros", w13_qzeros)
set_weight_attrs(
w13_qzeros, {
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
"is_transposed": True,
**extra_weight_attrs
})

w2_qzeros = Parameter(torch.empty(
num_experts,
intermediate_size // self.quant_config.group_size,
hidden_size // self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w2_qzeros", w2_qzeros)
set_weight_attrs(
w2_qzeros, {
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
"is_transposed": True,
**extra_weight_attrs
})

def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True) -> torch.Tensor:

return fused_moe_awq(
x,
layer.w13_qweight,
layer.w2_qweight,
router_logits,
top_k,
renormalize=renormalize,
pack_factor=self.quant_config.pack_factor,
w1_scales=layer.w13_scales,
w2_scales=layer.w2_scales,
w1_qzeros=layer.w13_qzeros,
w2_qzeros=layer.w2_qzeros,
)
Loading
Loading