File tree Expand file tree Collapse file tree 2 files changed +4
-5
lines changed
vllm/model_executor/layers/triton_kernel Expand file tree Collapse file tree 2 files changed +4
-5
lines changed Original file line number Diff line number Diff line change @@ -11,6 +11,6 @@ uvicorn[standard]
11
11
pydantic >= 2.0 # Required for OpenAI server.
12
12
prometheus_client >= 0.18.0
13
13
pynvml == 11.5.0
14
- triton >= 2.1 .0
14
+ triton >= 2.2 .0
15
15
outlines >= 0.0.27
16
16
cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.
Original file line number Diff line number Diff line change 7
7
import packaging
8
8
9
9
assert packaging .version .parse (triton .__version__ ) >= packaging .version .parse (
10
- "2.1.0" ), "Triton version >= 2.1.0 is required."
11
-
10
+ "2.2.0" ), "Triton version >= 2.2.0 is required."
12
11
13
12
@triton .jit
14
13
def _fwd_kernel (
@@ -99,7 +98,7 @@ def _fwd_kernel(
99
98
(start_n + offs_n [:, None ]) % block_size * stride_v_cache_bl )
100
99
k = tl .load (K_cache + off_k ,
101
100
mask = (start_n + offs_n [None , :]) < cur_batch_ctx_len ,
102
- other = 0.0 )
101
+ other = 0.0 ). to ( q . dtype )
103
102
104
103
qk = tl .zeros ([BLOCK_M , BLOCK_N ], dtype = tl .float32 )
105
104
qk += tl .dot (q , k )
@@ -126,7 +125,7 @@ def _fwd_kernel(
126
125
# update acc
127
126
v = tl .load (V_cache + off_v ,
128
127
mask = (start_n + offs_n [:, None ]) < cur_batch_ctx_len ,
129
- other = 0.0 )
128
+ other = 0.0 ). to ( k . dtype )
130
129
131
130
p = p .to (v .dtype )
132
131
acc += tl .dot (p , v )
You can’t perform that action at this time.
0 commit comments