Skip to content

permute/unpermute kernel for moe optimization #14568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 2, 2025

Conversation

CalebDu
Copy link
Contributor

@CalebDu CalebDu commented Mar 10, 2025

moe_permute kernel expands and oreders token in activation to gather uncontinuous tokens for each expert. And then call grouped-gemm for moe speedup.
moe_unpermute kernel reduces expanded grouped-gemm output and scales with topk_weight.
image

This implementation refers to moe kernel in tensorrt-llm in archive https://github.com/BBuf/tensorrt-llm-moe/tree/master.
Currently, unsupport Expert-Parallelism with expert_map, will follow up with updates.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Mar 10, 2025
@gzy19990617
Copy link

请问这个算子是为后面做group_gemm做准备工作嘛

@CalebDu
Copy link
Contributor Author

CalebDu commented Mar 12, 2025

请问这个算子是为后面做group_gemm做准备工作嘛

yes

@gzy19990617
Copy link

gzy19990617 commented Mar 12, 2025 via email

@CalebDu
Copy link
Contributor Author

CalebDu commented Mar 12, 2025

请问预计什么时候会支持呐

---- Replied Message ---- | From | @.> | | Date | 03/12/2025 16:54 | | To | vllm-project/vllm @.> | | Cc | gaoziyuan @.>, Comment @.> | | Subject | Re: [vllm-project/vllm] permute/unpermute kernel for moe optimization (PR #14568) | 请问这个算子是为后面做group_gemm做准备工作嘛 yes — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.> CalebDu left a comment (vllm-project/vllm#14568) 请问这个算子是为后面做group_gemm做准备工作嘛 yes — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.>

我也不清楚具体release的时间

@CalebDu CalebDu force-pushed the caleb_dev branch 3 times, most recently from 8d6c8fb to 076bee0 Compare March 12, 2025 16:22
@CalebDu
Copy link
Contributor Author

CalebDu commented Mar 13, 2025

update expert_map support for expert parallelism.

@CalebDu
Copy link
Contributor Author

CalebDu commented Mar 15, 2025

This PR submits 2 kernel moe_permute and moe_unpermute.
moe_permute includes:

  • preprocess_topk_id_launcher(optional for EP):preprocess_topk_id is to map global expert id to local expert id for each ep rank. For example, 4 expert with ep_size=2. ep_rank=1 owns global expert id [2,3] with expert_map=[-1, -1, 0, 1], preprocess_topk_id process topk_ids and map global expert id [2, 3] to local_expert id [0, 1] and map global expert id [0, 1] ( not in ep rank=1) to [4, 5] by plus n_expert. This map operation is to make local expert high priority in following sort topk_ids and scan expert_first_token_offset for local expert. So sorting topk_ids can get results in global expert id [2, 3, 0, 1] (local expert id[0, 1,4,5] ) order and scanning get expert_first_token_offset of local expert id 0 1(global expert id 2 3).
  • sortAndScanExpert:call cub key-value radix-sort to sort local expert_id in topk_id and get mapping idx for expanding and permuting input activation. And call computeExpertFirstTokenOffset to get first token offset in permuted activation m-dim for following group gemm.
  • expandInputRowsKernelLauncher : expand and permute input activation with idx map from sortAndScanExpert. And get mapping idx for moe_unpermute.

moe_unpermute includes:

  • finalizeMoeRoutingKernelLauncher:reduce permuted activation with scaling topk_weight and sort by original token order.

@bnellnm
Copy link
Contributor

bnellnm commented Mar 18, 2025

Hi @CalebDu , thanks for working on this. I think this might help me out with #13932. I am currently trying to use your PR to do the permute/unpermute steps needed for the DeepGemm grouped gemm kernel.

I did run into a test failure with certain problem sizes that are not in the original test, e.g.

@pytest.mark.parametrize("n_token", [2048])
@pytest.mark.parametrize("n_hidden", [7168])
@pytest.mark.parametrize("n_expert", [64])
@pytest.mark.parametrize("topk", [6])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("ep_size", [1])
def test_moe_permute_unpermute(
    n_token: int, n_hidden: int, topk: int, n_expert: int, ep_size: int, dtype: torch.dtype
)

========================================================= FAILURES ==========================================================
____________________________________ test_moe_permute_unpermute[1-dtype0-6-64-7168-2048] ____________________________________

n_token = 2048, n_hidden = 7168, topk = 6, n_expert = 64, ep_size = 1, dtype = torch.bfloat16

    @pytest.mark.parametrize("n_token", [2048])
    @pytest.mark.parametrize("n_hidden", [7168])
    @pytest.mark.parametrize("n_expert", [64])
    @pytest.mark.parametrize("topk", [6])
    #@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
    @pytest.mark.parametrize("dtype", [torch.bfloat16])
    #@pytest.mark.parametrize("ep_size", EP_SIZE)
    @pytest.mark.parametrize("ep_size", [1])
    def test_moe_permute_unpermute(
        n_token: int, n_hidden: int, topk: int, n_expert: int, ep_size: int, dtype: torch.dtype
    ):
        ep_rank= 0 #np.random.randint(0, ep_size)
        expert_map = None
        n_local_expert = n_expert
        if(ep_size != 1) :
            n_local_expert, expert_map = determine_expert_map(ep_size, ep_rank, n_expert)
            expert_map = expert_map.cuda()
        start_expert = n_local_expert * ep_rank
        current_platform.seed_everything(0)
        hidden_states = torch.randn((n_token, n_hidden), device="cuda").to(dtype)
        gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype)
        topk_weights, topk_ids, token_expert_indices = fused_topk(
            hidden_states, gating_output, topk, False
        )

        gold0, gold1, gold2 = torch_permute(hidden_states,
                                            topk_ids,
                                            token_expert_indices,
                                            topk, n_expert,n_local_expert,
                                            start_expert, expert_map=expert_map)

        result0, result1, result2 = moe_permute(hidden_states,
                                                topk_weights, topk_ids,
                                                token_expert_indices,
                                                topk, n_expert, n_local_expert, expert_map
                                                )
        # print(gold0, result0)
        #print(gold1, result1)
        # print(gold2, result2)

        # check expert_first_token_offset
>       torch.testing.assert_close(gold1,
                                   result1,
                                   atol=0,
                                   rtol=0)
E       AssertionError: Tensor-likes are not equal!
E       
E       Mismatched elements: 2 / 65 (3.1%)
E       Greatest absolute difference: 106 at index (3,)
E       Greatest relative difference: 0.150141641497612 at index (3,)

tests/kernels/test_moe_permute_unpermute.py:129: AssertionError
================================================== short test summary info ==================================================
FAILED tests/kernels/test_moe_permute_unpermute.py::test_moe_permute_unpermute[1-dtype0-6-64-7168-2048] - AssertionError: Tensor-likes are not equal!
===================================================== 1 failed in 3.08s =====================================================

Do you have any idea where the problem might be?

@CalebDu
Copy link
Contributor Author

CalebDu commented Mar 19, 2025

Hi @CalebDu , thanks for working on this. I think this might help me out with #13932. I am currently trying to use your PR to do the permute/unpermute steps needed for the DeepGemm grouped gemm kernel.

I did run into a test failure with certain problem sizes that are not in the original test, e.g.

@pytest.mark.parametrize("n_token", [2048])
@pytest.mark.parametrize("n_hidden", [7168])
@pytest.mark.parametrize("n_expert", [64])
@pytest.mark.parametrize("topk", [6])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("ep_size", [1])
def test_moe_permute_unpermute(
    n_token: int, n_hidden: int, topk: int, n_expert: int, ep_size: int, dtype: torch.dtype
)

========================================================= FAILURES ==========================================================
____________________________________ test_moe_permute_unpermute[1-dtype0-6-64-7168-2048] ____________________________________

n_token = 2048, n_hidden = 7168, topk = 6, n_expert = 64, ep_size = 1, dtype = torch.bfloat16

    @pytest.mark.parametrize("n_token", [2048])
    @pytest.mark.parametrize("n_hidden", [7168])
    @pytest.mark.parametrize("n_expert", [64])
    @pytest.mark.parametrize("topk", [6])
    #@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
    @pytest.mark.parametrize("dtype", [torch.bfloat16])
    #@pytest.mark.parametrize("ep_size", EP_SIZE)
    @pytest.mark.parametrize("ep_size", [1])
    def test_moe_permute_unpermute(
        n_token: int, n_hidden: int, topk: int, n_expert: int, ep_size: int, dtype: torch.dtype
    ):
        ep_rank= 0 #np.random.randint(0, ep_size)
        expert_map = None
        n_local_expert = n_expert
        if(ep_size != 1) :
            n_local_expert, expert_map = determine_expert_map(ep_size, ep_rank, n_expert)
            expert_map = expert_map.cuda()
        start_expert = n_local_expert * ep_rank
        current_platform.seed_everything(0)
        hidden_states = torch.randn((n_token, n_hidden), device="cuda").to(dtype)
        gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype)
        topk_weights, topk_ids, token_expert_indices = fused_topk(
            hidden_states, gating_output, topk, False
        )

        gold0, gold1, gold2 = torch_permute(hidden_states,
                                            topk_ids,
                                            token_expert_indices,
                                            topk, n_expert,n_local_expert,
                                            start_expert, expert_map=expert_map)

        result0, result1, result2 = moe_permute(hidden_states,
                                                topk_weights, topk_ids,
                                                token_expert_indices,
                                                topk, n_expert, n_local_expert, expert_map
                                                )
        # print(gold0, result0)
        #print(gold1, result1)
        # print(gold2, result2)

        # check expert_first_token_offset
>       torch.testing.assert_close(gold1,
                                   result1,
                                   atol=0,
                                   rtol=0)
E       AssertionError: Tensor-likes are not equal!
E       
E       Mismatched elements: 2 / 65 (3.1%)
E       Greatest absolute difference: 106 at index (3,)
E       Greatest relative difference: 0.150141641497612 at index (3,)

tests/kernels/test_moe_permute_unpermute.py:129: AssertionError
================================================== short test summary info ==================================================
FAILED tests/kernels/test_moe_permute_unpermute.py::test_moe_permute_unpermute[1-dtype0-6-64-7168-2048] - AssertionError: Tensor-likes are not equal!
===================================================== 1 failed in 3.08s =====================================================

Do you have any idea where the problem might be?

I'll try this test case, figure out why mismatch and fix it soon.

@CalebDu
Copy link
Contributor Author

CalebDu commented Mar 19, 2025

@bnellnm this bug is caused by the workspace is too small for cub radix_sort. After I expand workspace, I fix it.

@bnellnm
Copy link
Contributor

bnellnm commented Mar 19, 2025

@bnellnm this bug is caused by the workspace is too small for cub radix_sort. After I expand workspace, I fix it.

Thanks!

Copy link

mergify bot commented Mar 26, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @CalebDu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@CalebDu
Copy link
Contributor Author

CalebDu commented Mar 26, 2025

Add align_block_size support for contiguous group gemm in deepgemm.
Round up token amount in each expert to align_block_size and scan aligned amount to get align_first_token_offset for each expert. In permuted_ hidden , align_permuted_token_idx = permuted_token_idx - first_token_offset+ align_first_token_offset, so make permuted_hidden fit contiguous group gemm in deepgemm with align_block_size.

@bnellnm
Copy link
Contributor

bnellnm commented Mar 27, 2025

@CalebDu , thanks for adding the blocking support! I've been working on integrating the new version with DeepGemm but I'm running into problems with the m_indices. I don't think they are being computed correctly. I think the best way to test for correctness is to use moe_align_block_size like this:

block_m = 128
_, expert_ids, _ = moe_align_block_size(curr_topk_ids, block_m, global_num_experts, expert_map)
m_indices = torch.repeat_interleave(expert_ids, block_m, dim=0)

@CalebDu
Copy link
Contributor Author

CalebDu commented Mar 27, 2025

@CalebDu , thanks for adding the blocking support! I've been working on integrating the new version with DeepGemm but I'm running into problems with the m_indices. I don't think they are being computed correctly. I think the best way to test for correctness is to use moe_align_block_size like this:

block_m = 128
_, expert_ids, _ = moe_align_block_size(curr_topk_ids, block_m, global_num_experts, expert_map)
m_indices = torch.repeat_interleave(expert_ids, block_m, dim=0)

I notice deepgemm update document about m_grouped_gemm_fp8_fp8_bf16_nt_contiguous in commit deepseek-ai/DeepGEMM@ded740f. In original version, m_indices mark -1 for aligned invalid row ans skip it. So I fill -1 for all padding row in m_indices. It seems be not same as moe_align_block_size. I will update this part to fill m_indices with expert_id for all row rather than only valid row.
I have a question, deepgemm will check m_indices.numel() == lhs.shape[0], but sorted_ids.numel()=max_num_tokens_padded!=m_indices.numel()= round_up(max_num_tokens_padded, align_block_size ) in following code.

block_m = 128
sorted_ids, expert_ids, _ = moe_align_block_size(curr_topk_ids, block_m, global_num_experts, expert_map)
m_indices = torch.repeat_interleave(expert_ids, block_m, dim=0)

And moe_align_block_size will also fill m_indices with -1 for the expert that is invalid in ep_map. So we'd better to view/slice only aligned row in permuted_hiddenand m_indices for valid expert to compute grouped gemm. Like following code.

permuted_hidden = permuted_hidden[: align_first_token_offset[-1],...]
m_indices = m_indices[: align_first_token_offset[-1]]

Do you have a better idea?

@bnellnm
Copy link
Contributor

bnellnm commented Mar 27, 2025

@CalebDu , thanks for adding the blocking support! I've been working on integrating the new version with DeepGemm but I'm running into problems with the m_indices. I don't think they are being computed correctly. I think the best way to test for correctness is to use moe_align_block_size like this:

block_m = 128
_, expert_ids, _ = moe_align_block_size(curr_topk_ids, block_m, global_num_experts, expert_map)
m_indices = torch.repeat_interleave(expert_ids, block_m, dim=0)

I notice deepgemm update document about m_grouped_gemm_fp8_fp8_bf16_nt_contiguous in commit deepseek-ai/DeepGEMM@ded740f. In original version, m_indices mark -1 for aligned invalid row ans skip it. So I fill -1 for all padding row in m_indices. It seems be not same as moe_align_block_size. I will update this part to fill m_indices with expert_id for all row rather than only valid row. I have a question, deepgemm will check m_indices.numel() == lhs.shape[0], but sorted_ids.numel()=max_num_tokens_padded!=m_indices.numel()= round_up(max_num_tokens_padded, align_block_size ) in following code.

block_m = 128
sorted_ids, expert_ids, _ = moe_align_block_size(curr_topk_ids, block_m, global_num_experts, expert_map)
m_indices = torch.repeat_interleave(expert_ids, block_m, dim=0)

And moe_align_block_size will also fill m_indices with -1 for the expert that is invalid in ep_map. So we'd better to view/slice only aligned row in permuted_hiddenand m_indices for valid expert to compute grouped gemm. Like following code.

permuted_hidden = permuted_hidden[: align_first_token_offset[-1],...]
m_indices = m_indices[: align_first_token_offset[-1]]

Do you have a better idea?

The -1 in the original DeepGemm documentation was a bug. They never actually supported it (I ran into this while doing the initial integration). In the non expert_map case, moe_align_block_size fills in unused sections with 0s. I haven't really tested DeepGemm with an explicit expert_map but I guess it would be ok to replace any -1 with 0 (or any other valid expert index) since it will get dropped at the end anyway.

The above code snippet works for DeepGemm (without an explicit expert_map). I'll have to add some tests + code for the user suppled expert_map. At the moment unused values get chopped off at the end by the inverse map. But it would probably be better if we could chop them off up front like you suggest. I'm not sure of a clean/simple way to do that since they'd need to be chopped off in 128 element sized chunks (at least for DeepGemm).

@CalebDu
Copy link
Contributor Author

CalebDu commented Mar 27, 2025

The -1 in the original DeepGemm documentation was a bug. They never actually supported it (I ran into this while doing the initial integration). In the non expert_map case, moe_align_block_size fills in unused sections with 0s. I haven't really tested DeepGemm with an explicit expert_map but I guess it would be ok to replace any -1 with 0 (or any other valid expert index) since it will get dropped at the end anyway.

The above code snippet works for DeepGemm (without an explicit expert_map). I'll have to add some tests + code for the user suppled expert_map. At the moment unused values get chopped off at the end by the inverse map. But it would probably be better if we could chop them off up front like you suggest. I'm not sure of a clean/simple way to do that since they'd need to be chopped off in 128 element sized chunks (at least for DeepGemm).

I update code about fill padding row with expert_id in m_indices rather than -1 for each local valid expert and still fill -1 for non-local expert. I think filling 0 (or other valid expert index) in m_indices for non-local expert is not a good idea. Because it may make a lot redundant compute and memory traffic for many tailing padding row in DeepGemm.
If pass align_block_size to moe_permute kernel, output expert_first_token_offset will align up to align_block_size. For example,local expert [0,1] has [127, 129] tokens respectively. expert_first_token_offset will be [0, 128, 384] rather than [0, 127, 256]. You can use expert_first_token_offset[-1] directly to slice front valid row in permuted_hidden to skip -1 in m_indices.

@bnellnm
Copy link
Contributor

bnellnm commented Mar 27, 2025

The -1 in the original DeepGemm documentation was a bug. They never actually supported it (I ran into this while doing the initial integration). In the non expert_map case, moe_align_block_size fills in unused sections with 0s. I haven't really tested DeepGemm with an explicit expert_map but I guess it would be ok to replace any -1 with 0 (or any other valid expert index) since it will get dropped at the end anyway.
The above code snippet works for DeepGemm (without an explicit expert_map). I'll have to add some tests + code for the user suppled expert_map. At the moment unused values get chopped off at the end by the inverse map. But it would probably be better if we could chop them off up front like you suggest. I'm not sure of a clean/simple way to do that since they'd need to be chopped off in 128 element sized chunks (at least for DeepGemm).

I update code about fill padding row with expert_id in m_indices rather than -1 for each local valid expert and still fill -1 for non-local expert. I think filling 0 (or other valid expert index) in m_indices for non-local expert is not a good idea. Because it may make a lot redundant compute and memory traffic for many tailing padding row in DeepGemm. If pass align_block_size to moe_permute kernel, output expert_first_token_offset will align up to align_block_size. For example,local expert [0,1] has [127, 129] tokens respectively. expert_first_token_offset will be [0, 128, 384] rather than [0, 127, 256]. You can use expert_first_token_offset[-1] directly to slice front valid row in permuted_hidden to skip -1 in m_indices.

Cool, I didn't realize you could use the first offset tokens to exclude the unused bits. If that's the case I don't think it matters what goes into m_indices since they will be sliced off anyway.

@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label May 2, 2025
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation looks nice and clean. Are there any benchmark results?

(I did spot one potential overflow that should be addressed)

CalebDu added 12 commits May 2, 2025 00:14
Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn>
Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn>
Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn>
Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn>
Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn>
Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn>
Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn>
…ices`rather than -1,

Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn>
1. remove unused code
2. move all the non-trivial definitions from moe_permute_unpermute_kernel.h to .cu and .inl
3. some minor update

Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn>
…ll invoking fused_topk code

Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn>
Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn>
2. fix potential overflow and remove debug cruft with  tlrmchlsmth's review
3. add benchmark for performance

Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn>
@CalebDu
Copy link
Contributor Author

CalebDu commented May 2, 2025

@tlrmchlsmth I update code with your review. And fix ci failed in calling FusedMoE.select_experts.
I add benchmark for comparison permute/unpermte customized kernel in deepseek v3 with naive python implementation(_moe_permute _moe_unpermute_and_reduce) in deep_gemm_moe.py.

python vllm/benchmarks/kernels/benchmark_moe_permute_unpermute.py --dtype=fp8_w8a8  --model deepseek-ai/DeepSeek-V3 --trust-remote-code --use-customized-permute

benchmark in H20

batch size: permute customized permute python unpermute customized unpermute python
Batch size: 1 Permute time: 65.64 us Permute time: 1449.89 us Unpermute time: 11.29 us Unpermute time: 13.10 us
Batch size: 2 Permute time: 68.92 us Permute time: 1449.52 us Unpermute time: 11.45 us Unpermute time: 12.98 us
Batch size: 4 Permute time: 68.61 us Permute time: 1449.67 us Unpermute time: 11.71 us Unpermute time: 13.48 us
Batch size: 8 Permute time: 70.00 us Permute time: 1451.41 us Unpermute time: 12.54 us Unpermute time: 15.40 us
Batch size: 16 Permute time: 72.08 us Permute time: 1451.32 us Unpermute time: 12.34 us Unpermute time: 21.45 us
Batch size: 24 Permute time: 76.10 us Permute time: 1456.61 us Unpermute time: 12.37 us Unpermute time: 27.91 us
Batch size: 32 Permute time: 81.13 us Permute time: 1456.31 us Unpermute time: 12.52 us Unpermute time: 33.58 us
Batch size: 48 Permute time: 85.94 us Permute time: 1465.27 us Unpermute time: 12.66 us Unpermute time: 45.62 us
Batch size: 64 Permute time: 95.78 us Permute time: 1469.02 us Unpermute time: 12.77 us Unpermute time: 60.06 us
Batch size: 96 Permute time: 109.39 us Permute time: 1479.88 us Unpermute time: 14.10 us Unpermute time: 84.60 us
Batch size: 128 Permute time: 126.87 us Permute time: 1491.65 us Unpermute time: 14.32 us Unpermute time: 110.02 us
Batch size: 256 Permute time: 196.66 us Permute time: 1537.67 us Unpermute time: 17.87 us Unpermute time: 226.25 us
Batch size: 512 Permute time: 335.68 us Permute time: 1628.52 us Unpermute time: 28.91 us Unpermute time: 455.32 us
Batch size: 1024 Permute time: 629.52 us Permute time: 1810.61 us Unpermute time: 45.61 us Unpermute time: 895.50 us
Batch size: 1536 Permute time: 914.56 us Permute time: 1997.25 us Unpermute time: 52.90 us Unpermute time: 1335.12 us
Batch size: 2048 Permute time: 1198.34 us Permute time: 2191.29 us Unpermute time: 54.18 us Unpermute time: 1777.86 us
Batch size: 3072 Permute time: 1757.11 us Permute time: 2583.16 us Unpermute time: 63.88 us Unpermute time: 2656.25 us
Batch size: 4096 Permute time: 2307.67 us Permute time: 3000.42 us Unpermute time: 83.67 us Unpermute time: 3535.37 us

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for posting the performance numbers!

@simon-mo simon-mo merged commit 3e887d2 into vllm-project:main May 2, 2025
71 of 74 checks passed
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
mawong-amd pushed a commit to ROCm/vllm that referenced this pull request May 14, 2025
Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn>
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn>
Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn>
Signed-off-by: minpeter <kali2005611@gmail.com>
chenyang78 added a commit to chenyang78/vllm that referenced this pull request Jul 20, 2025
Looks like this was missed from vllm-project#14568

It caused issues when we build vllm with TORCH_CUDA_ARCH_LIST being specified
such as TORCH_CUDA_ARCH_LIST="9.0a". Because we didn't pass CUDA_ARCHS correctly
to compile moe_permute_unpermute_op, we ended up with the following failure:

Without the fix:
```
$ pytest tests/kernels/moe/test_cutlass_moe.py -k test_run_cutlass_moe_fp8[8-True-False-128-1-8192-5120-31]
...
RuntimeError: CUDA error: the provided PTX was compiled with an unsupported toolchain
```

With the fix:
```
$ pytest tests/kernels/moe/test_cutlass_moe.py -k test_run_cutlass_moe_fp8[8-True-False-128-1-8192-5120-31]
...
1 passed, 461 deselected, 2 warnings in 5.00s
```

Tags:
Signed-off-by: Yang Chen <yangche@fb.com>
chenyang78 added a commit to chenyang78/vllm that referenced this pull request Jul 24, 2025
Looks like this was missed from vllm-project#14568

It caused issues when we build vllm with TORCH_CUDA_ARCH_LIST being specified
such as TORCH_CUDA_ARCH_LIST="9.0a". Because we didn't pass CUDA_ARCHS correctly
to compile moe_permute_unpermute_op, we ended up with the following failure:

Without the fix:
```
$ pytest tests/kernels/moe/test_cutlass_moe.py -k test_run_cutlass_moe_fp8[8-True-False-128-1-8192-5120-31]
...
RuntimeError: CUDA error: the provided PTX was compiled with an unsupported toolchain
```

With the fix:
```
$ pytest tests/kernels/moe/test_cutlass_moe.py -k test_run_cutlass_moe_fp8[8-True-False-128-1-8192-5120-31]
...
1 passed, 461 deselected, 2 warnings in 5.00s
```

Tags:
Signed-off-by: Yang Chen <yangche@fb.com>
chenyang78 added a commit to chenyang78/vllm that referenced this pull request Jul 24, 2025
Looks like this was missed from vllm-project#14568

It caused issues when we build vllm with TORCH_CUDA_ARCH_LIST being specified
such as TORCH_CUDA_ARCH_LIST="9.0a". Because we didn't pass CUDA_ARCHS correctly
to compile moe_permute_unpermute_op, we ended up with the following failure:

Without the fix:
```
$ pytest tests/kernels/moe/test_cutlass_moe.py -k test_run_cutlass_moe_fp8[8-True-False-128-1-8192-5120-31]
...
RuntimeError: CUDA error: the provided PTX was compiled with an unsupported toolchain
```

With the fix:
```
$ pytest tests/kernels/moe/test_cutlass_moe.py -k test_run_cutlass_moe_fp8[8-True-False-128-1-8192-5120-31]
...
1 passed, 461 deselected, 2 warnings in 5.00s
```

Tags:
Signed-off-by: Yang Chen <yangche@fb.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants