Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ci/jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ run_test_config() {
run_default_fa 1 test_custom_call_compute.py
run_default_fa 1 test_functions.py
run 1 test_fused_attn.py
XLA_FLAGS='--xla_gpu_graph_level=0' run 1 test_fused_attn.py -k 'test_ck_unfused_smallseq_backend' # CK smallseq with GPU graph disabled
NVTE_CK_USES_FWD_V3=0 NVTE_CK_USES_BWD_V3=0 run_default_fa_lbl "v2" 3 test_fused_attn.py # Using FAv2 for forward and backward pass
run_default_fa 1 test_helper.py
run_default_fa 1 test_layer.py #it effectevly always uses unfused attention
Expand Down
137 changes: 117 additions & 20 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from functools import partial
from math import sqrt
from typing import Tuple, Optional, Dict
import os
import random

import jax
Expand Down Expand Up @@ -329,7 +330,11 @@ class FusedAttnRunner:
# generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
def _get_max_segments_per_sequence(self):
if self.qkv_layout.is_thd():
if 90400 <= get_cudnn_version() < 90500:
if (
90400 <= get_cudnn_version() < 90500
or ( is_hip_extension() and
os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") == "1")
):
return self.num_segments_per_seq
else:
# +1 for testing runtime_segments < max_segments
Expand Down Expand Up @@ -539,27 +544,60 @@ def generate_random_segment_ids(
return segment_ids, segment_pos, segment_pad

if self.qkv_layout.is_thd():
self.num_segments_per_seq = 2
self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
)
self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
# TODO(rewang): record only self attention and find the reason of cross attention
if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv:
self.segment_ids_kv = self.segment_ids_q
self.segment_pos_kv = self.segment_pos_q
self.pad_kv = self.pad_q
else:
# Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support
if self.max_seqlen_q == 1 and is_hip_extension() and os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") == "1":
self.num_segments_per_seq = 1
# Q: deterministic — one segment of length 1 per batch -> cu_seqlen [0,1,2,...,batch_size]
self.segment_ids_q = jnp.ones((self.batch_size, self.max_seqlen_q), dtype=jnp.int32)
self.segment_pos_q = jnp.zeros((self.batch_size, self.max_seqlen_q), dtype=jnp.int32)
self.pad_q = jnp.zeros((self.batch_size, self.max_seqlen_q), dtype=jnp.int32)
self.seqlens_q = jnp.ones((self.batch_size, 1), dtype=jnp.int32)
self.offsets_q = jnp.concatenate(
[
jnp.arange(self.batch_size, dtype=jnp.int32)[:, None],
jnp.full((self.batch_size, 1), -1, dtype=jnp.int32),
],
axis=1,
)

# KV: one segment per batch (num_segments_per_seq=1) to match smallseq kernel
min_segment_len = None if self.window_size is None else self.seqlens_q
self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
self.batch_size,
self.max_seqlen_kv,
self.num_segments_per_seq,
seed=2024,
min_segment_len=min_segment_len,
self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = (
generate_random_segment_ids(
self.batch_size,
self.max_seqlen_kv,
self.num_segments_per_seq, # 1 for s_q=1 path
seed=2024,
min_segment_len=min_segment_len,
)
)
self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(
self.segment_ids_kv
)
self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
else:
if is_hip_extension() and os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") == "1":
self.num_segments_per_seq = self.max_seqlen_q
else:
self.num_segments_per_seq = 2
self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
)
self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
# TODO(rewang): record only self attention and find the reason of cross attention
if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv:
self.segment_ids_kv = self.segment_ids_q
self.segment_pos_kv = self.segment_pos_q
self.pad_kv = self.pad_q
else:
# Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support
min_segment_len = None if self.window_size is None else self.seqlens_q
self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
self.batch_size,
self.max_seqlen_kv,
self.num_segments_per_seq,
seed=2024,
min_segment_len=min_segment_len,
)
self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
else:
self.num_segments_per_seq = 1
self.segment_ids_q, self.pad_q = gen_valid(
Expand Down Expand Up @@ -1214,3 +1252,62 @@ def test_jax_new_rng():
)
runner = FusedAttnRunner(**kwargs)
runner.test_forward()


# ROCm CK small-seq varlen tests.
@pytest.mark.skipif(
not is_hip_extension(), reason="CK unfused smallseq backend only available on AMD hardware"
)

@pytest.fixture
def ck_smallseq_env(monkeypatch):
"""Enable CK small-seq path and disable XLA GPU graphs for these tests."""
if "xla_gpu_graph_level=0" not in os.environ.get("XLA_FLAGS", ""):
pytest.skip("Run with XLA_FLAGS='--xla_gpu_graph_level=0' pytest ...")
monkeypatch.setenv("NVTE_FUSED_ATTN_CK_SMALLSEQ", "1")
yield

@pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16], ids=["BF16", "FP16"])
@pytest.mark.parametrize(
"b, s_q, s_kv, h_q, h_kv, d_qk, d_v",
[
pytest.param(4000, 1, 2, 16, 16, 128, 128, id="4000-1-2-16-16-128-128"),
pytest.param(4000, 1, 4, 16, 16, 128, 128, id="4000-1-4-16-16-128-128"),
pytest.param(4000, 1, 6, 16, 16, 128, 128, id="4000-1-6-16-16-128-128"),
pytest.param(4000, 1, 8, 16, 16, 128, 128, id="4000-1-8-16-16-128-128"),
pytest.param(4000, 1, 12, 16, 16, 128, 128, id="4000-1-12-16-16-128-128"),
pytest.param(4000, 1, 16, 16, 16, 128, 128, id="4000-1-16-16-16-128-128"),
pytest.param(2048, 2, 4, 16, 16, 128, 128, id="seqpack-2048-2-4-16-16-128-128"),
pytest.param(2, 4096, 8192, 16, 16, 128, 128, id="seqpack-2-4096-8192-16-16-128-128"),
],
)
def test_ck_unfused_smallseq_backend(
ck_smallseq_env, b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype
):
"""
Test the CK unfused small-seq (varlen) path on ROCm: s_q=1, s_kv<=16, THD layout.
Uses THD_THD_THD (Q,K,V all THD). ck_smallseq_env sets NVTE_FUSED_ATTN_CK_SMALLSEQ=1 and
restores it after the test.
"""
runner = FusedAttnRunner(
batch_size=b,
max_seqlen_q=s_q,
max_seqlen_kv=s_kv,
num_heads_q=h_q,
num_heads_kv=h_kv,
head_dim_qk=d_qk,
head_dim_v=d_v,
attn_bias_type=AttnBiasType.NO_BIAS,
attn_mask_type=AttnMaskType.PADDING_MASK,
dropout_prob=0.0,
use_old_rng=True,
dtype=dtype,
is_training=True,
qkv_layout=QKVLayout.THD_THD_THD,
bias_shape=None,
window_size=None,
seq_desc_format=SeqDescFormat.Seqlens,
)
runner._setup_inputs()
# runner.test_forward()
runner.test_backward()
1 change: 1 addition & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ else()
fused_attn_rocm/fused_attn.cpp
fused_attn_rocm/fused_attn_aotriton.cpp
fused_attn_rocm/fused_attn_ck.cpp
fused_attn_rocm/fused_attn_smallseq.cpp
fused_attn_rocm/utils.cpp
gemm/rocm_gemm.cu
amd_detail/system.cpp)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*************************************************************************
* Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
*
* License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/
Expand Down Expand Up @@ -168,6 +168,12 @@ hipError_t ck_attn_varlen_bwd(
int how_v3_bf16_cvt,
hipStream_t stream);

uint64_t get_runtime_max_seqlen(uint64_t b,
const void* cu_seqlen_ptr,
const void* cu_seqlen_padded_ptr,
void* workspace,
hipStream_t stream);

}//namespace ck_fused_attn
#endif // CK_FUSED_ATTN_H

65 changes: 62 additions & 3 deletions transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*************************************************************************
* Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
*
* License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/
Expand All @@ -9,6 +9,7 @@
#include <numeric> // Required for std::accumulate
#ifdef USE_FUSED_ATTN_CK
#include <ck_fused_attn/ck_fused_attn.hpp>
#include "fused_attn_smallseq.h"
#endif // USE_FUSED_ATTN_CK
#include "../util/cuda_runtime.h"
#include "../util/system.h"
Expand Down Expand Up @@ -614,6 +615,34 @@ void fused_attn_ck_fwd_impl(
// denote the next available section of workspace from upstream
void* workspace_next = workspace;

const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ");
if (is_ragged && nvte_smallseq && std::string(nvte_smallseq) == "1") {
void* max_seqlen_workspace = workspace;

size_t runtime_max_seqlen_q = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen(
static_cast<uint64_t>(b), devPtrCuSeqlensQ, nullptr, max_seqlen_workspace, stream));
size_t runtime_max_seqlen_kv = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen(
static_cast<uint64_t>(b), devPtrCuSeqlensKV, nullptr, max_seqlen_workspace, stream));

if (std::getenv("NVTE_LOG_CK_CONFIG")) {
std::cout << std::endl << "attn_fwd(ck small-seq): ";
std::cout << "b: " << b << ", ";
std::cout << "runtime_max_seqlen_q: " << runtime_max_seqlen_q << ", ";
std::cout << "runtime_max_seqlen_kv: " << runtime_max_seqlen_kv << std::endl;
}

if (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16) {
fused_attn_rocm::fused_attn_smallseq_fwd(
b, h, hg, runtime_max_seqlen_kv, d_qk, d_v,
is_training, scaling_factor, dropout_probability,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxAux,
devPtrCuSeqlensKV, devPtrSeqOffsetsKV,
devPtrDropoutSeed, devPtrDropoutOffset,
dtype, workspace, workspace_size, stream);
return;
}
}

std::array<uint64_t, 4> q_stride;
std::array<uint64_t, 4> k_stride;
std::array<uint64_t, 4> v_stride;
Expand Down Expand Up @@ -916,6 +945,35 @@ void fused_attn_ck_bwd_impl(
// denote the next available section of workspace from upstream
void* workspace_next = workspace;

const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ");
if (is_ragged && nvte_smallseq && std::string(nvte_smallseq) == "1") {
void* max_seqlen_workspace = workspace;

size_t runtime_max_seqlen_q = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen(
b, devPtrCuSeqlensQ, nullptr, max_seqlen_workspace, stream));
size_t runtime_max_seqlen_kv = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen(
b, devPtrCuSeqlensKV, nullptr, max_seqlen_workspace, stream));

if (std::getenv("NVTE_LOG_CK_CONFIG")) {
std::cout << std::endl << "attn_bwd(ck small-seq): ";
std::cout << "b: " << b << ", ";
std::cout << "runtime_max_seqlen_q: " << runtime_max_seqlen_q << ", ";
std::cout << "runtime_max_seqlen_kv: " << runtime_max_seqlen_kv << std::endl;
}

if (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16) {

fused_attn_rocm::fused_attn_smallseq_bwd(
b, h, hg, runtime_max_seqlen_kv, d_qk, d_v,
scaling_factor, dropout_probability,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrdO, devPtrSoftmaxAux,
devPtrdQ, devPtrdK, devPtrdV,
devPtrCuSeqlensKV, devPtrSeqOffsetsKV,
dtype, workspace, workspace_size, stream);
return;
}
}

std::array<uint64_t, 4> q_stride;
std::array<uint64_t, 4> k_stride;
std::array<uint64_t, 4> v_stride;
Expand Down Expand Up @@ -1828,7 +1886,8 @@ void fused_attn_ck_fwd(
size_t max_tokens_q = std::accumulate((input_Q->data).shape.begin(), (input_Q->data).shape.end(), static_cast<size_t>(1), std::multiplies<size_t>())/h_q/d_qk;
size_t max_tokens_kv = std::accumulate((input_K->data).shape.begin(), (input_K->data).shape.end(), static_cast<size_t>(1), std::multiplies<size_t>())/h_kv/d_qk;

bool is_ragged = nvte_get_qkv_format(qkv_layout)==NVTE_QKV_Format::NVTE_THD;
bool is_ragged = nvte_get_qkv_format(qkv_layout)==NVTE_QKV_Format::NVTE_THD;

if (Aux_CTX_Tensors->size == 0) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3;
Expand Down Expand Up @@ -1883,7 +1942,7 @@ void fused_attn_ck_fwd(
bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK);

fused_attn_ck_fwd_impl(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, bias_b, bias_h,
max_tokens_q, max_tokens_kv,
Expand Down
Loading
Loading