-
-
Notifications
You must be signed in to change notification settings - Fork 9.1k
Description
Your current environment
The output of python collect_env.py
INFO 06-09 15:59:09 [__init__.py:248] Automatically detected platform rocm.
Collecting environment information...
==============================
System Info
==============================
OS : Ubuntu 22.04.5 LTS (x86_64)
GCC version : (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version : 18.0.0git (https://github.com/RadeonOpenCompute/llvm-project roc-6.3.1 24491 1e0fda770a2079fbd71e4b70974d74f62fd3af10)
CMake version : version 3.31.4
Libc version : glibc-2.35
==============================
PyTorch Info
==============================
PyTorch version : 2.7.0a0+git6c0e746
Is debug build : False
CUDA used to build PyTorch : N/A
ROCM used to build PyTorch : 6.3.42133-1b9c17779
==============================
Python Environment
==============================
Python version : 3.12.9 (main, Feb 5 2025, 08:49:00) [GCC 11.4.0] (64-bit runtime)
Python platform : Linux-6.8.0-52-generic-x86_64-with-glibc2.35
==============================
CUDA / GPU Info
==============================
Is CUDA available : True
CUDA runtime version : Could not collect
CUDA_MODULE_LOADING set to : LAZY
GPU models and configuration : AMD Instinct MI250X/MI250 (gfx90a:sramecc+:xnack-)
Nvidia driver version : Could not collect
cuDNN version : Could not collect
HIP runtime version : 6.3.42133
MIOpen runtime version : 3.3.0
Is XNNPACK available : True
==============================
CPU Info
==============================
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 128
On-line CPU(s) list: 0-127
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7713 64-Core Processor
CPU family: 25
Model: 1
Thread(s) per core: 1
Core(s) per socket: 64
Socket(s): 2
Stepping: 1
Frequency boost: enabled
CPU max MHz: 3720.7029
CPU min MHz: 1500.0000
BogoMIPS: 3992.52
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin brs arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm debug_swap
Virtualization: AMD-V
L1d cache: 4 MiB (128 instances)
L1i cache: 4 MiB (128 instances)
L2 cache: 64 MiB (128 instances)
L3 cache: 512 MiB (16 instances)
NUMA node(s): 8
NUMA node0 CPU(s): 0-15
NUMA node1 CPU(s): 16-31
NUMA node2 CPU(s): 32-47
NUMA node3 CPU(s): 48-63
NUMA node4 CPU(s): 64-79
NUMA node5 CPU(s): 80-95
NUMA node6 CPU(s): 96-111
NUMA node7 CPU(s): 112-127
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; Safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
==============================
Versions of relevant libraries
==============================
[pip3] numpy==1.26.4
[pip3] pyzmq==26.2.1
[pip3] torch==2.7.0a0+git6c0e746
[pip3] torchvision==0.21.0+7af6987
[pip3] transformers==4.52.3
[pip3] triton==3.2.0+gite5be006a
[conda] Could not collect
==============================
vLLM Info
==============================
ROCM Version : 6.3.42133-1b9c17779
Neuron SDK Version : N/A
vLLM Version : 0.8.5.dev681+g964472b96.d20250528 (git sha: 964472b96, date: 20250528)
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
============================ ROCm System Management Interface ============================
================================ Weight between two GPUs =================================
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 0 15 15 30 30 30 15 30
GPU1 15 0 30 15 30 15 30 45
GPU2 15 30 0 15 15 30 30 30
GPU3 30 15 15 0 30 45 30 15
GPU4 30 30 15 30 0 15 15 30
GPU5 30 15 30 45 15 0 30 15
GPU6 15 30 30 30 15 30 0 15
GPU7 30 45 30 15 30 15 15 0
================================= Hops between two GPUs ==================================
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 0 1 1 1 1 1 1 1
GPU1 1 0 1 1 1 1 1 1
GPU2 1 1 0 1 1 1 1 1
GPU3 1 1 1 0 1 1 1 1
GPU4 1 1 1 1 0 1 1 1
GPU5 1 1 1 1 1 0 1 1
GPU6 1 1 1 1 1 1 0 1
GPU7 1 1 1 1 1 1 1 0
=============================== Link Type between two GPUs ===============================
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 0 XGMI XGMI XGMI XGMI XGMI XGMI XGMI
GPU1 XGMI 0 XGMI XGMI XGMI XGMI XGMI XGMI
GPU2 XGMI XGMI 0 XGMI XGMI XGMI XGMI XGMI
GPU3 XGMI XGMI XGMI 0 XGMI XGMI XGMI XGMI
GPU4 XGMI XGMI XGMI XGMI 0 XGMI XGMI XGMI
GPU5 XGMI XGMI XGMI XGMI XGMI 0 XGMI XGMI
GPU6 XGMI XGMI XGMI XGMI XGMI XGMI 0 XGMI
GPU7 XGMI XGMI XGMI XGMI XGMI XGMI XGMI 0
======================================= Numa Nodes =======================================
GPU[0] : (Topology) Numa Node: 3
GPU[0] : (Topology) Numa Affinity: 3
GPU[1] : (Topology) Numa Node: 3
GPU[1] : (Topology) Numa Affinity: 3
GPU[2] : (Topology) Numa Node: 2
GPU[2] : (Topology) Numa Affinity: 2
GPU[3] : (Topology) Numa Node: 2
GPU[3] : (Topology) Numa Affinity: 2
GPU[4] : (Topology) Numa Node: 7
GPU[4] : (Topology) Numa Affinity: 7
GPU[5] : (Topology) Numa Node: 7
GPU[5] : (Topology) Numa Affinity: 7
GPU[6] : (Topology) Numa Node: 6
GPU[6] : (Topology) Numa Affinity: 6
GPU[7] : (Topology) Numa Node: 6
GPU[7] : (Topology) Numa Affinity: 6
================================== End of ROCm SMI Log ===================================
==============================
Environment Variables
==============================
PYTORCH_ROCM_ARCH=gfx90a;gfx942
VLLM_ROCM_CUSTOM_PAGED_ATTN=1
VLLM_TARGET_DEVICE=rocm
LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib:
VLLM_USE_V1=1
VERBOSE=1
NCCL_CUMEM_ENABLE=0
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1
TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_root
CUDA_MODULE_LOADING=LAZY
🐛 Describe the bug
Multiple architectures, such as Qwen2, use Sliding Window Attention. However, there is no option in V1 to run Sliding Window Attention on ROCm. Sending a request to the server crashes, for:
MODEL_NAME="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
export VLLM_USE_V1=1
python3 -m vllm.entrypoints.openai.api_server --port 8080 --model $MODEL_NAME --served-model-name $MODEL_NAME --gpu-memory-utilization 0.95 --disable-custom-all-reduce --tensor-parallel-size 1 --enable-chunked-prefill --disable-log-requests --enable-reasoning --reasoning-parser deepseek_r1
This is because it uses Triton Flash Attention, which not support Sliding Window Attention. As a result, sending a request to the server crashes vLLM:
Request:
curl -iX POST "http://localhost:8080/v1/chat/completions" -H "Content-Type: application/json" -d '{
"model": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
"messages": [{ "role": "user", "content": "Hello, how are you?"}],
"stream": false
}'
Resulting logs from crash:
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/triton/language/core.py", line 34, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/language/core.py", line 1281, in arange
return semantic.arange(start, end, _builder)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/language/semantic.py", line 610, in arange
raise ValueError("arange's range must be a power of 2")
ValueError: arange's range must be a power of 2
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
self.run()
File "/usr/lib/python3.12/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/vllm/vllm/v1/engine/core.py", line 493, in run_engine_core
raise e
File "/vllm/vllm/v1/engine/core.py", line 482, in run_engine_core
engine_core.run_busy_loop()
File "/vllm/vllm/v1/engine/core.py", line 509, in run_busy_loop
self._process_engine_step()
File "/vllm/vllm/v1/engine/core.py", line 534, in _process_engine_step
outputs = self.step_fn()
^^^^^^^^^^^^^^
File "/vllm/vllm/v1/engine/core.py", line 222, in step
model_output = self.execute_model(scheduler_output)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/vllm/vllm/v1/engine/core.py", line 209, in execute_model
raise err
File "/vllm/vllm/v1/engine/core.py", line 203, in execute_model
return self.model_executor.execute_model(scheduler_output)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/vllm/vllm/v1/executor/abstract.py", line 86, in execute_model
output = self.collective_rpc("execute_model",
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/vllm/vllm/executor/uniproc_executor.py", line 56, in collective_rpc
answer = run_method(self.driver_worker, method, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/vllm/vllm/utils.py", line 2534, in run_method
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/vllm/vllm/v1/worker/gpu_worker.py", line 276, in execute_model
output = self.model_runner.execute_model(scheduler_output,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/vllm/vllm/v1/worker/gpu_model_runner.py", line 1156, in execute_model
model_output = self.model(
^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/vllm/vllm/model_executor/models/qwen2.py", line 480, in forward
hidden_states = self.model(input_ids, positions, intermediate_tensors,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/vllm/vllm/compilation/decorators.py", line 245, in __call__
model_output = self.forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/vllm/vllm/model_executor/models/qwen2.py", line 339, in forward
def forward(
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 764, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 830, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 406, in __call__
raise e
File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 393, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<eval_with_key>.58", line 212, in forward
submod_1 = self.submod_1(getitem, s0, getitem_1, getitem_2, getitem_3); getitem = getitem_1 = getitem_2 = submod_1 = None
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 830, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 406, in __call__
raise e
File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 393, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<eval_with_key>.2", line 5, in forward
unified_attention_with_output = torch.ops.vllm.unified_attention_with_output(query_2, key_2, value, output_1, 'model.layers.0.self_attn.attn'); query_2 = key_2 = value = output_1 = unified_attention_with_output = None
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1156, in __call__
return self._op(*args, **(kwargs or {}))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/vllm/vllm/attention/layer.py", line 425, in unified_attention_with_output
self.impl.forward(self,
File "/vllm/vllm/v1/attention/backends/triton_attn.py", line 201, in forward
unified_attention(
File "/vllm/vllm/attention/ops/triton_unified_attention.py", line 294, in unified_attention
kernel_unified_attention_2d[(
File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 330, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 623, in run
kernel = self.compile(
^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/compiler/compiler.py", line 280, in compile
module = src.make_ir(options, codegen_fns, module_map, context)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/compiler/compiler.py", line 85, in make_ir
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
triton.compiler.errors.CompilationError: at 67:13:
q_block_local_idx = q_block_global_idx - q_block_start_idx
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
cur_batch_query_len = cur_batch_in_all_stop_index \
- cur_batch_in_all_start_index
if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
return
offs_m = tl.arange(0, BLOCK_Q * num_queries_per_kv)
^
In V0, one can use ROCm's Custom Paged Attention, however this is not supported on V1, apparently due to numerical instabilities on V1:
# custom paged attn always supported on V0. On V1, requires sliding window
# disabled due to observed numerical discrepancy.
if ON_GFX9:
return ((not envs.VLLM_USE_V1 or sliding_window == 0
or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_seq_len <= 32768 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
and envs.VLLM_ROCM_USE_AITER))
else:
return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0
or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
and head_size == 128 and block_size == 16
and (gqa_ratio >= 3 and gqa_ratio <= 16)
and max_seq_len <= 32768 and alibi_slopes is None
and kv_cache_dtype == "auto"
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
Setting VLLM_USE_TRITON_FLASH_ATTN=0
does not work on V1, as it will still use Triton Flash Attention despite the flag, and suffer from the resulting crash if one sends a request. E.g. from the logs:
INFO 06-09 16:08:58 [rocm.py:184] Using Triton Attention backend on V1 engine.
As such, there is no way to run Qwen2 architectures, or any architectures that use Sliding Window Attention, on ROCm in V1. Given the plans to deprecate V0, this is going to be quite concerning for ROCm.
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.