Skip to content

[Bug]: Benchmark v1 on multi-gpu crashes with ValueError: Pointer argument (at 0) cannot be accessed from Triton #13392

Closed
@huydhn

Description

@huydhn

Your current environment

The output of `python collect_env.py`
Collecting environment information...
PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: CentOS Stream 9 (x86_64)
GCC version: (GCC) 11.5.0 20240719 (Red Hat 11.5.0-2)
Clang version: Could not collect
CMake version: version 3.31.4
Libc version: glibc-2.34

Python version: 3.12.9 | packaged by Anaconda, Inc. | (main, Feb  6 2025, 18:56:27) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.4.3-0_fbk14_hardened_2601_gcd42476b84e9-x86_64-with-glibc2.34
Is CUDA available: True
CUDA runtime version: 12.6.85
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA H100
GPU 1: NVIDIA H100
GPU 2: NVIDIA H100
GPU 3: NVIDIA H100
GPU 4: NVIDIA H100
GPU 5: NVIDIA H100
GPU 6: NVIDIA H100
GPU 7: NVIDIA H100

Nvidia driver version: 550.90.07
cuDNN version: Probably one of the following:
/usr/lib64/libcudnn.so.9.6.0
/usr/lib64/libcudnn_adv.so.9.6.0
/usr/lib64/libcudnn_cnn.so.9.6.0
/usr/lib64/libcudnn_engines_precompiled.so.9.6.0
/usr/lib64/libcudnn_engines_runtime_compiled.so.9.6.0
/usr/lib64/libcudnn_graph.so.9.6.0
/usr/lib64/libcudnn_heuristic.so.9.6.0
/usr/lib64/libcudnn_ops.so.9.6.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      52 bits physical, 57 bits virtual
Byte Order:                         Little Endian
CPU(s):                             368
On-line CPU(s) list:                0-367
Vendor ID:                          AuthenticAMD
Model name:                         AMD EPYC 9654 96-Core Processor
CPU family:                         25
Model:                              17
Thread(s) per core:                 1
Core(s) per socket:                 368
Socket(s):                          1
Stepping:                           1
BogoMIPS:                           4792.78
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm rep_good nopl cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext perfctr_core invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 clzero xsaveerptr wbnoinvd arat npt lbrv nrip_save tsc_scale vmcb_clean pausefilter pfthreshold v_vmsave_vmload vgif vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm flush_l1d arch_capabilities
Virtualization:                     AMD-V
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          23 MiB (368 instances)
L1i cache:                          23 MiB (368 instances)
L2 cache:                           184 MiB (368 instances)
L3 cache:                           16 MiB (1 instance)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-367
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec store bypass:    Vulnerable
Vulnerability Spectre v1:           Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2:           Vulnerable, IBPB: disabled, STIBP: disabled, PBRSB-eIBRS: Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-ml-py==12.570.86
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pynvml==12.0.0
[pip3] pyzmq==26.2.1
[pip3] torch==2.5.1
[pip3] torchaudio==2.5.1
[pip3] torchvision==0.20.1
[pip3] transformers==4.48.3
[pip3] triton==3.1.0
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] nvidia-cublas-cu12        12.4.5.8                 pypi_0    pypi
[conda] nvidia-cuda-cupti-cu12    12.4.127                 pypi_0    pypi
[conda] nvidia-cuda-nvrtc-cu12    12.4.127                 pypi_0    pypi
[conda] nvidia-cuda-runtime-cu12  12.4.127                 pypi_0    pypi
[conda] nvidia-cudnn-cu12         9.1.0.70                 pypi_0    pypi
[conda] nvidia-cufft-cu12         11.2.1.3                 pypi_0    pypi
[conda] nvidia-curand-cu12        10.3.5.147               pypi_0    pypi
[conda] nvidia-cusolver-cu12      11.6.1.9                 pypi_0    pypi
[conda] nvidia-cusparse-cu12      12.3.1.170               pypi_0    pypi
[conda] nvidia-ml-py              12.570.86                pypi_0    pypi
[conda] nvidia-nccl-cu12          2.21.5                   pypi_0    pypi
[conda] nvidia-nvjitlink-cu12     12.4.127                 pypi_0    pypi
[conda] nvidia-nvtx-cu12          12.4.127                 pypi_0    pypi
[conda] pynvml                    12.0.0                   pypi_0    pypi
[conda] pyzmq                     26.2.1                   pypi_0    pypi
[conda] torch                     2.5.1                    pypi_0    pypi
[conda] torchaudio                2.5.1                    pypi_0    pypi
[conda] torchvision               0.20.1                   pypi_0    pypi
[conda] transformers              4.48.3                   pypi_0    pypi
[conda] triton                    3.1.0                    pypi_0    pypi
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.7.3.dev71+g42288cba
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NV18    NV18    NV18    NV18    NV18    NV18    NV18    0-367   0               N/A
GPU1    NV18     X      NV18    NV18    NV18    NV18    NV18    NV18    0-367   0               N/A
GPU2    NV18    NV18     X      NV18    NV18    NV18    NV18    NV18    0-367   0               N/A
GPU3    NV18    NV18    NV18     X      NV18    NV18    NV18    NV18    0-367   0               N/A
GPU4    NV18    NV18    NV18    NV18     X      NV18    NV18    NV18    0-367   0               N/A
GPU5    NV18    NV18    NV18    NV18    NV18     X      NV18    NV18    0-367   0               N/A
GPU6    NV18    NV18    NV18    NV18    NV18    NV18     X      NV18    0-367   0               N/A
GPU7    NV18    NV18    NV18    NV18    NV18    NV18    NV18     X      0-367   0               N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

CUDA_CACHE_PATH=/data/users/huydo/.nv/ComputeCache
LD_LIBRARY_PATH=/usr/local/cuda-12.6/lib64/:
CUDA_HOME=/usr/local/cuda
CUDA_HOME=/usr/local/cuda
NCCL_CUMEM_ENABLE=0
TORCHINDUCTOR_COMPILE_THREADS=1
CUDA_MODULE_LOADING=LAZY

🐛 Describe the bug

Benchmarking v1 on multi-gpu with ENGINE_VERSION=v1 .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh crashes with the following error:

(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374] WorkerProc hit an exception: %s
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374] Traceback (most recent call last):
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]   File "/home/huydo/github/pytorch-integration-testing/vllm-benchmarks/vllm/vllm/v1/executor/multiproc_executor.py", line 370, in worker_busy_loop
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]     output = func(*args, **kwargs)
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]              ^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]   File "/home/huydo/miniconda3/envs/py3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]     return func(*args, **kwargs)
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]   File "/home/huydo/github/pytorch-integration-testing/vllm-benchmarks/vllm/vllm/v1/worker/gpu_worker.py", line 154, in determine_available_memory
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]     self.model_runner.profile_run()
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]   File "/home/huydo/github/pytorch-integration-testing/vllm-benchmarks/vllm/vllm/v1/worker/gpu_model_runner.py", line 1283, in profile_run
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]     hidden_states = self._dummy_run(self.max_num_tokens,
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]   File "/home/huydo/miniconda3/envs/py3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]     return func(*args, **kwargs)
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]   File "/home/huydo/github/pytorch-integration-testing/vllm-benchmarks/vllm/vllm/v1/worker/gpu_model_runner.py", line 1150, in _dummy_run
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]     hidden_states = model(
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]                     ^^^^^^
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]   File "/home/huydo/miniconda3/envs/py3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]     return self._call_impl(*args, **kwargs)
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]   File "/home/huydo/miniconda3/envs/py3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]     return forward_call(*args, **kwargs)
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]   File "/home/huydo/github/pytorch-integration-testing/vllm-benchmarks/vllm/vllm/model_executor/models/llama.py", line 547, in forward
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]     model_output = self.model(input_ids, positions, kv_caches,
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]   File "/home/huydo/github/pytorch-integration-testing/vllm-benchmarks/vllm/vllm/compilation/decorators.py", line 238, in __call__
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]     output = self.compiled_callable(*args, **kwargs)
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]   File "/home/huydo/miniconda3/envs/py3.12/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]     return fn(*args, **kwargs)
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]            ^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]   File "/home/huydo/github/pytorch-integration-testing/vllm-benchmarks/vllm/vllm/model_executor/models/llama.py", line 346, in forward
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]     def forward(
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]   File "/home/huydo/miniconda3/envs/py3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]     return self._call_impl(*args, **kwargs)
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]   File "/home/huydo/miniconda3/envs/py3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]     return forward_call(*args, **kwargs)
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]   File "/home/huydo/miniconda3/envs/py3.12/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]     return fn(*args, **kwargs)
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]            ^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]   File "/home/huydo/miniconda3/envs/py3.12/lib/python3.12/site-packages/torch/fx/graph_module.py", line 784, in call_wrapped
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]     return self._wrapped_call(self, *args, **kwargs)
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]   File "/home/huydo/miniconda3/envs/py3.12/lib/python3.12/site-packages/torch/fx/graph_module.py", line 361, in __call__
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]     raise e
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]   File "/home/huydo/miniconda3/envs/py3.12/lib/python3.12/site-packages/torch/fx/graph_module.py", line 348, in __call__
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]     return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]   File "/home/huydo/miniconda3/envs/py3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]     return self._call_impl(*args, **kwargs)
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]   File "/home/huydo/miniconda3/envs/py3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]     return forward_call(*args, **kwargs)
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]   File "<eval_with_key>.162", line 490, in forward
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]     submod_0 = self.submod_0(l_input_ids_, s0, l_self_modules_embed_tokens_parameters_weight_, l_self_modules_layers_modules_0_modules_input_layernorm_parameters_weight_, l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_weight_, l_positions_, l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_);  l_input_ids_ = l_self_modules_embed_tokens_parameters_weight_ = l_self_modules_layers_modules_0_modules_input_layernorm_parameters_weight_ = l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_weight_ = None
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]   File "/home/huydo/github/pytorch-integration-testing/vllm-benchmarks/vllm/vllm/compilation/backends.py", line 600, in __call__
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]     return self.compiled_graph_for_general_shape(*args)
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]   File "/home/huydo/github/pytorch-integration-testing/vllm-benchmarks/vllm/vllm/compilation/compiler_interface.py", line 318, in compiled_graph
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]     graph_output = inductor_compiled_graph(list_args)
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]   File "/home/huydo/miniconda3/envs/py3.12/lib/python3.12/site-packages/torch/_inductor/codecache.py", line 1478, in __call__
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]     return self.current_callable(inputs)
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]   File "/home/huydo/.cache/vllm/torch_compile_cache/cccb42815a/rank_0/inductor_cache/zx/czxwnqvrpmbnzmt33cqa6igt7rc4blsq45yauaoe5iqqaipyenhs.py", line 362, in call
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]     triton_poi_fused_add_bitwise_and_bitwise_or_embedding_ge_lt_masked_fill_mul_sub_0.run(arg0_1, arg2_1, buf0, triton_poi_fused_add_bitwise_and_bitwise_or_embedding_ge_lt_masked_fill_mul_sub_0_xnumel, grid=grid(triton_poi_fused_add_bitwise_and_bitwise_or_embedding_ge_lt_masked_fill_mul_sub_0_xnumel), stream=stream3)
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]   File "/home/huydo/miniconda3/envs/py3.12/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 879, in run
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]     return launcher(
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]            ^^^^^^^^^
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]   File "<string>", line 13, in launcher
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]   File "/home/huydo/miniconda3/envs/py3.12/lib/python3.12/site-packages/triton/backends/nvidia/driver.py", line 365, in __call__
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374]     self.launch(*args, **kwargs)
(VllmWorker rank=2 pid=3748461) ERROR 02-17 00:02:01 multiproc_executor.py:374] ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
ERROR 02-17 00:02:02 core.py:313] EngineCore hit an exception: Traceback (most recent call last):
ERROR 02-17 00:02:02 core.py:313]   File "/home/huydo/github/pytorch-integration-testing/vllm-benchmarks/vllm/vllm/v1/engine/core.py", line 305, in run_engine_core
ERROR 02-17 00:02:02 core.py:313]     engine_core = EngineCoreProc(*args, **kwargs)
ERROR 02-17 00:02:02 core.py:313]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-17 00:02:02 core.py:313]   File "/home/huydo/github/pytorch-integration-testing/vllm-benchmarks/vllm/vllm/v1/engine/core.py", line 260, in __init__
ERROR 02-17 00:02:02 core.py:313]     super().__init__(vllm_config, executor_class, log_stats)
ERROR 02-17 00:02:02 core.py:313]   File "/home/huydo/github/pytorch-integration-testing/vllm-benchmarks/vllm/vllm/v1/engine/core.py", line 58, in __init__
ERROR 02-17 00:02:02 core.py:313]     num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches(
ERROR 02-17 00:02:02 core.py:313]                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-17 00:02:02 core.py:313]   File "/home/huydo/github/pytorch-integration-testing/vllm-benchmarks/vllm/vllm/v1/engine/core.py", line 107, in _initialize_kv_caches
ERROR 02-17 00:02:02 core.py:313]     available_gpu_memory = self.model_executor.determine_available_memory()
ERROR 02-17 00:02:02 core.py:313]                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-17 00:02:02 core.py:313]   File "/home/huydo/github/pytorch-integration-testing/vllm-benchmarks/vllm/vllm/v1/executor/abstract.py", line 61, in determine_available_memory
ERROR 02-17 00:02:02 core.py:313]     output = self.collective_rpc("determine_available_memory")
ERROR 02-17 00:02:02 core.py:313]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-17 00:02:02 core.py:313]   File "/home/huydo/github/pytorch-integration-testing/vllm-benchmarks/vllm/vllm/v1/executor/multiproc_executor.py", line 133, in collective_rpc
ERROR 02-17 00:02:02 core.py:313]     raise e
ERROR 02-17 00:02:02 core.py:313]   File "/home/huydo/github/pytorch-integration-testing/vllm-benchmarks/vllm/vllm/v1/executor/multiproc_executor.py", line 122, in collective_rpc
ERROR 02-17 00:02:02 core.py:313]     raise result
ERROR 02-17 00:02:02 core.py:313] ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)

The error only manifests when torch.compile cache is already populated and used, for example, when running the same benchmark twice. My understand is that each vlllm worker uses a separate GPU, but the compiled code that it uses could create tensors on a different GPU causing the above error.

Manually clearing the cache directory rm -rf ~/.cache/vllm/torch_compile_cache before each benchmark help avoid the crash. On the other hand, setting VLLM_DISABLE_COMPILE_CACHE doesn't work as I expect and fails with a different error. This seems like a separate bug:

ERROR 02-17 01:02:48 core.py:313] EngineCore hit an exception: Traceback (most recent call last):
ERROR 02-17 01:02:48 core.py:313]   File "/home/huydo/github/pytorch-integration-testing/vllm-benchmarks/vllm/vllm/v1/engine/core.py", line 305, in run_engine_core
ERROR 02-17 01:02:48 core.py:313]     engine_core = EngineCoreProc(*args, **kwargs)
ERROR 02-17 01:02:48 core.py:313]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-17 01:02:48 core.py:313]   File "/home/huydo/github/pytorch-integration-testing/vllm-benchmarks/vllm/vllm/v1/engine/core.py", line 260, in __init__
ERROR 02-17 01:02:48 core.py:313]     super().__init__(vllm_config, executor_class, log_stats)
ERROR 02-17 01:02:48 core.py:313]   File "/home/huydo/github/pytorch-integration-testing/vllm-benchmarks/vllm/vllm/v1/engine/core.py", line 58, in __init__
ERROR 02-17 01:02:48 core.py:313]     num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches(
ERROR 02-17 01:02:48 core.py:313]                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-17 01:02:48 core.py:313]   File "/home/huydo/github/pytorch-integration-testing/vllm-benchmarks/vllm/vllm/v1/engine/core.py", line 107, in _initialize_kv_caches
ERROR 02-17 01:02:48 core.py:313]     available_gpu_memory = self.model_executor.determine_available_memory()
ERROR 02-17 01:02:48 core.py:313]                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-17 01:02:48 core.py:313]   File "/home/huydo/github/pytorch-integration-testing/vllm-benchmarks/vllm/vllm/v1/executor/abstract.py", line 61, in determine_available_memory
ERROR 02-17 01:02:48 core.py:313]     output = self.collective_rpc("determine_available_memory")
ERROR 02-17 01:02:48 core.py:313]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-17 01:02:48 core.py:313]   File "/home/huydo/github/pytorch-integration-testing/vllm-benchmarks/vllm/vllm/v1/executor/multiproc_executor.py", line 133, in collective_rpc
ERROR 02-17 01:02:48 core.py:313]     raise e
ERROR 02-17 01:02:48 core.py:313]   File "/home/huydo/github/pytorch-integration-testing/vllm-benchmarks/vllm/vllm/v1/executor/multiproc_executor.py", line 117, in collective_rpc
ERROR 02-17 01:02:48 core.py:313]     status, result = w.worker_response_mq.dequeue(
ERROR 02-17 01:02:48 core.py:313]                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-17 01:02:48 core.py:313]   File "/home/huydo/github/pytorch-integration-testing/vllm-benchmarks/vllm/vllm/distributed/device_communicators/shm_broadcast.py", line 464, in dequeue
ERROR 02-17 01:02:48 core.py:313]     obj = pickle.loads(buf[1:])
ERROR 02-17 01:02:48 core.py:313]           ^^^^^^^^^^^^^^^^^^^^^
ERROR 02-17 01:02:48 core.py:313] TypeError: BackendCompilerFailed.__init__() missing 1 required positional argument: 'inner_exception'
ERROR 02-17 01:02:48 core.py:313]
(VllmWorker rank=0 pid=146296) ERROR 02-17 01:02:48 multiproc_executor.py:374] WorkerProc hit an exception: %s
(VllmWorker rank=0 pid=146296) ERROR 02-17 01:02:48 multiproc_executor.py:374] Traceback (most recent call last):
(VllmWorker rank=0 pid=146296) ERROR 02-17 01:02:48 multiproc_executor.py:374]   File "/home/huydo/miniconda3/envs/py3.12/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1446, in _call_user_compiler
(VllmWorker rank=0 pid=146296) ERROR 02-17 01:02:48 multiproc_executor.py:374]     compiled_fn = compiler_fn(gm, self.example_inputs())
(VllmWorker rank=0 pid=146296) ERROR 02-17 01:02:48 multiproc_executor.py:374]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=146296) ERROR 02-17 01:02:48 multiproc_executor.py:374]   File "/home/huydo/miniconda3/envs/py3.12/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
(VllmWorker rank=0 pid=146296) ERROR 02-17 01:02:48 multiproc_executor.py:374]     compiled_gm = compiler_fn(gm, example_inputs)

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions