Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Optimize FP8 support for MoE kernel / Mixtral via static scales #4343

Merged
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,12 @@ void gptq_shuffle(
torch::Tensor q_perm,
int bit);

void scaled_fp8_quant(
void static_scaled_fp8_quant(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& scale);

void dynamic_scaled_fp8_quant(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& scale);
Expand Down
3 changes: 2 additions & 1 deletion csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def("scaled_fp8_quant", &scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor");
ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
ops.def(
"moe_align_block_size",
&moe_align_block_size,
Expand Down
25 changes: 24 additions & 1 deletion csrc/quantization/fp8/fp8_cuda_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,30 @@ __global__ void scaled_fp8_quant_kernel(

} // namespace vllm

void scaled_fp8_quant(
void static_scaled_fp8_quant(
torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d]
torch::Tensor& scale) // [1]
{
int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel();
dim3 grid(num_tokens);
dim3 block(1024);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"scaled_fp8_quant_kernel",
[&] {
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(),
input.data_ptr<scalar_t>(),
scale.data_ptr<float>(),
num_elems);
});
}

void dynamic_scaled_fp8_quant(
torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d]
torch::Tensor& scale) // [1]
Expand Down
12 changes: 9 additions & 3 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,16 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,


# fp8
def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
vllm_ops.scaled_fp8_quant(output, input, scale)
if scale is None:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
pcmoritz marked this conversation as resolved.
Show resolved Hide resolved
vllm_ops.dynamic_scaled_fp8_quant(output, input, scale)
else:
vllm_ops.static_scaled_fp8_quant(output, input, scale)
return output, scale


Expand Down
13 changes: 9 additions & 4 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,9 @@ def moe_align_block_size(


def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
B_scale: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
Expand All @@ -232,10 +233,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
assert sorted_token_ids.stride(0) == 1

if not use_fp8:
A_scale = None
assert A_scale is None
assert B_scale is None
else:
A, A_scale = ops.scaled_fp8_quant(A)
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
assert B_scale is not None

grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
Expand Down Expand Up @@ -318,6 +319,8 @@ def fused_moe(
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
Expand Down Expand Up @@ -434,6 +437,7 @@ def fused_moe(
invoke_fused_moe_kernel(hidden_states,
w1,
intermediate_cache1,
a1_scale,
w1_scale,
topk_weights,
topk_ids,
Expand All @@ -451,6 +455,7 @@ def fused_moe(
invoke_fused_moe_kernel(intermediate_cache2,
w2,
intermediate_cache3,
a2_scale,
w2_scale,
topk_weights,
topk_ids,
Expand Down
9 changes: 8 additions & 1 deletion vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
class FP8Config(QuantizationConfig):
"""Config class for FP8."""

def __init__(
self,
activation_scheme: str = "dynamic",
) -> None:
self.activation_scheme = activation_scheme

@classmethod
def get_name(cls) -> str:
return "fp8"
Expand All @@ -34,7 +40,8 @@ def get_config_filenames(cls) -> List[str]:

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "FP8Config":
return cls()
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
return cls(activation_scheme)

def get_linear_method(self) -> "Fp8LinearMethod":
return Fp8LinearMethod(self)
Expand Down
45 changes: 38 additions & 7 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@ def __init__(
device="cuda",
dtype=self.params_dtype))

set_weight_attrs(self.ws, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2s, {
"weight_loader": self.weight_loader,
})

# Scaling factors for FP8 weights
self.ws_scale = nn.Parameter(
torch.ones(
Expand All @@ -114,12 +121,24 @@ def __init__(
self.num_total_experts, device="cuda", dtype=torch.float32),
requires_grad=False) if self.use_fp8 else None

set_weight_attrs(self.ws, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2s, {
"weight_loader": self.weight_loader,
})
# Scaling factors for FP8 activations
need_act_scales = (self.use_fp8
and linear_method.quant_config.activation_scheme
== "static")
self.as_scale = nn.Parameter(
torch.zeros(1, device="cuda", dtype=torch.float32),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Theoretically, we shouldn't use "cuda" in the model code. Since the GPU worker sets "cuda" as the default device in torch, device="cuda" is not necessary. Also, it's not good for the compatibility with non-CUDA devices.

This rule is violated for Mixtral and other MoE models unfortunately. 😢

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'll make a follow up PR to remove the device="cuda" -- since we also specify it explicitly for the other parameters, I don't want to be inconsistent for this PR :)

requires_grad=False) if need_act_scales else None
self.a2s_scale = nn.Parameter(
torch.zeros(1, device="cuda", dtype=torch.float32),
requires_grad=False) if need_act_scales else None

if need_act_scales:
set_weight_attrs(self.as_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.a2s_scale, {
"weight_loader": self.weight_loader,
})

def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to this PR, I think we should have an MoELayer that is shared across models

All of these changes are currently only impacting Mixtral, but could also be applied to other models. Since we have all this generic logic in the model definitions, we are losing out at running others with these features

weight_name: str, expert_id: int):
Expand All @@ -134,6 +153,8 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
if "act_scale" in weight_name:
param_data[:] = param_data[:].max(loaded_weight)

def process_weights_after_loading(self):
if self.use_fp8:
Expand Down Expand Up @@ -161,7 +182,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
inplace=True,
use_fp8=self.use_fp8,
w1_scale=self.ws_scale,
w2_scale=self.w2s_scale)
w2_scale=self.w2s_scale,
a1_scale=self.as_scale,
a2_scale=self.a2s_scale)

if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
Expand Down Expand Up @@ -443,11 +466,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
]

expert_params_mapping = [
# These are the weights for the experts
# (param_name, weight_name, expert_id)
("ws" if weight_name in ["w1", "w3"] else "w2s",
f"experts.{expert_id}.{weight_name}.weight", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
] + [
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
("as_scale" if weight_name in ["w1", "w3"] else "a2s_scale",
f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]

params_dict = dict(self.named_parameters())
Expand Down
Loading