Skip to content

Commit a325146

Browse files
elvischenv0xrushi
authored andcommitted
Bump Flashinfer to v0.4.0 (vllm-project#26326)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
1 parent 7007bce commit a325146

File tree

7 files changed

+25
-23
lines changed

7 files changed

+25
-23
lines changed

docker/Dockerfile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ ARG PYTHON_VERSION=3.12
1515
# Example:
1616
# docker build --build-arg BUILD_BASE_IMAGE=registry.acme.org/mirror/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
1717

18-
# Important: We build with an old version of Ubuntu to maintain broad
18+
# Important: We build with an old version of Ubuntu to maintain broad
1919
# compatibility with other Linux OSes. The main reason for this is that the
2020
# glibc version is baked into the distro, and binaries built with one glibc
2121
# version are not backwards compatible with OSes that use an earlier version.
@@ -371,7 +371,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
371371
# Install FlashInfer from source
372372
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
373373
# Keep this in sync with "flashinfer" extra in setup.py
374-
ARG FLASHINFER_GIT_REF="v0.3.1"
374+
ARG FLASHINFER_GIT_REF="v0.4.0"
375375
# Flag to control whether to compile FlashInfer AOT kernels
376376
# Set to "true" to enable AOT compilation:
377377
# docker build --build-arg FLASHINFER_AOT_COMPILE=true ...
@@ -392,7 +392,7 @@ RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
392392
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
393393
fi
394394
pushd flashinfer
395-
if [[ "${CUDA_VERSION}" == 12.8.* ]] && [ "$TARGETPLATFORM" = "linux/amd64" ]; then
395+
if [[ "${CUDA_VERSION}" == 12.8.* ]] && [ "$TARGETPLATFORM" = "linux/amd64" ] && [ "${FLASHINFER_GIT_REF}" = "v0.3.1" ]; then
396396
# NOTE: To make new precompiled wheels, see tools/flashinfer-build.sh
397397
echo "🏗️ Installing FlashInfer from pre-compiled wheel"
398398
uv pip install --system https://wheels.vllm.ai/flashinfer-python/flashinfer_python-0.3.1-cp39-abi3-manylinux1_x86_64.whl \

docker/Dockerfile.nightly_torch

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,15 +246,15 @@ RUN pip install setuptools==75.6.0 packaging==23.2 ninja==1.11.1.3 build==1.2.2.
246246

247247

248248
# build flashinfer for torch nightly from source around 10 mins
249-
# release version: v0.3.1
249+
# release version: v0.4.0
250250
# todo(elainewy): cache flashinfer build result for faster build
251251
ENV CCACHE_DIR=/root/.cache/ccache
252252
RUN --mount=type=cache,target=/root/.cache/ccache \
253253
--mount=type=cache,target=/root/.cache/uv \
254254
echo "git clone flashinfer..." \
255255
&& git clone --recursive https://github.com/flashinfer-ai/flashinfer.git \
256256
&& cd flashinfer \
257-
&& git checkout v0.3.1 \
257+
&& git checkout v0.4.0 \
258258
&& git submodule update --init --recursive \
259259
&& echo "finish git clone flashinfer..." \
260260
&& rm -rf build \

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -715,7 +715,7 @@ def _read_requirements(filename: str) -> list[str]:
715715
], # Required for audio processing
716716
"video": [], # Kept for backwards compatibility
717717
# FlashInfer should be updated together with the Dockerfile
718-
"flashinfer": ["flashinfer-python==0.3.1"],
718+
"flashinfer": ["flashinfer-python==0.4.0"],
719719
# Optional deps for AMD FP4 quantization support
720720
"petit-kernel": ["petit-kernel"],
721721
},

tests/kernels/attention/test_flashinfer_trtllm_attention.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
import torch
88

99
from tests.kernels.quantization.nvfp4_utils import (
10-
FLOAT4_E2M1_MAX,
11-
FLOAT8_E4M3_MAX,
1210
dequantize_nvfp4_to_dtype,
11+
get_nvfp4_global_scale,
1312
)
1413
from vllm.platforms import current_platform
1514
from vllm.utils import round_up
@@ -171,13 +170,12 @@ def test_flashinfer_trtllm_decode_with_baseline(
171170
output = torch.empty(ref_query.shape, dtype=dtype)
172171
wrapper.run(ref_query, ref_kv_cache, out=output)
173172
o_scale = 1.0
174-
o_sf_scale = None
173+
o_sf_scale_float = None
175174
if o_quant_dtype == FP8_DTYPE:
176175
_, o_scale = to_float8(output)
177176
elif o_quant_dtype == FP4_DTYPE:
178-
o_sf_scale = (
179-
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(output.flatten(), dim=-1)
180-
).to(torch.float32)
177+
o_sf_scale = get_nvfp4_global_scale(output)
178+
o_sf_scale_float = o_sf_scale.item()
181179

182180
# TRTLLM Decode
183181
if o_quant_dtype == FP4_DTYPE:
@@ -204,7 +202,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
204202
bmm1_scale=q_scale * k_scale * sm_scale,
205203
bmm2_scale=v_scale / o_scale,
206204
window_left=window_left,
207-
o_sf_scale=o_sf_scale,
205+
o_sf_scale=o_sf_scale_float,
208206
out=output_trtllm,
209207
)
210208
if o_quant_dtype == FP8_DTYPE:
@@ -361,13 +359,12 @@ def test_flashinfer_trtllm_prefill_with_baseline(
361359
output = torch.empty(ref_query.shape, dtype=dtype)
362360
wrapper.run(ref_query, ref_kv_cache, out=output)
363361
o_scale = 1.0
364-
o_sf_scale = None
362+
o_sf_scale_float = None
365363
if o_quant_dtype == FP8_DTYPE:
366364
_, o_scale = to_float8(output)
367365
elif o_quant_dtype == FP4_DTYPE:
368-
o_sf_scale = (
369-
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(output.flatten(), dim=-1)
370-
).to(torch.float32)
366+
o_sf_scale = get_nvfp4_global_scale(output)
367+
o_sf_scale_float = o_sf_scale.item()
371368

372369
# TRTLLM Prefill
373370
if o_quant_dtype == FP4_DTYPE:
@@ -398,7 +395,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
398395
cum_seq_lens_q=q_indptr,
399396
cum_seq_lens_kv=kv_indptr,
400397
window_left=window_left,
401-
o_sf_scale=o_sf_scale,
398+
o_sf_scale=o_sf_scale_float,
402399
out=output_trtllm,
403400
)
404401
if o_quant_dtype == FP8_DTYPE:

tests/kernels/quantization/nvfp4_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,11 @@ def break_fp4_bytes(a, dtype):
6666
return values.reshape(m, n * 2).to(dtype=dtype)
6767

6868

69+
def get_nvfp4_global_scale(a: torch.Tensor):
70+
return (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to(torch.float32)
71+
72+
6973
def quant_nvfp4_tensor(a: torch.Tensor):
70-
a_global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to(
71-
torch.float32
72-
)
74+
a_global_scale = get_nvfp4_global_scale(a)
7375
a_quant, a_block_scale = scaled_fp4_quant(a, a_global_scale)
7476
return a_quant, a_block_scale, a_global_scale

tests/quantization/test_blackwell_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def can_initialize(model: str, extra_args: Optional[list[str]] = None):
5050
with RemoteOpenAIServer(
5151
model,
5252
server_args,
53-
max_wait_seconds=1000, # Due to FlashInfer compile
53+
max_wait_seconds=1500, # Due to FlashInfer compile
5454
override_hf_configs=dummy_hf_overrides,
5555
) as server:
5656
client = server.get_client()

vllm/v1/attention/backends/flashinfer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1199,7 +1199,7 @@ def fast_plan_decode(
11991199
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
12001200

12011201
try:
1202-
# Make sure we pass exactly 15 arguments for tensor core version
1202+
# Make sure we pass exactly 18 arguments for tensor core version
12031203
self._plan_info = self._cached_module.plan(
12041204
self._float_workspace_buffer,
12051205
self._int_workspace_buffer,
@@ -1216,6 +1216,9 @@ def fast_plan_decode(
12161216
head_dim,
12171217
head_dim,
12181218
False, # causal
1219+
window_left,
1220+
-1, # fixed_split_size
1221+
False, # disable_split_kv
12191222
)
12201223
except Exception as e:
12211224
raise RuntimeError(f"Error in tensor core plan: {e}") from e

0 commit comments

Comments
 (0)