Skip to content

Fixes and cleanup #82

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 190 commits into
base: modular-fused-experts
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
190 commits
Select commit Hold shift + click to select a range
1f86a2d
moe refactoring
bnellnm Apr 1, 2025
6b6630b
module deepgemm moe working
bnellnm Apr 1, 2025
33c61d6
working deep gemm, wip cutlass
bnellnm Apr 2, 2025
e1ab18a
working cutlass
bnellnm Apr 2, 2025
4965854
deepgemm working again
bnellnm Apr 2, 2025
a6564c1
cutlass working again
bnellnm Apr 2, 2025
1968c4a
cutlass working again
bnellnm Apr 2, 2025
be10ed4
fix inplace, format and name cleanups
bnellnm Apr 2, 2025
f8b64f5
fix inplace, format + name cleanups
bnellnm Apr 2, 2025
e94d6c1
test improvements
bnellnm Apr 3, 2025
d675138
make modular triton classes, fix edge cases
bnellnm Apr 3, 2025
a665564
fix outplace bug
bnellnm Apr 3, 2025
a6459aa
refactor dispatch/combine stuff
bnellnm Apr 3, 2025
da3fe2b
initial pplx dispatch/combine class
bnellnm Apr 3, 2025
f2fe65a
merge triton dispatch into standard, add some comments
bnellnm Apr 3, 2025
caf9805
format
bnellnm Apr 3, 2025
fab51b8
comments
bnellnm Apr 3, 2025
5694036
fix linter
bnellnm Apr 3, 2025
124f0ba
fix more linter stuff
bnellnm Apr 3, 2025
d98a3c3
cleanup for review
bnellnm Apr 3, 2025
a00c12c
review comments
bnellnm Apr 4, 2025
27cc6d1
forgot return
bnellnm Apr 4, 2025
1db9fcf
add dp_rank_num_tokens to DPMetadata
bnellnm Apr 4, 2025
374c55c
better check for fp8 in _fp8_permute
bnellnm Apr 4, 2025
fbf9370
updates
bnellnm Apr 28, 2025
be5c8d8
fix merge issues
bnellnm Apr 29, 2025
bdcaae2
fix lint
bnellnm Apr 29, 2025
e3ba64f
add pplx tests
bnellnm Apr 29, 2025
897f3a1
lint
bnellnm Apr 29, 2025
3feab72
undo random lint changes
bnellnm Apr 29, 2025
cf36977
more lint
bnellnm Apr 29, 2025
f861876
more lint nonsense
bnellnm Apr 29, 2025
f147062
WIP torch while
tlrmchlsmth Mar 15, 2025
eb58491
wip
tlrmchlsmth Mar 25, 2025
7a1b6e1
wip
tlrmchlsmth Mar 25, 2025
b1bef6f
wip
tlrmchlsmth Mar 27, 2025
fb73ea7
wip
tlrmchlsmth Mar 27, 2025
1cf5831
WIP integration
tlrmchlsmth Mar 28, 2025
91222dc
Add test for deep gemm matmul
bnellnm Feb 26, 2025
0e5081e
fix matmul test
bnellnm Feb 27, 2025
9e58fee
running
bnellnm Feb 27, 2025
a1ccb78
wip
bnellnm Feb 27, 2025
a1b033e
wip
bnellnm Feb 28, 2025
6b0aa02
debugging
bnellnm Feb 28, 2025
4b91cd4
debugging
bnellnm Feb 28, 2025
06acd02
fix
bnellnm Feb 28, 2025
24a90b9
update deep gemm
bnellnm Feb 28, 2025
49ec1c6
update deep gemm + small test case
bnellnm Mar 1, 2025
ceba476
wip
bnellnm Mar 2, 2025
9ac041c
wip
bnellnm Mar 2, 2025
696e6a2
problem with scores
bnellnm Mar 2, 2025
383c364
some passing tests
bnellnm Mar 3, 2025
4ff5a1a
some passing tests
bnellnm Mar 3, 2025
6c49e2d
topk > 1 doesn't work. prune oom-ing tests
bnellnm Mar 3, 2025
bada057
fix indices
bnellnm Mar 3, 2025
69788af
enable more tests
bnellnm Mar 3, 2025
cf83822
format
bnellnm Mar 3, 2025
3b57bf9
use fused_topk for unit test
bnellnm Mar 4, 2025
4b31217
every other block correct
bnellnm Mar 5, 2025
396cfa0
working
bnellnm Mar 5, 2025
25945cc
enable more tests
bnellnm Mar 5, 2025
4c3f491
working tests w/permute
bnellnm Mar 5, 2025
16e7d17
cleanups
bnellnm Mar 5, 2025
b05c810
wip
bnellnm Mar 6, 2025
0d15f37
not crashing
bnellnm Mar 6, 2025
9349503
baseline working integration
bnellnm Mar 6, 2025
3a744b3
add allow_deep_gemm flag
bnellnm Mar 6, 2025
449c6a1
wip
bnellnm Mar 7, 2025
5f21c96
better
bnellnm Mar 7, 2025
5bcbd93
fix some stuff
bnellnm Mar 8, 2025
2eeeedf
fix more stuff
bnellnm Mar 8, 2025
8fce359
cleanups
bnellnm Mar 8, 2025
2550173
some integration tests working
bnellnm Mar 8, 2025
1b19a9f
almost all tests passing
bnellnm Mar 10, 2025
e24d2c1
cleanup temp construction a bit
bnellnm Mar 10, 2025
c4a1a2c
fix rest of tests
bnellnm Mar 10, 2025
ce573fd
cleanups + format
bnellnm Mar 10, 2025
b402f57
do more of output computation in place
bnellnm Mar 10, 2025
4494c17
add env var
bnellnm Mar 10, 2025
afee1df
formatting, remove some blocking restrictions
bnellnm Mar 12, 2025
6872005
wip
bnellnm Mar 12, 2025
98b3256
fix resizing of output
bnellnm Mar 12, 2025
9060108
fix resizing of output
bnellnm Mar 12, 2025
411fc7a
fixes
bnellnm Mar 12, 2025
b773fdc
aligned chunking working for deep gemm
bnellnm Mar 12, 2025
ae2c791
unaligned chunking for deep gemm
bnellnm Mar 13, 2025
269f18b
cleanup wip
bnellnm Mar 13, 2025
0ea9a5d
clean up some blocking stuff
bnellnm Mar 13, 2025
52f53b3
clean up some blocking stuff
bnellnm Mar 13, 2025
de0135c
tweaks
bnellnm Mar 14, 2025
a4a0719
fix rebase
bnellnm Mar 15, 2025
58c733b
rebase
bnellnm Mar 17, 2025
3f397b0
refactoring + minor perf improvements
bnellnm Mar 21, 2025
48b55c4
refactoring + perf tweaks
bnellnm Mar 22, 2025
3c8704c
remove debugging cruft
bnellnm Mar 24, 2025
e885027
cache resize refactoring
bnellnm Mar 24, 2025
e8e6b6d
cleanups
bnellnm Mar 25, 2025
4c64246
format
bnellnm Mar 25, 2025
f3ff692
revert test.txt, fix mypy errors
bnellnm Mar 25, 2025
9d048ec
review comments
bnellnm Mar 26, 2025
9504dea
review comments
bnellnm Mar 27, 2025
ef0eee9
clean up use_dg flags
bnellnm Mar 27, 2025
e9c5c27
remove check for aligned M
bnellnm Mar 27, 2025
b3287ac
rebase + clean up test
bnellnm Mar 28, 2025
c820cfe
fix format
bnellnm Mar 28, 2025
30cfab4
Clean up diff
tlrmchlsmth Mar 31, 2025
24399c3
[Distributed] Add custom allreduce support for ROCM (#14125)
ilmarkov Apr 1, 2025
e58944e
[Bugfix][Model] fix mllama multi-image (#14883)
yma11 Apr 1, 2025
9fbd1a9
module deepgemm moe working
bnellnm Apr 1, 2025
a1eecd5
working deep gemm, wip cutlass
bnellnm Apr 2, 2025
6efacf1
working cutlass
bnellnm Apr 2, 2025
e1dd818
deepgemm working again
bnellnm Apr 2, 2025
bf02e1c
fix inplace, format and name cleanups
bnellnm Apr 2, 2025
0abc8e5
test improvements
bnellnm Apr 3, 2025
7896795
make modular triton classes, fix edge cases
bnellnm Apr 3, 2025
e42df0f
refactor dispatch/combine stuff
bnellnm Apr 3, 2025
f1cb920
initial pplx dispatch/combine class
bnellnm Apr 3, 2025
a023439
merge triton dispatch into standard, add some comments
bnellnm Apr 3, 2025
913c017
format
bnellnm Apr 3, 2025
8e3284a
cleanup for review
bnellnm Apr 3, 2025
cc1a878
hacking
bnellnm Apr 4, 2025
6c320cf
hacking
bnellnm Apr 7, 2025
ddc0b99
init stuff
bnellnm Apr 7, 2025
66ae985
call super ctor + fix random stuff
bnellnm Apr 7, 2025
4c23784
fix use_ep bug
tlrmchlsmth Apr 7, 2025
64e4281
Fix dp_size
tlrmchlsmth Apr 7, 2025
169403a
add comment
tlrmchlsmth Apr 7, 2025
c4110cb
fixes
tlrmchlsmth Apr 7, 2025
f4ae47f
get a bit further
bnellnm Apr 7, 2025
1d7aa87
hacking in dispatch_combine
bnellnm Apr 9, 2025
1316317
hook up some wires
bnellnm Apr 10, 2025
b466854
seems to be working
bnellnm Apr 10, 2025
f005ce6
wip
bnellnm Apr 11, 2025
b4dc0b1
batched moe test
bnellnm Apr 14, 2025
7490b67
simple test
bnellnm Apr 15, 2025
11edcc1
cleanup
bnellnm Apr 15, 2025
74f0a54
test pplx w/naive implementation
bnellnm Apr 15, 2025
406924d
test pplx w/naive implementation
bnellnm Apr 15, 2025
80164b9
hack fix for chunking loop
bnellnm Apr 15, 2025
84ea0bd
wip. add pplx unit test
bnellnm Apr 16, 2025
8a9895c
work on unit test
bnellnm Apr 17, 2025
d37a301
dispatch/combine unit test
bnellnm Apr 17, 2025
be94232
forgot file
bnellnm Apr 17, 2025
5e5b3ad
somewhat working unit test
bnellnm Apr 18, 2025
669e4f3
wip
bnellnm Apr 18, 2025
cbcf12a
fix test
bnellnm Apr 18, 2025
9a800ad
some cleanup
bnellnm Apr 19, 2025
7688233
wip
bnellnm Apr 19, 2025
02d42a7
wip
bnellnm Apr 29, 2025
27bee28
undo random changes
bnellnm Apr 29, 2025
089a71d
merge
bnellnm Apr 29, 2025
db3f01d
tweak
bnellnm Apr 29, 2025
17a978b
revert hack
bnellnm Apr 29, 2025
323d12f
fixes
bnellnm Apr 29, 2025
bb4e896
pplx update
bnellnm Apr 29, 2025
1260147
varun's fixes
bnellnm Apr 29, 2025
cb7bec9
varun's fixes
bnellnm Apr 29, 2025
a403f7c
tweak bound_m
bnellnm Apr 29, 2025
501d04c
run linter
bnellnm Apr 29, 2025
7f09738
more lint stuff
bnellnm Apr 29, 2025
d42186f
add guards for pplx import
bnellnm Apr 30, 2025
764e646
fix forward_chunked
Apr 30, 2025
2a31f90
fix more lint
bnellnm Apr 30, 2025
6a3daba
cleanups
bnellnm Apr 30, 2025
9590b96
cleanups + lint, layer.py wip
bnellnm Apr 30, 2025
ff40a9c
fix parallel_state lint
bnellnm Apr 30, 2025
c7cb7df
fix M=1 pplx test
bnellnm May 1, 2025
4f4584a
fix M=1 pplx test
bnellnm May 1, 2025
73226b9
fix M=1 pplx test
bnellnm May 1, 2025
a77fb2c
lint
bnellnm May 1, 2025
48ba146
remove valid pplx check
bnellnm May 1, 2025
4c40380
semi-working cudagraphs
bnellnm May 2, 2025
1938bc8
fix reference implementations
bnellnm May 2, 2025
0f2e37a
wip ref impl
bnellnm May 5, 2025
a003bd8
improve ref impl
bnellnm May 6, 2025
2bafbe0
wip
bnellnm May 6, 2025
ca763c3
fix merge
bnellnm May 6, 2025
054c10a
fix merge
bnellnm May 6, 2025
0851b31
wip
May 1, 2025
3e2cf4b
zero out attn outputs during profile run
May 7, 2025
d2862c0
lint
bnellnm May 7, 2025
e478c1a
lint
bnellnm May 7, 2025
90e9d05
revert lint changes to requirements/test.txt
bnellnm May 7, 2025
be1a8e5
revert lint changes to compiler_interface.py
bnellnm May 7, 2025
916f902
fix merge
bnellnm May 7, 2025
c04cb12
fix more lint errors
bnellnm May 7, 2025
f5bcc22
fix lint
bnellnm May 7, 2025
f7b7070
Fixes and cleanup
May 9, 2025
af02167
import sort
May 9, 2025
081f11f
fix shutdown
May 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
if (num_tokens == 0) { \
return; \
} \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
Expand Down
14 changes: 14 additions & 0 deletions csrc/dispatch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,19 @@
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)

#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__)

#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))

#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
8 changes: 4 additions & 4 deletions csrc/moe/moe_align_sum_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
}

if (use_global_memory) {
VLLM_DISPATCH_INTEGRAL_TYPES(
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
Expand All @@ -351,7 +351,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
cumsum_buffer.data_ptr<int32_t>());
});
} else if (use_i16) {
VLLM_DISPATCH_INTEGRAL_TYPES(
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// set dynamic shared mem
auto kernel =
Expand All @@ -366,7 +366,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
topk_ids.numel());
});
} else {
VLLM_DISPATCH_INTEGRAL_TYPES(
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
auto kernel =
vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
Expand All @@ -391,7 +391,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
TORCH_CHECK(num_experts == 256,
"sgl_moe_align_block_size kernel only supports deepseek v3.");

VLLM_DISPATCH_INTEGRAL_TYPES(
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `cumsum` tensors
auto options_int =
Expand Down
63 changes: 45 additions & 18 deletions csrc/moe/topk_softmax_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__
}
}

template <int TPB>
__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
template <int TPB, typename IndType>
__launch_bounds__(TPB) __global__ void moeTopK(
const float* inputs_after_softmax,
const bool* finished,
float* output,
IndType* indices,
int* source_rows,
const int num_experts,
const int k,
const int start_expert,
const int end_expert)
{

using cub_kvp = cub::KeyValuePair<int, float>;
Expand Down Expand Up @@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax
2) This implementation assumes k is small, but will work for any k.
*/

template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, typename IndType>
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices,
int* source_rows, const int k, const int start_expert, const int end_expert)
{
// We begin by enforcing compile time assertions and setting up compile time constants.
Expand Down Expand Up @@ -397,8 +405,8 @@ struct TopkConstants
};
} // namespace detail

template <int EXPERTS, int WARPS_PER_TB>
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
template <int EXPERTS, int WARPS_PER_TB, typename IndType>
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
{
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
Expand All @@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
token_expert_indices, num_tokens, topk, 0, num_experts, \
stream);

template <typename IndType>
void topkGatingSoftmaxKernelLauncher(
const float* gating_output,
float* topk_weights,
int* topk_indicies,
IndType* topk_indicies,
int* token_expert_indices,
float* softmax_workspace,
const int num_tokens,
Expand Down Expand Up @@ -493,14 +502,32 @@ void topk_softmax(
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);

if(topk_indices.scalar_type() == at::ScalarType::Int)
{
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);
}
else
{
assert(topk_indices.scalar_type() == at::ScalarType::UInt32);
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<uint32_t>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);
}
}
22 changes: 16 additions & 6 deletions examples/offline_inference/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,17 @@ def parse_args():
type=int,
default=0,
help="Master node port")
parser.add_argument("--enforce-eager",
action='store_true',
help="Enforce eager mode execution.")
parser.add_argument("--trust-remote-code",
action='store_true',
help="Trust remote code.")
return parser.parse_args()


def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
dp_master_port, GPUs_per_dp_rank):
dp_master_port, GPUs_per_dp_rank, enforce_eager, trust_remote_code):
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
os.environ["VLLM_DP_SIZE"] = str(dp_size)
Expand Down Expand Up @@ -109,10 +115,13 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
max_tokens=[16, 20][global_dp_rank % 2])

# Create an LLM.
llm = LLM(model=model,
tensor_parallel_size=GPUs_per_dp_rank,
enforce_eager=True,
enable_expert_parallel=True)
llm = LLM(
model=model,
tensor_parallel_size=GPUs_per_dp_rank,
enforce_eager=enforce_eager,
enable_expert_parallel=True,
trust_remote_code=trust_remote_code,
)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for i, output in enumerate(outputs):
Expand Down Expand Up @@ -155,7 +164,8 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
proc = Process(target=main,
args=(args.model, dp_size, local_dp_rank,
global_dp_rank, dp_master_ip, dp_master_port,
tp_size))
tp_size, args.enforce_eager,
args.trust_remote_code))
proc.start()
procs.append(proc)
exit_code = 0
Expand Down
107 changes: 107 additions & 0 deletions tests/kernels/moe/test_batched_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass

import pytest
import torch
import triton.language as tl

from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
invoke_moe_batched_triton_kernel)


@dataclass
class BatchedMMConfig:
dtype: torch.dtype
num_experts: int
max_tokens_per_expert: int
K: int
N: int


@dataclass
class BatchedMMTensors:
A: torch.Tensor # [E, max_tokens, K]
B: torch.Tensor # [E, K, N] - column major
C: torch.Tensor # [E, max_tokens, N]
num_expert_tokens: torch.Tensor # [E]

@staticmethod
def make_tensors(config: BatchedMMConfig):
A = torch.randn(
(config.num_experts, config.max_tokens_per_expert, config.K),
device="cuda",
dtype=config.dtype) / 50.0
B = torch.randn((config.num_experts, config.N, config.K),
device="cuda",
dtype=config.dtype) / 50.0
C = torch.zeros(
(config.num_experts, config.max_tokens_per_expert, config.N),
device="cuda",
dtype=config.dtype)
num_expert_tokens = torch.randint(low=0,
high=config.max_tokens_per_expert,
size=(config.num_experts, ),
device="cuda",
dtype=torch.int32)
return BatchedMMTensors(A, B, C, num_expert_tokens)


def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
num_expert_tokens: torch.Tensor) -> torch.Tensor:

num_expert_tokens_cpu = num_expert_tokens.clone()
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
num_experts = num_expert_tokens.size(0)

for e in range(num_experts):
num_tokens = num_expert_tokens_cpu[e]
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)

return C


@pytest.mark.parametrize("num_experts", [16, 32])
@pytest.mark.parametrize("max_tokens_per_expert",
[32, 64, 128, 192, 224, 256, 512])
@pytest.mark.parametrize("K", [128, 256, 1024])
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
N: int, dtype: torch.dtype):

config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N)
tensors = BatchedMMTensors.make_tensors(config)

test_output = tensors.C
ref_output = test_output.clone()

compute_tl_dtype = {
torch.float16: tl.float16,
torch.bfloat16: tl.bfloat16,
torch.float32: tl.float32
}[test_output.dtype]
invoke_moe_batched_triton_kernel(
tensors.A,
tensors.B,
test_output,
tensors.num_expert_tokens,
compute_tl_dtype,
# Quantization data
None,
None,
None,
# Quantization schemes
False,
False,
False,
config={
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 16
})

ref_output = ref_impl(tensors.A, tensors.B, ref_output,
tensors.num_expert_tokens)

torch.testing.assert_close(test_output, ref_output, atol=1e-3, rtol=1e-3)
Loading
Loading