Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 1e88172

Browse files
pcmoritzWoosukKwon
authored andcommitted
[Kernel] Optimize FP8 support for MoE kernel / Mixtral via static scales (vllm-project#4343)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent 192c704 commit 1e88172

File tree

7 files changed

+95
-18
lines changed

7 files changed

+95
-18
lines changed

csrc/ops.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,12 @@ void gptq_shuffle(
146146
torch::Tensor q_perm,
147147
int bit);
148148

149-
void scaled_fp8_quant(
149+
void static_scaled_fp8_quant(
150+
torch::Tensor& out,
151+
torch::Tensor& input,
152+
torch::Tensor& scale);
153+
154+
void dynamic_scaled_fp8_quant(
150155
torch::Tensor& out,
151156
torch::Tensor& input,
152157
torch::Tensor& scale);

csrc/pybind.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
7373
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
7474
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
7575
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
76-
ops.def("scaled_fp8_quant", &scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
76+
ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor");
77+
ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
7778
ops.def(
7879
"moe_align_block_size",
7980
&moe_align_block_size,

csrc/quantization/fp8/fp8_cuda_kernels.cu

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,30 @@ __global__ void scaled_fp8_quant_kernel(
7474

7575
} // namespace vllm
7676

77-
void scaled_fp8_quant(
77+
void static_scaled_fp8_quant(
78+
torch::Tensor& out, // [..., d]
79+
torch::Tensor& input, // [..., d]
80+
torch::Tensor& scale) // [1]
81+
{
82+
int64_t num_tokens = input.numel() / input.size(-1);
83+
int64_t num_elems = input.numel();
84+
dim3 grid(num_tokens);
85+
dim3 block(1024);
86+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
87+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
88+
VLLM_DISPATCH_FLOATING_TYPES(
89+
input.scalar_type(),
90+
"scaled_fp8_quant_kernel",
91+
[&] {
92+
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
93+
out.data_ptr<c10::Float8_e4m3fn>(),
94+
input.data_ptr<scalar_t>(),
95+
scale.data_ptr<float>(),
96+
num_elems);
97+
});
98+
}
99+
100+
void dynamic_scaled_fp8_quant(
78101
torch::Tensor& out, // [..., d]
79102
torch::Tensor& input, // [..., d]
80103
torch::Tensor& scale) // [1]

vllm/_custom_ops.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,16 @@ def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
168168

169169

170170
# fp8
171-
def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
172-
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
171+
def scaled_fp8_quant(
172+
input: torch.Tensor,
173+
scale: Optional[torch.Tensor] = None,
174+
) -> Tuple[torch.Tensor, torch.Tensor]:
173175
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
174-
vllm_ops.scaled_fp8_quant(output, input, scale)
176+
if scale is None:
177+
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
178+
vllm_ops.dynamic_scaled_fp8_quant(output, input, scale)
179+
else:
180+
vllm_ops.static_scaled_fp8_quant(output, input, scale)
175181
return output, scale
176182

177183

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,9 @@ def moe_align_block_size(
220220

221221

222222
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
223-
B_scale: torch.Tensor, topk_weights: torch.Tensor,
224-
topk_ids: torch.Tensor,
223+
A_scale: Optional[torch.Tensor],
224+
B_scale: Optional[torch.Tensor],
225+
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
225226
sorted_token_ids: torch.Tensor,
226227
expert_ids: torch.Tensor,
227228
num_tokens_post_padded: torch.Tensor,
@@ -232,10 +233,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
232233
assert sorted_token_ids.stride(0) == 1
233234

234235
if not use_fp8:
235-
A_scale = None
236+
assert A_scale is None
236237
assert B_scale is None
237238
else:
238-
A, A_scale = ops.scaled_fp8_quant(A)
239+
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
239240
assert B_scale is not None
240241

241242
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
@@ -318,6 +319,8 @@ def fused_moe(
318319
use_fp8: bool = False,
319320
w1_scale: Optional[torch.Tensor] = None,
320321
w2_scale: Optional[torch.Tensor] = None,
322+
a1_scale: Optional[torch.Tensor] = None,
323+
a2_scale: Optional[torch.Tensor] = None,
321324
) -> torch.Tensor:
322325
"""
323326
This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -434,6 +437,7 @@ def fused_moe(
434437
invoke_fused_moe_kernel(hidden_states,
435438
w1,
436439
intermediate_cache1,
440+
a1_scale,
437441
w1_scale,
438442
topk_weights,
439443
topk_ids,
@@ -451,6 +455,7 @@ def fused_moe(
451455
invoke_fused_moe_kernel(intermediate_cache2,
452456
w2,
453457
intermediate_cache3,
458+
a2_scale,
454459
w2_scale,
455460
topk_weights,
456461
topk_ids,

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414
class Fp8Config(QuantizationConfig):
1515
"""Config class for FP8."""
1616

17+
def __init__(
18+
self,
19+
activation_scheme: str = "dynamic",
20+
) -> None:
21+
self.activation_scheme = activation_scheme
22+
1723
@classmethod
1824
def get_name(cls) -> str:
1925
return "fp8"
@@ -35,7 +41,8 @@ def get_config_filenames(cls) -> List[str]:
3541

3642
@classmethod
3743
def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
38-
return cls()
44+
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
45+
return cls(activation_scheme)
3946

4047
def get_quant_method(
4148
self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:

vllm/model_executor/models/mixtral.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,13 @@ def __init__(
105105
device="cuda",
106106
dtype=self.params_dtype))
107107

108+
set_weight_attrs(self.ws, {
109+
"weight_loader": self.weight_loader,
110+
})
111+
set_weight_attrs(self.w2s, {
112+
"weight_loader": self.weight_loader,
113+
})
114+
108115
# Scaling factors for FP8 weights
109116
self.ws_scale = nn.Parameter(
110117
torch.ones(
@@ -115,12 +122,23 @@ def __init__(
115122
self.num_total_experts, device="cuda", dtype=torch.float32),
116123
requires_grad=False) if self.use_fp8 else None
117124

118-
set_weight_attrs(self.ws, {
119-
"weight_loader": self.weight_loader,
120-
})
121-
set_weight_attrs(self.w2s, {
122-
"weight_loader": self.weight_loader,
123-
})
125+
# Scaling factors for FP8 activations
126+
need_act_scales = (self.use_fp8
127+
and quant_config.activation_scheme == "static")
128+
self.as_scale = nn.Parameter(
129+
torch.zeros(1, device="cuda", dtype=torch.float32),
130+
requires_grad=False) if need_act_scales else None
131+
self.a2s_scale = nn.Parameter(
132+
torch.zeros(1, device="cuda", dtype=torch.float32),
133+
requires_grad=False) if need_act_scales else None
134+
135+
if need_act_scales:
136+
set_weight_attrs(self.as_scale, {
137+
"weight_loader": self.weight_loader,
138+
})
139+
set_weight_attrs(self.a2s_scale, {
140+
"weight_loader": self.weight_loader,
141+
})
124142

125143
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
126144
weight_name: str, expert_id: int):
@@ -135,6 +153,8 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
135153
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
136154
if weight_name.endswith("w2.weight"):
137155
param_data[expert_id, :, :] = loaded_weight[:, shard]
156+
if "act_scale" in weight_name:
157+
param_data[:] = param_data[:].max(loaded_weight)
138158

139159
def process_weights_after_loading(self):
140160
if self.use_fp8:
@@ -162,7 +182,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
162182
inplace=True,
163183
use_fp8=self.use_fp8,
164184
w1_scale=self.ws_scale,
165-
w2_scale=self.w2s_scale)
185+
w2_scale=self.w2s_scale,
186+
a1_scale=self.as_scale,
187+
a2_scale=self.a2s_scale)
166188

167189
if self.tp_size > 1:
168190
final_hidden_states = tensor_model_parallel_all_reduce(
@@ -443,11 +465,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
443465
]
444466

445467
expert_params_mapping = [
468+
# These are the weights for the experts
446469
# (param_name, weight_name, expert_id)
447470
("ws" if weight_name in ["w1", "w3"] else "w2s",
448471
f"experts.{expert_id}.{weight_name}.weight", expert_id)
449472
for expert_id in range(self.config.num_local_experts)
450473
for weight_name in ["w1", "w2", "w3"]
474+
] + [
475+
# These are the activation scales for the experts
476+
# (param_name, weight_name, expert_id)
477+
("as_scale" if weight_name in ["w1", "w3"] else "a2s_scale",
478+
f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
479+
for expert_id in range(self.config.num_local_experts)
480+
for weight_name in ["w1", "w2", "w3"]
451481
]
452482

453483
params_dict = dict(self.named_parameters())

0 commit comments

Comments
 (0)