Skip to content

Commit 2d7dba5

Browse files
committed
An option to apply fp8 output scale in ROCm custom paged attention and output FP8 tensor
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
1 parent 6d0df0e commit 2d7dba5

File tree

4 files changed

+86
-50
lines changed

4 files changed

+86
-50
lines changed

csrc/rocm/attention.cu

Lines changed: 73 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,8 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
282282
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
283283
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
284284
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
285-
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
285+
int max_ctx_blocks, const float* k_scale, const float* v_scale,
286+
const float* __restrict__ fp8_out_scale_ptr) {
286287
// clang-format on
287288
constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
288289
const auto warpid = threadIdx.x / WARP_SIZE;
@@ -796,7 +797,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
796797
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
797798
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
798799
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
799-
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
800+
int max_ctx_blocks, const float* k_scale, const float* v_scale,
801+
const float* __restrict__ fp8_out_scale_ptr) {
800802
// clang-format on
801803
constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
802804
const auto warpid = threadIdx.x / WARP_SIZE;
@@ -1238,6 +1240,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
12381240

12391241
// final write to tmp_out after vout accumulation
12401242
if (warpid == 0) {
1243+
const float out_scale =
1244+
(fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f;
12411245
_B16x4 vout[QHLOOP][VHELOOP];
12421246
// iterate across heads
12431247
for (int qh = 0; qh < QHLOOP; qh++) {
@@ -1286,7 +1290,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
12861290
// max_num_partitions, head_size]
12871291
const int* __restrict__ context_lens, // [num_seqs]
12881292
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
1289-
const int max_num_partitions) {
1293+
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
12901294
const auto num_heads = gridDim.x;
12911295
const auto head_idx = blockIdx.x;
12921296
const auto seq_idx = blockIdx.y;
@@ -1464,8 +1468,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
14641468

14651469
const float inv_global_exp_sum =
14661470
__fdividef(1.0f, shared_global_exp_sum + 1e-6f);
1471+
const float out_scale =
1472+
(fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f;
14671473
acc *= inv_global_exp_sum;
1468-
1474+
acc *= out_scale;
14691475
const int64_t query_start_off = static_cast<int64_t>(
14701476
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
14711477
OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE +
@@ -1505,7 +1511,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel(
15051511
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
15061512
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
15071513
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
1508-
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
1514+
int max_ctx_blocks, const float* k_scale, const float* v_scale,
1515+
const float* __restrict__ fp8_out_scale_ptr) {
15091516
UNREACHABLE_CODE
15101517
}
15111518

@@ -1532,7 +1539,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
15321539
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
15331540
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
15341541
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
1535-
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
1542+
int max_ctx_blocks, const float* k_scale, const float* v_scale,
1543+
const float* __restrict__ fp8_out_scale_ptr) {
15361544
UNREACHABLE_CODE
15371545
}
15381546

@@ -1547,7 +1555,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
15471555
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
15481556
const int* __restrict__ context_lens, // [num_seqs]
15491557
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
1550-
const int max_num_partitions) {
1558+
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
15511559
UNREACHABLE_CODE
15521560
}
15531561
// clang-format on
@@ -1563,7 +1571,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
15631571
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
15641572
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
15651573
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
1566-
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
1574+
max_ctx_blocks, k_scale_ptr, v_scale_ptr, fp8_out_scale_ptr);
15671575

15681576
#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \
15691577
paged_attention_ll4mi_QKV_mfma4_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
@@ -1574,14 +1582,15 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
15741582
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
15751583
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
15761584
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
1577-
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
1585+
max_ctx_blocks, k_scale_ptr, v_scale_ptr, fp8_out_scale_ptr);
15781586

15791587
#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \
15801588
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \
15811589
PARTITION_SIZE, NPAR_LOOPS> \
15821590
<<<reduce_grid, reduce_block, 0, stream>>>( \
15831591
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \
1584-
context_lens_ptr, query_start_loc_ptr, max_num_partitions);
1592+
context_lens_ptr, query_start_loc_ptr, max_num_partitions, \
1593+
fp8_out_scale_ptr);
15851594

15861595
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
15871596
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
@@ -1593,7 +1602,7 @@ void paged_attention_custom_launcher(
15931602
torch::Tensor& block_tables, torch::Tensor& context_lens,
15941603
const std::optional<torch::Tensor>& query_start_loc, int max_context_len,
15951604
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
1596-
torch::Tensor& v_scale) {
1605+
torch::Tensor& v_scale, const c10::optional<torch::Tensor>& fp8_out_scale) {
15971606
int num_seqs = block_tables.size(0);
15981607
int num_heads = query.size(1);
15991608
int head_size = query.size(2);
@@ -1625,6 +1634,11 @@ void paged_attention_custom_launcher(
16251634
int* context_lens_ptr = context_lens.data_ptr<int>();
16261635
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
16271636
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
1637+
// NOTE: fp8_out_scale is optional.
1638+
const float* fp8_out_scale_ptr =
1639+
fp8_out_scale
1640+
? reinterpret_cast<const float*>(fp8_out_scale.value().data_ptr())
1641+
: nullptr;
16281642
OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr());
16291643

16301644
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
@@ -1735,33 +1749,54 @@ void paged_attention_custom_launcher(
17351749
}
17361750
}
17371751

1738-
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, \
1739-
ALIBI_ENABLED) \
1740-
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
1741-
PSIZE, ALIBI_ENABLED>( \
1742-
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
1743-
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
1744-
max_context_len, alibi_slopes, k_scale, v_scale);
1745-
1746-
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
1747-
PSIZE) \
1748-
if (alibi_slopes) { \
1749-
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, true); \
1750-
} else { \
1751-
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, false); \
1752+
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
1753+
PSIZE, ALIBI_ENABLED) \
1754+
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
1755+
PSIZE, ALIBI_ENABLED>( \
1756+
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
1757+
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
1758+
max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale);
1759+
1760+
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
1761+
OUTT, PSIZE) \
1762+
if (alibi_slopes) { \
1763+
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \
1764+
true); \
1765+
} else { \
1766+
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \
1767+
false); \
17521768
}
17531769

1754-
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \
1755-
switch (block_size) { \
1756-
case 16: \
1757-
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, 16, HEAD_SIZE, 256); \
1758-
break; \
1759-
case 32: \
1760-
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, 32, HEAD_SIZE, 256); \
1761-
break; \
1762-
default: \
1763-
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
1764-
break; \
1770+
#if defined(__HIPCC__) && defined(__gfx90a__)
1771+
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
1772+
if (fp8_out_scale) { \
1773+
TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a"); \
1774+
} else { \
1775+
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
1776+
256); \
1777+
}
1778+
#else
1779+
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
1780+
if (fp8_out_scale) { \
1781+
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
1782+
uint8_t, 256); \
1783+
} else { \
1784+
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
1785+
256); \
1786+
}
1787+
#endif
1788+
1789+
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \
1790+
switch (block_size) { \
1791+
case 16: \
1792+
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \
1793+
break; \
1794+
case 32: \
1795+
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \
1796+
break; \
1797+
default: \
1798+
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
1799+
break; \
17651800
}
17661801

17671802
#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \
@@ -1794,7 +1829,8 @@ void paged_attention(
17941829
int64_t block_size, int64_t max_context_len,
17951830
const std::optional<torch::Tensor>& alibi_slopes,
17961831
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
1797-
torch::Tensor& v_scale) {
1832+
torch::Tensor& v_scale,
1833+
const c10::optional<torch::Tensor>& fp8_out_scale) {
17981834
// clang-format on
17991835
const int head_size = query.size(2);
18001836
if (kv_cache_dtype == "auto") {

csrc/rocm/ops.h

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,12 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
1111
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
1212
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount);
1313

14-
void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
15-
torch::Tensor& max_logits, torch::Tensor& tmp_out,
16-
torch::Tensor& query, torch::Tensor& key_cache,
17-
torch::Tensor& value_cache, int64_t num_kv_heads,
18-
double scale, torch::Tensor& block_tables,
19-
torch::Tensor& context_lens,
20-
const std::optional<torch::Tensor>& query_start_loc,
21-
int64_t block_size, int64_t max_context_len,
22-
const std::optional<torch::Tensor>& alibi_slopes,
23-
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
24-
torch::Tensor& v_scale);
14+
void paged_attention(
15+
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
16+
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
17+
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
18+
torch::Tensor& block_tables, torch::Tensor& context_lens,
19+
const std::optional<torch::Tensor>& query_start_loc, int64_t block_size,
20+
int64_t max_context_len, const std::optional<torch::Tensor>& alibi_slopes,
21+
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
22+
torch::Tensor& v_scale, const c10::optional<torch::Tensor>& fp8_out_scale);

csrc/rocm/torch_bindings.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
4747
" int max_context_len,"
4848
" Tensor? alibi_slopes,"
4949
" str kv_cache_dtype,"
50-
" Tensor k_scale, Tensor v_scale) -> ()");
50+
" Tensor k_scale, Tensor v_scale,"
51+
" Tensor? fp8_out_scale) -> ()");
5152
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
5253
}
5354

vllm/_custom_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,14 @@ def paged_attention_rocm(
117117
kv_cache_dtype: str,
118118
k_scale: torch.Tensor,
119119
v_scale: torch.Tensor,
120+
fp8_out_scale: Optional[torch.Tensor] = None,
120121
) -> None:
121122
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
122123
key_cache, value_cache, num_kv_heads,
123124
scale, block_tables, seq_lens,
124125
query_start_loc, block_size, max_seq_len,
125126
alibi_slopes, kv_cache_dtype, k_scale,
126-
v_scale)
127+
v_scale, fp8_out_scale)
127128

128129

129130
def mla_decode_kvcache_cpu(

0 commit comments

Comments
 (0)