Skip to content

Commit 5d31fe7

Browse files
committed
update triton version to 2.2.0
1 parent 9970b79 commit 5d31fe7

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@ uvicorn[standard]
1111
pydantic >= 2.0 # Required for OpenAI server.
1212
prometheus_client >= 0.18.0
1313
pynvml == 11.5.0
14-
triton >= 2.1.0
14+
triton >= 2.2.0
1515
outlines >= 0.0.27
1616
cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.

vllm/model_executor/layers/triton_kernel/prefix_prefill.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
import packaging
88

99
assert packaging.version.parse(triton.__version__) >= packaging.version.parse(
10-
"2.1.0"), "Triton version >= 2.1.0 is required."
11-
10+
"2.2.0"), "Triton version >= 2.2.0 is required."
1211

1312
@triton.jit
1413
def _fwd_kernel(
@@ -99,7 +98,7 @@ def _fwd_kernel(
9998
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
10099
k = tl.load(K_cache + off_k,
101100
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
102-
other=0.0)
101+
other=0.0).to(q.dtype)
103102

104103
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
105104
qk += tl.dot(q, k)
@@ -126,7 +125,7 @@ def _fwd_kernel(
126125
# update acc
127126
v = tl.load(V_cache + off_v,
128127
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
129-
other=0.0)
128+
other=0.0).to(k.dtype)
130129

131130
p = p.to(v.dtype)
132131
acc += tl.dot(p, v)

0 commit comments

Comments
 (0)