[NO MERGE] Integrate CK varlen cross attention for small-seq (s_q=1, s_kv<=16)#461
[NO MERGE] Integrate CK varlen cross attention for small-seq (s_q=1, s_kv<=16)#461VeeraRajasekhar wants to merge 7 commits intodevfrom
Conversation
Integrate the CK team's unfused variable-length attention HIP kernels from
varlen_attn/ into Transformer Engine's ROCm fused-attn path as a specialized
path for specialized cross-attention (Q length 1, KV length 2-16, large
batch)..
- Add fused_attn_smallseq.hpp and fused_attn_smallseq.cpp under
fused_attn_rocm/: declarations and implementation adapted from
varlen_attn/attn_fwd.cpp and attn_bwd.cpp (scores, mask+softmax, output;
grad_V, grad_attn, softmax bwd, grad_Q/grad_K). Runtime dispatch over
max_seqlen_kv in {2,4,6,8,12,16}, head_dim 128, BF16.
- Add fused_attn_smallseq.cpp to the ROCm fused-attn build in
transformer_engine/common/CMakeLists.txt.
- In fused_attn_ck_fwd: when THD and no bias, branch to small-seq path when
max_seqlen_q==1 and 2<=max_seqlen_kv<=16. On shape query (Aux_CTX_Tensors->size
== 0) skip get_runtime_max_seqlen (cu_seqlens pointers are null); use host
max_seqlen_kv and set output_S to attention-weights shape {max_tokens_q,
h_q, 1, runtime_max_seqlen_kv} and dtype QKV_type. On real run (size >= 2)
call get_runtime_max_seqlen then fused_attn_smallseq_fwd. Use sequence
count b_varlen = max_tokens_q (not segment count b) for get_runtime_max_seqlen,
output_S shape, workspace size, and small-seq fwd so varlen kernel indexing
matches Q and cu_seqlens_kv (THD may pass segment-level cu_seqlens; varlen
kernel expects sequence-level batch).
- In fused_attn_ck_bwd: same THD/small-seq condition. On workspace query
(workspace->data.dptr == nullptr) skip get_runtime_max_seqlen and use host
max_seqlen_kv; on real run call get_runtime_max_seqlen then
fused_attn_smallseq_bwd. Use b_varlen = max_tokens_q_bwd for
get_runtime_max_seqlen, workspace size, and small-seq bwd.
- Reuse softmax LSE auxiliary buffer for attention weights in the small-seq
path (forward write, backward read);
- JAX attention.py: in NVTE_CK block, when THD and q_max_seqlen==1 and
kv_max_seqlen<=16 set softmax_shape = (*batch_shape, attn_heads,
q_max_seqlen, kv_max_seqlen) and softmax_dtype = q_dtype so Python aux
buffer matches C++ attention-weights convention.
- Add test_ck_unfused_smallseq_backend in tests/jax/test_fused_attn.py
(parametrized s_kv in {2,4,6,8,12,16}, b=30720, s_q=1, THD_THD_THD,
SeqDescFormat.Seqlens) and optional NVTE_LOG_CK_SMALLSEQ debug logging in
C++.
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
|
Let's make this PR work for jax extension first. Later we can support pytorch. One key difference btw jax and pytorch fused-attn dispatch is that pytorch can calculate, request, and allocate softmax_aux, workspace during runtime with actual cu_seqlen_q/kv data. However, in jax extension, softmax_aux and workspace calculation is done in TransformerEngine/transformer_engine/jax/cpp_extensions/attention.py Lines 364 to 375 in b685686 General guideline: |
transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp
Outdated
Show resolved
Hide resolved
| NVTE_CHECK(workspace != nullptr, "small-seq bwd requires workspace."); | ||
|
|
||
| float sqr_dk_scale = attn_scale; | ||
| hipStream_t hip_stream = reinterpret_cast<hipStream_t>(stream); |
There was a problem hiding this comment.
Probably no need for this cast. cudaStream_t will be hipified correctly to hipStream_t
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
…port to small-seq kernels
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
- tests/jax: CK small-seq tests use fixture to set/restore NVTE_FUSED_ATTN_CK_SMALLSEQ=1; parametrize dtype (BF16/FP16) and add sequence-packing cases (2048-2-4, 2-4096-8192); when env set, num_segments_per_seq = max_seqlen_q for THD else 2. - JAX attention.py: THD softmax shape/dtype uses small-seq path only when env=1, else original layout - JAX attention.cpp: Added env guard - fused_attn_smallseq: Use TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT for fwd/bwd; add FP16 (__half) support; fix __half*float with T(scale).
Integrate the CK team's unfused variable-length attention HIP kernels from varlen_attn/ into Transformer Engine's ROCm fused-attn path as a specialized path for specialized cross-attention (Q length 1, KV length 2-16, large batch)..
Add fused_attn_smallseq.hpp and fused_attn_smallseq.cpp under fused_attn_rocm/: declarations and implementation adapted from varlen_attn/attn_fwd.cpp and attn_bwd.cpp (scores, mask+softmax, output; grad_V, grad_attn, softmax bwd, grad_Q/grad_K). Runtime dispatch over max_seqlen_kv in {2,4,6,8,12,16}, head_dim 128, BF16.
Add fused_attn_smallseq.cpp to the ROCm fused-attn build in transformer_engine/common/CMakeLists.txt.
In fused_attn_ck_fwd: when THD and no bias, branch to small-seq path when max_seqlen_q==1 and 2<=max_seqlen_kv<=16. On shape query (Aux_CTX_Tensors->size == 0) skip get_runtime_max_seqlen (cu_seqlens pointers are null); use host max_seqlen_kv and set output_S to attention-weights shape {max_tokens_q, h_q, 1, runtime_max_seqlen_kv} and dtype QKV_type. On real run (size >= 2) call get_runtime_max_seqlen then fused_attn_smallseq_fwd. Use sequence count b_varlen = max_tokens_q (not segment count b) for get_runtime_max_seqlen, output_S shape, workspace size, and small-seq fwd so varlen kernel indexing matches Q and cu_seqlens_kv (THD may pass segment-level cu_seqlens; varlen kernel expects sequence-level batch).
In fused_attn_ck_bwd: same THD/small-seq condition. On workspace query (workspace->data.dptr == nullptr) skip get_runtime_max_seqlen and use host max_seqlen_kv; on real run call get_runtime_max_seqlen then fused_attn_smallseq_bwd. Use b_varlen = max_tokens_q_bwd for get_runtime_max_seqlen, workspace size, and small-seq bwd.
Reuse softmax LSE auxiliary buffer for attention weights in the small-seq path (forward write, backward read);
JAX attention.py: in NVTE_CK block, when THD and q_max_seqlen==1 and kv_max_seqlen<=16 set softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen) and softmax_dtype = q_dtype so Python aux buffer matches C++ attention-weights convention.
Add test_ck_unfused_smallseq_backend in tests/jax/test_fused_attn.py (parametrized s_kv in {2,4,6,8,12,16}, b=30720, s_q=1, THD_THD_THD, SeqDescFormat.Seqlens) and optional NVTE_LOG_CK_SMALLSEQ debug logging in C++.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: