Skip to content

Nixl optimization for llama4 local attention #87

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 59 commits into
base: pd-launch-branch
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
418d2f8
[V1][Spec Decode] Share input embedding of target model with EAGLE dr…
ekagra-ranjan May 14, 2025
f9c069c
Modularize fused experts and integrate PPLX kernels (#15956)
bnellnm May 14, 2025
8568650
[CI] Disable Failing Tests (#18165)
robertgshaw2-redhat May 14, 2025
7e55a34
Local attention optimization for NIXL
mgoin May 14, 2025
8ea467d
Clean up a lot!
mgoin May 14, 2025
73a8272
Small opt
mgoin May 14, 2025
749f792
[Frontend] decrease import time of vllm.multimodal (#18031)
davidxia May 14, 2025
d93c976
[Kernel] Have rotary embeddings support tensors (#18046)
LucasWilkinson May 14, 2025
2fc9075
[V1] Structured Outputs + Thinking compatibility (#16577)
aarnphm May 14, 2025
17cc4c9
Fix mypy
mgoin May 14, 2025
7974736
Add support for loading torchao models with `AOPerModuleConfig` (#17826)
jerryzh168 May 14, 2025
78aa341
[CI] Fix race condition in test_kv_cache_events test (#18169)
russellb May 14, 2025
2142035
[V1] Support multiple kv connectors (#17564)
mgoin May 14, 2025
09f106a
Upload vllm index for the rc builds (#18173)
atalman May 14, 2025
f25e0d1
[Bugfix]: make most of `test_openai_schema.py` pass (#17664)
davidxia May 15, 2025
e60f550
[v1] Support multiple KV cache groups in GPU model runner (#17945)
heheda12345 May 15, 2025
65334ef
[V1][Metrics] Remove unused code (#18158)
markmc May 15, 2025
afe3236
[Chore] astral's ty (#18116)
aarnphm May 15, 2025
2dff093
[Misc] add lobe-chat support (#18177)
reidliu41 May 15, 2025
83f74c6
[Fix][ROCm] Enforce eager for all encoder-decoder models on ROCm (#18…
ProExpertProg May 15, 2025
26d0419
Update deprecated type hinting in `models` (#18132)
hmellor May 15, 2025
e6b8e65
[Bugfix] Fix fp8 tests for triton_unified_attention for Triton 3.3 (#…
tdoublep May 15, 2025
4f07a64
Support custom implementations of VideoLoader backends. (#18091)
huachenheli May 15, 2025
420caf7
[UT] Add ut for none hash (#17892)
andyxning May 15, 2025
dd2a945
[Model] Allow the use of sliding window in Qwen2 (#17772)
inkcherry May 15, 2025
70f8b96
[Bugfix] Fix FusedMoEPrepareAndFinalize for cuda-disalike backends (#…
MengqingCao May 15, 2025
de71fec
[CI] don't skip fixed `test_kv_cache_events()` (#18183)
davidxia May 15, 2025
a8f5aec
[V1] Update zmq socket creation in nixl connector (#18148)
russellb May 15, 2025
a9944aa
fix: typos (#18151)
omahs May 15, 2025
07ad271
Update deprecated type hinting in `model_loader` (#18130)
hmellor May 15, 2025
451da4b
add tools into TokenizeChatRequest (#18187)
hustxiayang May 15, 2025
01c2233
[Kernel] [V1] Fix performance regression for triton unified attention…
tdoublep May 15, 2025
566ec04
Adding "Basic Models Test" and "Multi-Modal Models Test (Extended) 3"…
Alexei-V-Ivanov-AMD May 15, 2025
51ff154
Improve examples rendering in docs and GitHub (#18203)
hmellor May 15, 2025
2aa5470
[Frontend] Fix chat template content format detection (#18190)
schoennenbeck May 15, 2025
fadb8d5
[Bugfix]Change the exception thrown by call_hf_processor from Runtime…
Abatom May 15, 2025
9254052
[Bugfix] [ROCm]: Remove assertion logic when using AITER fused moe in…
tjtanaa May 15, 2025
e3f3aee
[Misc] Avoid cuda graph log when sizes still match (#18202)
NickLucche May 15, 2025
0b34593
Adding "AMD: Tensorizer Test" to amdproduction. (#18216)
Alexei-V-Ivanov-AMD May 15, 2025
8795eb9
[Bugfix] Fix test_eagle test (#18223)
luccafong May 15, 2025
c7852a6
[Build] Allow shipping PTX on a per-file basis (#18155)
LucasWilkinson May 15, 2025
4e1c6a0
[Bugfix] fix rotary embedding test for _get_padded_tensor_shape (#18229)
LucasWilkinson May 16, 2025
ee659e3
[Bugfix][ROCm] Use `chunked_prefill_paged_decode` as fallback for V1 …
kliuae May 16, 2025
f4937a5
[Model] vLLM v1 supports Medusa (#17956)
skylee-01 May 16, 2025
b18201f
Allow users to pass arbitrary JSON keys from CLI (#18208)
hmellor May 16, 2025
6b31c84
Throw better error for when running into k8s service discovery issue …
wseaton May 16, 2025
3d2779c
[Feature] Support Pipeline Parallism in torchrun SPMD offline inferen…
luccafong May 16, 2025
5c04bb8
[doc] fix multimodal example script (#18089)
davidxia May 16, 2025
67da572
[PERF] Speed up Qwen2.5-VL model by speed up rotary position embeddin…
vadiklyutiy May 16, 2025
5418176
[Misc] Add Ray Prometheus logger to V1 (#17925)
eicherseiji May 16, 2025
390ec88
[Misc] Consolidate Audio tests into multimodal common generation test…
Isotr0py May 16, 2025
e23564c
use ceil_div in cutlass block scaling shape check (#17918)
IwakuraRein May 16, 2025
a5f8c11
[Fix] Fix typo in `resolve_hf_chat_template` (#18259)
fxmarty-amd May 16, 2025
87d8714
[Model] Use autoweightloader for dbrx (#18251)
learner0810 May 16, 2025
d3d91b6
[Misc][MacOS] fix bfloat16 error (#18249)
reidliu41 May 16, 2025
1db4f47
[BugFix] Fix multi async save in MultiConnector (#18246)
njhill May 16, 2025
af2f264
Add TODO to remove
mgoin May 16, 2025
e8fd2f1
Merge branch 'main' into nixl-l4-opt
mgoin May 16, 2025
7f0ef82
Bug fixes
mgoin May 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .buildkite/scripts/hardware_ci/run-amd-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ if [[ $commands == *"pytest -v -s basic_correctness/test_basic_correctness.py"*
commands=${commands//"pytest -v -s basic_correctness/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s basic_correctness/test_basic_correctness.py"}
fi

if [[ $commands == *"pytest -v -s models/test_registry.py"* ]]; then
commands=${commands//"pytest -v -s models/test_registry.py"/"pytest -v -s models/test_registry.py -k 'not BambaForCausalLM and not GritLM and not Mamba2ForCausalLM and not Zamba2ForCausalLM'"}
fi

if [[ $commands == *"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'"* ]]; then
commands=${commands//"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'"/"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2 and not BambaForCausalLM and not Gemma2ForCausalLM and not Grok1ModelForCausalLM and not Zamba2ForCausalLM and not Gemma2Model and not GritLM'"}
fi

if [[ $commands == *"pytest -v -s compile/test_basic_correctness.py"* ]]; then
commands=${commands//"pytest -v -s compile/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s compile/test_basic_correctness.py"}
fi
Expand Down
1 change: 1 addition & 0 deletions .buildkite/scripts/upload-wheels.sh
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,4 @@ else
fi

aws s3 cp "$wheel" "s3://vllm-wheels/$version/"
aws s3 cp index.html "s3://vllm-wheels/$version/vllm/index.html"
11 changes: 6 additions & 5 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ steps:
# test with tp=2 and external_dp=2
- VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
# test with tp=2 and pp=2
- PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
# test with internal dp
- python3 ../examples/offline_inference/data_parallel.py
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
Expand Down Expand Up @@ -216,7 +218,6 @@ steps:
- pytest -v -s v1/spec_decode
- pytest -v -s v1/kv_connector/unit
- pytest -v -s v1/test_serial_utils.py
- pytest -v -s v1/test_stats.py
- pytest -v -s v1/test_utils.py
- pytest -v -s v1/test_oracle.py
# TODO: accuracy does not match, whether setting
Expand Down Expand Up @@ -380,7 +381,7 @@ steps:
- pytest -v -s kernels/mamba

- label: Tensorizer Test # 11min
mirror_hardwares: [amdexperimental]
mirror_hardwares: [amdexperimental, amdproduction]
soft_fail: true
source_file_dependencies:
- vllm/model_executor/model_loader
Expand Down Expand Up @@ -456,7 +457,7 @@ steps:
##### models test #####

- label: Basic Models Test # 24min
mirror_hardwares: [amdexperimental]
mirror_hardwares: [amdexperimental, amdproduction]
torch_nightly: true
source_file_dependencies:
- vllm/
Expand Down Expand Up @@ -528,7 +529,7 @@ steps:
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model'

- label: Multi-Modal Models Test (Extended) 3
mirror_hardwares: [amdexperimental]
mirror_hardwares: [amdexperimental, amdproduction]
optional: true
source_file_dependencies:
- vllm/
Expand All @@ -538,7 +539,7 @@ steps:
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=1) and not core_model'

- label: Quantized Models Test
mirror_hardwares: [amdexperimental]
mirror_hardwares: [amdexperimental, amdproduction]
source_file_dependencies:
- vllm/model_executor/layers/quantization
- tests/models/quantization
Expand Down
9 changes: 6 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Only build Marlin kernels if we are building for at least some compatible archs.
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
# are not supported by Machete yet.
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
# 9.0 for latest bf16 atomicAdd PTX
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;9.0+PTX" "${CUDA_ARCHS}")
if (MARLIN_ARCHS)

#
Expand Down Expand Up @@ -445,8 +446,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
#
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
# kernels for the remaining archs that are not already built for 3x.
# (Build 8.9 for FP8)
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
"7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
"7.5;8.0;8.9+PTX" "${CUDA_ARCHS}")
# subtract out the archs that are already built for 3x
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
if (SCALED_MM_2X_ARCHS)
Expand Down Expand Up @@ -675,7 +677,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
CUDA_ARCHS "${CUDA_ARCHS}")

list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
# 9.0 for latest bf16 atomicAdd PTX
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;9.0+PTX" "${CUDA_ARCHS}")
if (MARLIN_MOE_ARCHS)

#
Expand Down
12 changes: 10 additions & 2 deletions benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,16 @@ def bench_fp8(
a_cont = a.contiguous()
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
block_scale_a = torch.rand((m, k // 128), device="cuda", dtype=torch.float32)
block_scale_b = torch.rand((k // 128, n // 128), device="cuda", dtype=torch.float32)

def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y

block_scale_a = torch.rand(
(m, ceil_div(k, 128)), device="cuda", dtype=torch.float32
)
block_scale_b = torch.rand(
ceil_div(k, 128), ceil_div(n, 128), device="cuda", dtype=torch.float32
)
block_scale_a_M_major = block_scale_a.t().contiguous().t()
block_scale_b_K_major = block_scale_b.t().contiguous().t()
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
Expand Down
89 changes: 69 additions & 20 deletions cmake/utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,26 @@ macro(set_gencode_flags_for_srcs)
"${multiValueArgs}" ${ARGN} )

foreach(_ARCH ${arg_CUDA_ARCHS})
string(REPLACE "." "" _ARCH "${_ARCH}")
set_gencode_flag_for_srcs(
SRCS ${arg_SRCS}
ARCH "compute_${_ARCH}"
CODE "sm_${_ARCH}")
# handle +PTX suffix: generate both sm and ptx codes if requested
string(FIND "${_ARCH}" "+PTX" _HAS_PTX)
if(NOT _HAS_PTX EQUAL -1)
string(REPLACE "+PTX" "" _BASE_ARCH "${_ARCH}")
string(REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH}")
set_gencode_flag_for_srcs(
SRCS ${arg_SRCS}
ARCH "compute_${_STRIPPED_ARCH}"
CODE "sm_${_STRIPPED_ARCH}")
set_gencode_flag_for_srcs(
SRCS ${arg_SRCS}
ARCH "compute_${_STRIPPED_ARCH}"
CODE "compute_${_STRIPPED_ARCH}")
else()
string(REPLACE "." "" _STRIPPED_ARCH "${_ARCH}")
set_gencode_flag_for_srcs(
SRCS ${arg_SRCS}
ARCH "compute_${_STRIPPED_ARCH}"
CODE "sm_${_STRIPPED_ARCH}")
endif()
endforeach()

if (${arg_BUILD_PTX_FOR_ARCH})
Expand All @@ -251,7 +266,10 @@ endmacro()
#
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
# `<major>.<minor>[letter]` compute the "loose intersection" with the
# `TGT_CUDA_ARCHS` list of gencodes.
# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
# architecture in `SRC_CUDA_ARCHS`.
# The loose intersection is defined as:
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
# where `<=` is the version comparison operator.
Expand All @@ -268,44 +286,63 @@ endmacro()
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
#
# Example With PTX:
# SRC_CUDA_ARCHS="8.0+PTX"
# TGT_CUDA_ARCHS="9.0"
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
# OUT_CUDA_ARCHS="8.0+PTX"
#
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS)
set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS})
set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}")
set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS})

# handle +PTX suffix: separate base arch for matching, record PTX requests
set(_PTX_ARCHS)
foreach(_arch ${_SRC_CUDA_ARCHS})
if(_arch MATCHES "\\+PTX$")
string(REPLACE "+PTX" "" _base "${_arch}")
list(APPEND _PTX_ARCHS "${_base}")
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
list(APPEND _SRC_CUDA_ARCHS "${_base}")
endif()
endforeach()
list(REMOVE_DUPLICATES _PTX_ARCHS)
list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS)

# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
set(_CUDA_ARCHS)
if ("9.0a" IN_LIST SRC_CUDA_ARCHS)
list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a")
if ("9.0" IN_LIST TGT_CUDA_ARCHS_)
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0")
if ("9.0a" IN_LIST _SRC_CUDA_ARCHS)
list(REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a")
if ("9.0" IN_LIST TGT_CUDA_ARCHS)
list(REMOVE_ITEM _TGT_CUDA_ARCHS "9.0")
set(_CUDA_ARCHS "9.0a")
endif()
endif()

if ("10.0a" IN_LIST SRC_CUDA_ARCHS)
list(REMOVE_ITEM SRC_CUDA_ARCHS "10.0a")
if ("10.0a" IN_LIST _SRC_CUDA_ARCHS)
list(REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a")
if ("10.0" IN_LIST TGT_CUDA_ARCHS)
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "10.0")
list(REMOVE_ITEM _TGT_CUDA_ARCHS "10.0")
set(_CUDA_ARCHS "10.0a")
endif()
endif()

list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)

# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
# is less or equal to ARCH (but has the same major version since SASS binary
# compatibility is only forward compatible within the same major version).
foreach(_ARCH ${TGT_CUDA_ARCHS_})
foreach(_ARCH ${_TGT_CUDA_ARCHS})
set(_TMP_ARCH)
# Extract the major version of the target arch
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
foreach(_SRC_ARCH ${_SRC_CUDA_ARCHS})
# Extract the major version of the source arch
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
# Check major-version match AND version-less-or-equal
# Check version-less-or-equal, and allow PTX arches to match across majors
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
set(_TMP_ARCH "${_SRC_ARCH}")
endif()
else()
Expand All @@ -321,6 +358,18 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
endforeach()

list(REMOVE_DUPLICATES _CUDA_ARCHS)

# reapply +PTX suffix to architectures that requested PTX
set(_FINAL_ARCHS)
foreach(_arch ${_CUDA_ARCHS})
if(_arch IN_LIST _PTX_ARCHS)
list(APPEND _FINAL_ARCHS "${_arch}+PTX")
else()
list(APPEND _FINAL_ARCHS "${_arch}")
endif()
endforeach()
set(_CUDA_ARCHS ${_FINAL_ARCHS})

set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
endfunction()

Expand Down
3 changes: 3 additions & 0 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
if (num_tokens == 0) { \
return; \
} \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
Expand Down
4 changes: 2 additions & 2 deletions csrc/attention/attention_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ __device__ void paged_attention_kernel(

// Load the query to registers.
// Each thread in a thread group has a different part of the query.
// For example, if the the thread group size is 4, then the first thread in
// For example, if the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
// q is split from a qkv tensor, it may not be contiguous.
Expand Down Expand Up @@ -259,7 +259,7 @@ __device__ void paged_attention_kernel(

// Load a key to registers.
// Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in
// For example, if the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the key, and the second thread
// has 1, 5, 9, ... th vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
Expand Down
14 changes: 14 additions & 0 deletions csrc/dispatch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,19 @@
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)

#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__)

#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))

#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
8 changes: 4 additions & 4 deletions csrc/moe/moe_align_sum_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
}

if (use_global_memory) {
VLLM_DISPATCH_INTEGRAL_TYPES(
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
Expand All @@ -351,7 +351,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
cumsum_buffer.data_ptr<int32_t>());
});
} else if (use_i16) {
VLLM_DISPATCH_INTEGRAL_TYPES(
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// set dynamic shared mem
auto kernel =
Expand All @@ -366,7 +366,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
topk_ids.numel());
});
} else {
VLLM_DISPATCH_INTEGRAL_TYPES(
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
auto kernel =
vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
Expand All @@ -391,7 +391,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
TORCH_CHECK(num_experts == 256,
"sgl_moe_align_block_size kernel only supports deepseek v3.");

VLLM_DISPATCH_INTEGRAL_TYPES(
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `cumsum` tensors
auto options_int =
Expand Down
Loading