-
-
Notifications
You must be signed in to change notification settings - Fork 9k
Closed as not planned
Labels
bugSomething isn't workingSomething isn't workingstaleOver 90 days of inactivityOver 90 days of inactivity
Description
Your current environment
PyTorch version: 2.1.2+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.27.9
Libc version: glibc-2.35
Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.1.58+-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.2.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Tesla T4
Nvidia driver version: 535.104.05
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.6
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 2
On-line CPU(s) list: 0,1
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) CPU @ 2.00GHz
CPU family: 6
Model: 85
Thread(s) per core: 2
Core(s) per socket: 1
Socket(s): 1
Stepping: 3
BogoMIPS: 4000.33
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat md_clear arch_capabilities
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 32 KiB (1 instance)
L1i cache: 32 KiB (1 instance)
L2 cache: 1 MiB (1 instance)
L3 cache: 38.5 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0,1
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Vulnerable; SMT Host state unknown
Vulnerability Meltdown: Vulnerable
Vulnerability Mmio stale data: Vulnerable
Vulnerability Retbleed: Vulnerable
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2: Vulnerable, IBPB: disabled, STIBP: disabled, PBRSB-eIBRS: Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Vulnerable
Versions of relevant libraries:
[pip3] numpy==1.25.2
[pip3] torch==2.1.2
[pip3] torchaudio==2.2.1+cu121
[pip3] torchdata==0.7.1
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.17.1
[pip3] torchvision==0.17.1+cu121
[pip3] triton==2.1.0
[conda] Could not collectROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.4.0.post1
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X 0-1 N/A N/A
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
🐛 Describe the bug
Error in triton kernel when using prefix caching and a quantized model
from vllm import LLM, SamplingParams
prompts = ['0UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
'0UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
'0UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
'0UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
'0UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
'1UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
'1UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
'1UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
'1UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
'1UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
'2UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
'2UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
'2UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
'2UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
'2UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
'3UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
'3UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
'3UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
'3UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
'3UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
'4UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
'4UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
'4UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
'4UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
'4UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA']
sampling_params = SamplingParams(temperature=0, skip_special_tokens=False)
llm = LLM(model="meta-llama/Llama-2-7b-chat-hf", kv_cache_dtype="fp8_e5m2", max_model_len=560, enable_prefix_caching=True)
outputs = llm.generate(prompts, sampling_params)
Fails with
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
[/usr/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py](https://localhost:8080/#) in ast_to_ttir(fn, signature, specialization, constants, debug, arch)
1123 try:
-> 1124 generator.visit(fn.parse())
1125 except CompilationError as e:
AssertionError: First input (fp16) and second input (uint8) must have the same dtype!
The above exception was the direct cause of the following exception:
CompilationError Traceback (most recent call last)
<string> in _fwd_kernel(Q, K, V, K_cache, V_cache, B_Loc, sm_scale, B_Start_Loc, B_Seqlen, B_Ctxlen, block_size, x, Out, stride_b_loc_b, stride_b_loc_s, stride_qbs, stride_qh, stride_qd, stride_kbs, stride_kh, stride_kd, stride_vbs, stride_vh, stride_vd, stride_obs, stride_oh, stride_od, stride_k_cache_bs, stride_k_cache_h, stride_k_cache_d, stride_k_cache_bl, stride_k_cache_x, stride_v_cache_bs, stride_v_cache_h, stride_v_cache_d, stride_v_cache_bl, num_queries_per_kv, BLOCK_M, BLOCK_DMODEL, BLOCK_N, grid, num_warps, num_stages, extern_libs, stream, warmup, device, device_type)
[/usr/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py](https://localhost:8080/#) in ast_to_ttir(fn, signature, specialization, constants, debug, arch)
1131 if node is None:
1132 raise
-> 1133 raise CompilationError(fn.src, node, repr(e)) from e
1134 ret = generator.module
1135 # module takes ownership of the context
CompilationError: at 96:24: (offs_d[:, None] % x) * stride_k_cache_x)
off_v = (
bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
^
AssertionError('First input (fp16) and second input (uint8) must have the same dtype!')
AaronFriel and jonzarecki
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingstaleOver 90 days of inactivityOver 90 days of inactivity