Skip to content

Commit 12358d3

Browse files
committed
fix up
1 parent 82ac2e4 commit 12358d3

File tree

6 files changed

+86
-108
lines changed

6 files changed

+86
-108
lines changed

csrc/moe_alig_block_size_kernels.cu renamed to csrc/moe_align_block_size_kernels.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ const static size_t NUM_MAX_EXPERTS = 64;
1111

1212
namespace vllm {
1313
template <typename scalar_t>
14-
__global__ void moe_alig_block_size_kernel(scalar_t *__restrict__ topk_ids,
14+
__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
1515
int32_t *sorted_token_ids,
1616
int32_t *expert_ids,
1717
int32_t *total_tokens_post_pad,
@@ -22,7 +22,7 @@ __global__ void moe_alig_block_size_kernel(scalar_t *__restrict__ topk_ids,
2222
const size_t start_idx = threadIdx.x * tokens_per_thread;
2323
__shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS];
2424
__shared__ int32_t cumsum[NUM_MAX_EXPERTS + 1];
25-
for(int i = 0;i < num_experts;i++){
25+
for(int i = 0; i < num_experts; ++i){
2626
tokens_cnts[threadIdx.x + 1][i] = 0;
2727
}
2828

@@ -33,23 +33,23 @@ __global__ void moe_alig_block_size_kernel(scalar_t *__restrict__ topk_ids,
3333
__syncthreads();
3434

3535
tokens_cnts[0][threadIdx.x] = 0;
36-
for(int i=1;i<=blockDim.x;++i){
36+
for(int i = 1; i <= blockDim.x; ++i){
3737
tokens_cnts[i][threadIdx.x] += tokens_cnts[i-1][threadIdx.x];
3838
}
3939

4040
__syncthreads();
4141

42-
if(threadIdx.x ==0){
42+
if(threadIdx.x == 0){
4343
cumsum[0] = 0;
44-
for(int i=1;i<=num_experts;++i){
44+
for(int i = 1; i <= num_experts; ++i){
4545
cumsum[i] = cumsum[i-1] + (tokens_cnts[blockDim.x][i - 1] + block_size - 1) / block_size * block_size;
4646
}
4747
*total_tokens_post_pad = cumsum[num_experts];
4848
}
4949

5050
__syncthreads();
5151

52-
for(int i= cumsum[threadIdx.x];i<cumsum[threadIdx.x + 1];i += block_size){
52+
for(int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size){
5353
expert_ids[i / block_size] = threadIdx.x;
5454
}
5555

@@ -62,7 +62,7 @@ __global__ void moe_alig_block_size_kernel(scalar_t *__restrict__ topk_ids,
6262
}
6363
}
6464

65-
void moe_alig_block_size(
65+
void moe_align_block_size(
6666
torch::Tensor topk_ids,
6767
int num_experts,
6868
int block_size,
@@ -73,7 +73,7 @@ void moe_alig_block_size(
7373
assert(num_experts <= NUM_MAX_EXPERTS);
7474
VLLM_DISPATCH_INTEGRAL_TYPES(
7575
topk_ids.scalar_type(), "moe_alig_block_size_kernel", [&] {
76-
vllm::moe_alig_block_size_kernel<scalar_t><<<1, num_experts, 0, stream>>>(
76+
vllm::moe_align_block_size_kernel<scalar_t><<<1, num_experts, 0, stream>>>(
7777
topk_ids.data_ptr<scalar_t>(),
7878
sorted_token_ids.data_ptr<int32_t>(),
7979
experts_ids.data_ptr<int32_t>(),

csrc/ops.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
122122
const std::vector<std::vector<int64_t>> &offsets);
123123
#endif
124124

125-
void moe_alig_block_size(
125+
void moe_align_block_size(
126126
torch::Tensor topk_ids,
127127
int num_experts,
128128
int block_size,

csrc/pybind.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
5757
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
5858
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
5959
ops.def(
60-
"moe_alig_block_size",
61-
&moe_alig_block_size,
60+
"moe_align_block_size",
61+
&moe_align_block_size,
6262
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");
6363

6464
// Cache ops

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def get_torch_arch_list() -> Set[str]:
305305
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
306306
"csrc/quantization/gptq/q_gemm.cu",
307307
"csrc/cuda_utils_kernels.cu",
308-
"csrc/moe_alig_block_size_kernels.cu",
308+
"csrc/moe_align_block_size_kernels.cu",
309309
"csrc/pybind.cpp",
310310
]
311311

vllm/model_executor/layers/fused_moe.py

Lines changed: 58 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ def fused_moe_kernel(
1616
expert_ids_ptr,
1717
num_tokens_post_padded_ptr,
1818
# Matrix dimensions
19-
M,
2019
N,
2120
K,
2221
EM,
@@ -86,10 +85,9 @@ def fused_moe_kernel(
8685
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
8786
offs_k[None, :] * stride_ak)
8887

89-
#
90-
off_experts = tl.load(expert_ids_ptr + pid_m) * stride_be
91-
b_ptrs = b_ptr + off_experts + (offs_k[:, None] * stride_bk +
92-
offs_bn[None, :] * stride_bn)
88+
off_experts = tl.load(expert_ids_ptr + pid_m)
89+
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
90+
offs_bn[None, :] * stride_bn)
9391

9492
# -----------------------------------------------------------
9593
# Iterate to compute a block of the C matrix.
@@ -129,7 +127,7 @@ def fused_moe_kernel(
129127
tl.store(c_ptrs, accumulator, mask=c_mask)
130128

131129

132-
def alig_block_size(
130+
def moe_align_block_size(
133131
topk_ids: torch.Tensor, block_size: int,
134132
num_experts: int) -> (torch.Tensor, torch.Tensor, torch.Tensor):
135133
"""
@@ -169,11 +167,48 @@ def alig_block_size(
169167
num_tokens_post_pad = torch.empty((1),
170168
dtype=torch.int32,
171169
device=topk_ids.device)
172-
ops.moe_alig_block_size(topk_ids, num_experts, block_size, sorted_ids,
173-
expert_ids, num_tokens_post_pad)
170+
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
171+
expert_ids, num_tokens_post_pad)
174172
return sorted_ids, expert_ids, num_tokens_post_pad
175173

176174

175+
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
176+
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
177+
sorted_token_ids: torch.Tensor,
178+
expert_ids: torch.Tensor,
179+
num_tokens_post_padded: torch.Tensor,
180+
mul_routed_weight: bool, top_k: int, config: dict):
181+
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
182+
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
183+
184+
fused_moe_kernel[grid](
185+
A,
186+
B,
187+
C,
188+
topk_weights,
189+
sorted_token_ids,
190+
expert_ids,
191+
num_tokens_post_padded,
192+
B.shape[1],
193+
B.shape[2],
194+
sorted_token_ids.shape[0],
195+
topk_ids.numel(),
196+
A.stride(0),
197+
A.stride(1),
198+
B.stride(0),
199+
B.stride(2),
200+
B.stride(1),
201+
C.stride(1),
202+
C.stride(2),
203+
topk_weights.stride(1),
204+
sorted_token_ids.stride(0),
205+
MUL_ROUTED_WEIGHT=mul_routed_weight,
206+
top_k=top_k,
207+
compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16,
208+
**config,
209+
)
210+
211+
177212
def fused_moe(hidden_states: torch.Tensor,
178213
w1: torch.Tensor,
179214
w2: torch.Tensor,
@@ -196,11 +231,12 @@ def fused_moe(hidden_states: torch.Tensor,
196231
"""
197232
# Check constraints.
198233
assert hidden_states.shape[1] == w1.shape[2], "Incompatible dimensions"
199-
assert hidden_states.is_contiguous(), "Matrix A must be contiguous"
200-
assert w1.is_contiguous(), "Matrix B must be contiguous"
234+
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
235+
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
236+
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
201237
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
202-
M, K = hidden_states.shape
203-
E, N, K = w1.shape
238+
M, _ = hidden_states.shape
239+
E, N, _ = w1.shape
204240

205241
config = {
206242
'BLOCK_SIZE_M': 64,
@@ -227,73 +263,21 @@ def fused_moe(hidden_states: torch.Tensor,
227263
device=hidden_states.device,
228264
dtype=hidden_states.dtype)
229265

230-
sorted_token_ids, expert_ids, num_tokens_post_padded = alig_block_size(
266+
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
231267
topk_ids, config['BLOCK_SIZE_M'], E)
232-
# 1D launch kernel where each block gets its own program.
233-
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
234-
'BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
235268

236-
fused_moe_kernel[grid](
237-
hidden_states,
238-
w1,
239-
intermediate_cache1,
240-
topk_weights,
241-
sorted_token_ids,
242-
expert_ids,
243-
num_tokens_post_padded,
244-
M,
245-
N,
246-
K,
247-
sorted_token_ids.shape[0],
248-
topk_ids.numel(),
249-
hidden_states.stride(0),
250-
hidden_states.stride(1),
251-
w1.stride(0),
252-
w1.stride(2),
253-
w1.stride(1),
254-
intermediate_cache1.stride(1),
255-
intermediate_cache1.stride(2),
256-
topk_weights.stride(1),
257-
sorted_token_ids.stride(0),
258-
MUL_ROUTED_WEIGHT=False,
259-
top_k=topk_ids.shape[1],
260-
compute_type=tl.bfloat16
261-
if hidden_states.dtype == torch.bfloat16 else tl.float16,
262-
**config,
263-
)
269+
invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1,
270+
topk_weights, topk_ids, sorted_token_ids,
271+
expert_ids, num_tokens_post_padded, False,
272+
topk_ids.shape[1], config)
264273

265274
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
266275

267-
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
268-
'BLOCK_SIZE_M']) * triton.cdiv(w2.shape[1], META['BLOCK_SIZE_N']), )
269-
fused_moe_kernel[grid](
270-
intermediate_cache2,
271-
w2,
272-
intermediate_cache3,
273-
topk_weights,
274-
sorted_token_ids,
275-
expert_ids,
276-
num_tokens_post_padded,
277-
M,
278-
w2.shape[1],
279-
w2.shape[2],
280-
sorted_token_ids.shape[0],
281-
topk_ids.numel(),
282-
intermediate_cache2.stride(0),
283-
intermediate_cache2.stride(1),
284-
w2.stride(0),
285-
w2.stride(2),
286-
w2.stride(1),
287-
intermediate_cache3.stride(1),
288-
intermediate_cache3.stride(2),
289-
topk_weights.stride(1),
290-
sorted_token_ids.stride(0),
291-
MUL_ROUTED_WEIGHT=True,
292-
top_k=1, #
293-
compute_type=tl.bfloat16
294-
if hidden_states.dtype == torch.bfloat16 else tl.float16,
295-
**config,
296-
)
276+
invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3,
277+
topk_weights, topk_ids, sorted_token_ids,
278+
expert_ids, num_tokens_post_padded, True, 1,
279+
config)
280+
297281
if inplace:
298282
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
299283
dim=1,

vllm/model_executor/models/deepseek.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -149,16 +149,6 @@ def pack_params(self):
149149

150150
self.w2 = self.w2.view(len(w2), *w2s[0].shape)
151151

152-
def fused_moe_infer(self, hidden_states: torch.Tensor,
153-
selected_experts: torch.Tensor,
154-
routing_weights: torch.Tensor) -> torch.Tensor:
155-
return fused_moe(hidden_states,
156-
self.w1,
157-
self.w2,
158-
routing_weights,
159-
selected_experts,
160-
inplace=True)
161-
162152
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
163153
batch_size, sequence_length, hidden_dim = hidden_states.shape
164154
hidden_states = hidden_states.view(-1, hidden_dim)
@@ -175,9 +165,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
175165
if self.config.norm_topk_prob:
176166
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
177167

178-
final_hidden_states = self.fused_moe_infer(hidden_states,
179-
selected_experts,
180-
routing_weights)
168+
final_hidden_states = fused_moe(hidden_states,
169+
self.w1,
170+
self.w2,
171+
routing_weights,
172+
selected_experts,
173+
inplace=True)
181174

182175
if self.config.n_shared_experts is not None:
183176
final_hidden_states = final_hidden_states + shared_output
@@ -290,15 +283,16 @@ def __init__(
290283
max_position_embeddings=max_position_embeddings,
291284
linear_method=linear_method,
292285
)
293-
self.mlp = DeepseekMoE(config=config,
294-
linear_method=linear_method) if (config.n_routed_experts is not None and \
295-
layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0) \
296-
else DeepseekMLP(
297-
hidden_size=config.hidden_size,
298-
intermediate_size=config.intermediate_size,
299-
hidden_act=config.hidden_act,
300-
linear_method=linear_method,
301-
)
286+
if (config.n_routed_experts is not None and \
287+
layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0):
288+
self.mlp = DeepseekMoE(config=config, linear_method=linear_method)
289+
else:
290+
self.mlp = DeepseekMLP(
291+
hidden_size=config.hidden_size,
292+
intermediate_size=config.intermediate_size,
293+
hidden_act=config.hidden_act,
294+
linear_method=linear_method,
295+
)
302296
self.input_layernorm = RMSNorm(config.hidden_size,
303297
eps=config.rms_norm_eps)
304298
self.post_attention_layernorm = RMSNorm(config.hidden_size,

0 commit comments

Comments
 (0)