Skip to content

[Kernel] Adding basic Triton JitCache for triton_attn #16606

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 49 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
b3a01dc
copying jit_cache, adapting jit cache for 2d kernel
bringlein Apr 10, 2025
3490bfc
some cleanup
bringlein Apr 10, 2025
4cc5407
formatting, typos...
bringlein Apr 10, 2025
59755e2
ruff....
bringlein Apr 10, 2025
f114090
adding assume const to jit cache
bringlein Apr 11, 2025
c43006e
experimenting with static launch grid again
bringlein Apr 11, 2025
9da4df6
recovering good performance
bringlein Apr 11, 2025
d7fc0af
going back to static launch grid
bringlein Apr 14, 2025
bf64b6d
formatting...
bringlein Apr 14, 2025
f3fb7e9
make type checking of key arguments more helpful
bringlein Apr 14, 2025
dc3b28c
applying jit cache for prefix prefill
bringlein Apr 14, 2025
e717040
fmt & ruff
bringlein Apr 14, 2025
fe2f6a5
ci
bringlein Apr 14, 2025
14cca7e
remove changed requirements by mistake/pre-hook?
bringlein Apr 14, 2025
d37ef48
fmt...
bringlein Apr 14, 2025
5e4bb2f
removing jit cache from prefix prefill again
bringlein Apr 15, 2025
c711433
cleanup
bringlein Apr 15, 2025
f8c6610
address review comments
bringlein Apr 24, 2025
f8b5001
fix type hints
bringlein Apr 24, 2025
ef3d6a3
add transparency as fallback mode
bringlein Apr 24, 2025
edf8633
CI whacamole
bringlein Apr 24, 2025
10df1df
CI whacamole...
bringlein Apr 24, 2025
cf1cea9
Merge branch 'main' into ngl_jit_cache_pr
bringlein May 7, 2025
f6852ed
adding triton 3.3 support
bringlein May 8, 2025
b93de23
Merge branch 'main' into ngl_jit_cache_pr
bringlein May 8, 2025
72d9858
fixing triton 3.3 support (1/x); add support for unified kernel
bringlein May 8, 2025
eeaab8d
fixing triton 3.3 support (2/2)
bringlein May 9, 2025
9ffc6e4
cleanup and add env var
bringlein May 9, 2025
1c65d75
adding assume_const
bringlein May 9, 2025
43b500b
make argument passing (slightly) faster
bringlein May 9, 2025
43aed8c
Merge branch 'main' into ngl_jit_cache_pr (moving envs content)
bringlein May 13, 2025
e50534a
fixing env var merge conflict
bringlein May 13, 2025
450770c
adding attention metadata specific for triton_backend
bringlein May 13, 2025
f7705c0
fixing env file again
bringlein May 13, 2025
3a5c63e
Revert "adding attention metadata specific for triton_backend"
bringlein May 13, 2025
e2ef23e
more elegant fix on dependency of flash attention
bringlein May 13, 2025
8f5735b
thrid way to un-break triton backend
bringlein May 13, 2025
ccd22c9
CI...
bringlein May 13, 2025
a94e99b
making jitcache safe to use with autotuner
cyang49 May 14, 2025
af094a3
CI whacamole...
bringlein May 14, 2025
c1b21d5
fixup spelling in a few spots
tlrmchlsmth May 20, 2025
be9d7d4
Merge branch 'main' into ngl_jit_cache_pr
tdoublep May 23, 2025
791b8b2
Added support for specialization.
tdoublep May 23, 2025
f4a436a
Merge branch 'main' into ngl_jit_cache_pr
bringlein Jun 12, 2025
f72a768
minor cleanup; remove copy of launch grid
bringlein Jun 13, 2025
02a6ea4
improve docstring
bringlein Jun 18, 2025
cd987c2
Merge branch 'main' into ngl_jit_cache_pr
bringlein Jun 18, 2025
d52af9b
ruff....
bringlein Jun 18, 2025
e1cf444
fixing merge error
bringlein Jun 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions vllm/attention/ops/triton_unified_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import triton.language as tl

from vllm.logger import init_logger
from vllm.triton_utils.jit_cache import jitcache

logger = init_logger(__name__)

Expand Down Expand Up @@ -47,6 +48,24 @@ def find_seq_idx(query_start_len_ptr, target_idx, num_seqs,
return left - 1


@jitcache(
check_keys=[],
check_specialization=["num_seqs"],
assume_const=[
"scale",
"k_scale",
"v_scale",
"query_stride_1",
"output_stride_1",
"stride_k_cache_0",
"stride_k_cache_1",
"stride_k_cache_2",
"stride_k_cache_4",
"stride_v_cache_0",
"stride_v_cache_1",
"stride_v_cache_2",
],
)
@triton.jit
def kernel_unified_attention_2d(
output_ptr, # [num_tokens, num_query_heads, head_size]
Expand Down
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
VLLM_PLUGINS: Optional[list[str]] = None
VLLM_LORA_RESOLVER_CACHE_DIR: Optional[str] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
VLLM_TRITON_ENABLE_JITCACHE: bool = False
VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False
Expand Down Expand Up @@ -589,6 +590,11 @@ def get_vllm_port() -> Optional[int]:
lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os
.path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))),

# Enable the JITCache for Triton Kernels
# see triton_utils/jitcache.py
"VLLM_TRITON_ENABLE_JITCACHE":
lambda: bool(int(os.getenv("VLLM_TRITON_ENABLE_JITCACHE", "0"))),

# If set, vLLM will use Triton implementations of AWQ.
"VLLM_USE_TRITON_AWQ":
lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))),
Expand Down
Loading