Skip to content

Commit c5a9406

Browse files
gshtrascharlifu
andauthored
Deepseek V3 support (#364)
* Changing the hard coded datatype to see if it's enough for the model to work * Picking the upstrteam moe kernel version * make upstream fix for v3 also works for rocm v2 * Conditional fnuz dtype * Requantizing from fn to fnuz * Requantizing moe as well * Actually requantizing moe weights * Conditional requantization and assert on padding in block quant * Format --------- Co-authored-by: charlifu <charlifu@amd.com>
1 parent 8bd76fb commit c5a9406

File tree

3 files changed

+119
-36
lines changed

3 files changed

+119
-36
lines changed

csrc/moe/moe_align_sum_kernels.cu

Lines changed: 71 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,11 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
5555
__syncthreads();
5656

5757
// For each expert we accumulate the token counts from the different threads.
58-
for (int eid = threadIdx.x; eid < num_experts; eid += blockDim.x) {
59-
tokens_cnts[index(num_experts, 0, eid)] = 0;
58+
if (threadIdx.x < num_experts) {
59+
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
6060
for (int i = 1; i <= blockDim.x; ++i) {
61-
tokens_cnts[index(num_experts, i, eid)] +=
62-
tokens_cnts[index(num_experts, i - 1, eid)];
61+
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
62+
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
6363
}
6464
}
6565

@@ -83,9 +83,10 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
8383
* For each expert, each thread processes the tokens of the corresponding
8484
* blocks and stores the corresponding expert_id for each block.
8585
*/
86-
for (int eid = threadIdx.x; eid < num_experts; eid += blockDim.x) {
87-
for (int i = cumsum[eid]; i < cumsum[eid + 1]; i += block_size) {
88-
expert_ids[i / block_size] = eid;
86+
if (threadIdx.x < num_experts) {
87+
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
88+
i += block_size) {
89+
expert_ids[i / block_size] = threadIdx.x;
8990
}
9091
}
9192

@@ -140,11 +141,11 @@ __global__ void moe_align_block_size_global_mem_kernel(
140141
__syncthreads();
141142

142143
// For each expert we accumulate the token counts from the different threads.
143-
for (int eid = threadIdx.x; eid < num_experts; eid += blockDim.x) {
144-
tokens_cnts[index(num_experts, 0, eid)] = 0;
144+
if (threadIdx.x < num_experts) {
145+
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
145146
for (int i = 1; i <= blockDim.x; ++i) {
146-
tokens_cnts[index(num_experts, i, eid)] +=
147-
tokens_cnts[index(num_experts, i - 1, eid)];
147+
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
148+
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
148149
}
149150
}
150151

@@ -168,9 +169,10 @@ __global__ void moe_align_block_size_global_mem_kernel(
168169
* For each expert, each thread processes the tokens of the corresponding
169170
* blocks and stores the corresponding expert_id for each block.
170171
*/
171-
for (int eid = threadIdx.x; eid < num_experts; eid += blockDim.x) {
172-
for (int i = cumsum[eid]; i < cumsum[eid + 1]; i += block_size) {
173-
expert_ids[i / block_size] = eid;
172+
if (threadIdx.x < num_experts) {
173+
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
174+
i += block_size) {
175+
expert_ids[i / block_size] = threadIdx.x;
174176
}
175177
}
176178

@@ -221,25 +223,61 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
221223
torch::Tensor experts_ids,
222224
torch::Tensor num_tokens_post_pad) {
223225
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
224-
VLLM_DISPATCH_INTEGRAL_TYPES(
225-
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
226-
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
227-
// tensors
228-
const int32_t num_thread = WARP_SIZE;
229-
const int32_t shared_mem =
230-
((num_thread + 1) * num_experts + (num_experts + 1)) *
231-
sizeof(int32_t);
232-
233-
// set dynamic shared mem
234-
auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
235-
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
236-
(void*)kernel, shared_mem));
237-
kernel<<<1, num_thread, shared_mem, stream>>>(
238-
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
239-
experts_ids.data_ptr<int32_t>(),
240-
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
241-
topk_ids.numel());
242-
});
226+
227+
// If we have very large number of experts, we can no longer use shared
228+
// memory.
229+
// TODO(simon): the right solution should be calculating the exact right
230+
// amount of shared memory and use that. The num_experts >= 256 is just a
231+
// temporary solution to unblock Deepseek V3.
232+
if (num_experts >= 96) {
233+
VLLM_DISPATCH_INTEGRAL_TYPES(
234+
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
235+
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
236+
// tensors
237+
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
238+
239+
const int32_t mem_tokens_cnts =
240+
((num_experts + 1) * num_experts) * sizeof(int32_t);
241+
const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t);
242+
// allocate global memory
243+
int32_t* tokens_cnts;
244+
int32_t* cumsum;
245+
cudaMalloc(&tokens_cnts, mem_tokens_cnts);
246+
cudaMalloc(&cumsum, mem_cumsum);
247+
248+
auto kernel =
249+
vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>;
250+
kernel<<<1, num_thread, 0, stream>>>(
251+
topk_ids.data_ptr<scalar_t>(),
252+
sorted_token_ids.data_ptr<int32_t>(),
253+
experts_ids.data_ptr<int32_t>(),
254+
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
255+
topk_ids.numel(), tokens_cnts, cumsum);
256+
cudaFree(tokens_cnts);
257+
cudaFree(cumsum);
258+
});
259+
} else {
260+
VLLM_DISPATCH_INTEGRAL_TYPES(
261+
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
262+
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
263+
// tensors
264+
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
265+
const int32_t shared_mem =
266+
((num_thread + 1) * num_experts + (num_experts + 1)) *
267+
sizeof(int32_t);
268+
269+
// set dynamic shared mem
270+
auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
271+
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
272+
(void*)kernel, shared_mem));
273+
kernel<<<1, num_thread, shared_mem, stream>>>(
274+
topk_ids.data_ptr<scalar_t>(),
275+
sorted_token_ids.data_ptr<int32_t>(),
276+
experts_ids.data_ptr<int32_t>(),
277+
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
278+
topk_ids.numel());
279+
});
280+
}
243281
}
244282

245283
void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ def create_weights(
166166
weight_loader = extra_weight_attrs.get("weight_loader")
167167

168168
if self.block_quant:
169+
assert not envs.VLLM_FP8_PADDING, (
170+
"FP8 weight padding is not supported in block quantization.")
169171
tp_size = get_tensor_model_parallel_world_size()
170172
assert self.quant_config.weight_block_size is not None
171173
block_n, block_k = (
@@ -196,8 +198,9 @@ def create_weights(
196198
layer.output_size_per_partition = output_size_per_partition
197199
layer.orig_dtype = params_dtype
198200

201+
fp8_dtype = torch.float8_e4m3fn
199202
# WEIGHT
200-
weight_dtype = (torch.float8_e4m3fn
203+
weight_dtype = (fp8_dtype
201204
if self.quant_config.is_checkpoint_fp8_serialized else
202205
params_dtype)
203206

@@ -252,6 +255,15 @@ def create_weights(
252255
def process_weights_after_loading(self, layer: Module) -> None:
253256
# Block quant doesn't need to process weights after loading
254257
if self.block_quant:
258+
if current_platform.is_rocm() and not is_navi():
259+
weight, weight_scale, _ = \
260+
normalize_e4m3fn_to_e4m3fnuz(
261+
weight=layer.weight,
262+
weight_scale=layer.weight_scale_inv,
263+
input_scale=layer.input_scale)
264+
layer.weight = Parameter(weight, requires_grad=False)
265+
layer.weight_scale_inv = Parameter(weight_scale,
266+
requires_grad=False)
255267
return
256268
layer.weight = torch.nn.Parameter(layer.weight.data,
257269
requires_grad=False)
@@ -512,6 +524,30 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
512524
def process_weights_after_loading(self, layer: Module) -> None:
513525
# Block quant doesn't need to process weights after loading
514526
if self.block_quant:
527+
if current_platform.is_rocm() and not is_navi():
528+
w13_weight, w13_weight_scale_inv, w13_input_scale = \
529+
normalize_e4m3fn_to_e4m3fnuz(
530+
layer.w13_weight, layer.w13_weight_scale_inv,
531+
layer.w13_input_scale)
532+
w2_weight, w2_weight_scale_inv, w2_input_scale = \
533+
normalize_e4m3fn_to_e4m3fnuz(
534+
layer.w2_weight, layer.w2_weight_scale_inv,
535+
layer.w2_input_scale)
536+
# Reset the parameter
537+
layer.w13_weight = torch.nn.Parameter(w13_weight,
538+
requires_grad=False)
539+
layer.w13_weight_scale_inv = torch.nn.Parameter(
540+
w13_weight_scale_inv, requires_grad=False)
541+
if w13_input_scale is not None:
542+
layer.w13_input_scale = torch.nn.Parameter(
543+
w13_input_scale, requires_grad=False)
544+
layer.w2_weight = torch.nn.Parameter(w2_weight,
545+
requires_grad=False)
546+
layer.w2_weight_scale_inv = torch.nn.Parameter(
547+
w2_weight_scale_inv, requires_grad=False)
548+
if w2_input_scale is not None:
549+
layer.w2_input_scale = torch.nn.Parameter(
550+
w2_input_scale, requires_grad=False)
515551
return
516552
# If checkpoint is fp16, quantize in place.
517553
if not self.quant_config.is_checkpoint_fp8_serialized:

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
import triton
66
import triton.language as tl
77

8+
from vllm.platforms import current_platform
9+
from vllm.utils import is_navi
10+
811

912
def apply_w8a8_block_fp8_linear(
1013
input: torch.Tensor,
@@ -34,10 +37,13 @@ def apply_w8a8_block_fp8_linear(
3437

3538
def input_to_float8(
3639
x: torch.Tensor,
37-
dtype: torch.dtype = torch.float8_e4m3fn
40+
dtype: Optional[torch.dtype] = None,
3841
) -> Tuple[torch.Tensor, torch.Tensor]:
3942
"""This function quantizes input values to float8 values "
4043
"with tensor-wise quantization."""
44+
if dtype is None:
45+
dtype = (torch.float8_e4m3fnuz if current_platform.is_rocm()
46+
and not is_navi() else torch.float8_e4m3fn)
4147
finfo = torch.finfo(dtype)
4248
min_val, max_val = x.aminmax()
4349
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
@@ -125,7 +131,7 @@ def per_token_group_quant_fp8(
125131
x: torch.Tensor,
126132
group_size: int,
127133
eps: float = 1e-10,
128-
dtype: torch.dtype = torch.float8_e4m3fn,
134+
dtype: Optional[torch.dtype] = None,
129135
) -> Tuple[torch.Tensor, torch.Tensor]:
130136
"""Function to perform per-token-group quantization on an input tensor `x`.
131137
It converts the tensor values into signed float8 values and returns the
@@ -140,6 +146,9 @@ def per_token_group_quant_fp8(
140146
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
141147
scaling factor for quantization.
142148
"""
149+
if dtype is None:
150+
dtype = (torch.float8_e4m3fnuz if current_platform.is_rocm()
151+
and not is_navi() else torch.float8_e4m3fn)
143152
assert (x.shape[-1] % group_size == 0), (
144153
f"the last dimension of `x` {x.shape[-1]} must be divisible "
145154
f"by `group_size` {group_size}")

0 commit comments

Comments
 (0)