@@ -282,7 +282,8 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
282
282
float * __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
283
283
scalar_t * __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
284
284
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
285
- int max_ctx_blocks, const float * k_scale, const float * v_scale) {
285
+ int max_ctx_blocks, const float * k_scale, const float * v_scale,
286
+ const float * __restrict__ fp8_out_scale_ptr) {
286
287
// clang-format on
287
288
constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
288
289
const auto warpid = threadIdx .x / WARP_SIZE;
@@ -796,7 +797,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
796
797
float * __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
797
798
scalar_t * __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
798
799
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
799
- int max_ctx_blocks, const float * k_scale, const float * v_scale) {
800
+ int max_ctx_blocks, const float * k_scale, const float * v_scale,
801
+ const float * __restrict__ fp8_out_scale_ptr) {
800
802
// clang-format on
801
803
constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
802
804
const auto warpid = threadIdx .x / WARP_SIZE;
@@ -1238,6 +1240,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
1238
1240
1239
1241
// final write to tmp_out after vout accumulation
1240
1242
if (warpid == 0 ) {
1243
+ const float out_scale =
1244
+ (fp8_out_scale_ptr != nullptr ) ? 1 .0f / (*fp8_out_scale_ptr) : 1 .0f ;
1241
1245
_B16x4 vout[QHLOOP][VHELOOP];
1242
1246
// iterate across heads
1243
1247
for (int qh = 0 ; qh < QHLOOP; qh++) {
@@ -1286,7 +1290,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
1286
1290
// max_num_partitions, head_size]
1287
1291
const int * __restrict__ context_lens, // [num_seqs]
1288
1292
const int * __restrict__ query_start_loc_ptr, // [num_seqs]
1289
- const int max_num_partitions) {
1293
+ const int max_num_partitions, const float * __restrict__ fp8_out_scale_ptr ) {
1290
1294
const auto num_heads = gridDim .x ;
1291
1295
const auto head_idx = blockIdx .x ;
1292
1296
const auto seq_idx = blockIdx .y ;
@@ -1464,8 +1468,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
1464
1468
1465
1469
const float inv_global_exp_sum =
1466
1470
__fdividef (1 .0f , shared_global_exp_sum + 1e-6f );
1471
+ const float out_scale =
1472
+ (fp8_out_scale_ptr != nullptr ) ? 1 .0f / (*fp8_out_scale_ptr) : 1 .0f ;
1467
1473
acc *= inv_global_exp_sum;
1468
-
1474
+ acc *= out_scale;
1469
1475
const int64_t query_start_off = static_cast <int64_t >(
1470
1476
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
1471
1477
OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE +
@@ -1505,7 +1511,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel(
1505
1511
float * __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
1506
1512
scalar_t * __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
1507
1513
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
1508
- int max_ctx_blocks, const float * k_scale, const float * v_scale) {
1514
+ int max_ctx_blocks, const float * k_scale, const float * v_scale,
1515
+ const float * __restrict__ fp8_out_scale_ptr) {
1509
1516
UNREACHABLE_CODE
1510
1517
}
1511
1518
@@ -1532,7 +1539,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
1532
1539
float * __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
1533
1540
scalar_t * __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
1534
1541
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
1535
- int max_ctx_blocks, const float * k_scale, const float * v_scale) {
1542
+ int max_ctx_blocks, const float * k_scale, const float * v_scale,
1543
+ const float * __restrict__ fp8_out_scale_ptr) {
1536
1544
UNREACHABLE_CODE
1537
1545
}
1538
1546
@@ -1547,7 +1555,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
1547
1555
const scalar_t * __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
1548
1556
const int * __restrict__ context_lens, // [num_seqs]
1549
1557
const int * __restrict__ query_start_loc_ptr, // [num_seqs]
1550
- const int max_num_partitions) {
1558
+ const int max_num_partitions, const float * __restrict__ fp8_out_scale_ptr ) {
1551
1559
UNREACHABLE_CODE
1552
1560
}
1553
1561
// clang-format on
@@ -1563,7 +1571,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
1563
1571
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
1564
1572
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
1565
1573
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
1566
- max_ctx_blocks, k_scale_ptr, v_scale_ptr);
1574
+ max_ctx_blocks, k_scale_ptr, v_scale_ptr, fp8_out_scale_ptr );
1567
1575
1568
1576
#define LAUNCH_CUSTOM_ATTENTION_MFMA4 (GQA_RATIO ) \
1569
1577
paged_attention_ll4mi_QKV_mfma4_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
@@ -1574,14 +1582,15 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
1574
1582
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
1575
1583
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
1576
1584
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
1577
- max_ctx_blocks, k_scale_ptr, v_scale_ptr);
1585
+ max_ctx_blocks, k_scale_ptr, v_scale_ptr, fp8_out_scale_ptr );
1578
1586
1579
1587
#define LAUNCH_CUSTOM_REDUCTION (NPAR_LOOPS ) \
1580
1588
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \
1581
1589
PARTITION_SIZE, NPAR_LOOPS> \
1582
1590
<<<reduce_grid, reduce_block, 0 , stream>>> ( \
1583
1591
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \
1584
- context_lens_ptr, query_start_loc_ptr, max_num_partitions);
1592
+ context_lens_ptr, query_start_loc_ptr, max_num_partitions, \
1593
+ fp8_out_scale_ptr);
1585
1594
1586
1595
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
1587
1596
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
@@ -1593,7 +1602,7 @@ void paged_attention_custom_launcher(
1593
1602
torch::Tensor& block_tables, torch::Tensor& context_lens,
1594
1603
const std::optional<torch::Tensor>& query_start_loc, int max_context_len,
1595
1604
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
1596
- torch::Tensor& v_scale) {
1605
+ torch::Tensor& v_scale, const c10::optional<torch::Tensor>& fp8_out_scale ) {
1597
1606
int num_seqs = block_tables.size (0 );
1598
1607
int num_heads = query.size (1 );
1599
1608
int head_size = query.size (2 );
@@ -1625,6 +1634,11 @@ void paged_attention_custom_launcher(
1625
1634
int * context_lens_ptr = context_lens.data_ptr <int >();
1626
1635
const float * k_scale_ptr = reinterpret_cast <const float *>(k_scale.data_ptr ());
1627
1636
const float * v_scale_ptr = reinterpret_cast <const float *>(v_scale.data_ptr ());
1637
+ // NOTE: fp8_out_scale is optional.
1638
+ const float * fp8_out_scale_ptr =
1639
+ fp8_out_scale
1640
+ ? reinterpret_cast <const float *>(fp8_out_scale.value ().data_ptr ())
1641
+ : nullptr ;
1628
1642
OUTT* out_ptr = reinterpret_cast <OUTT*>(out.data_ptr ());
1629
1643
1630
1644
const int max_ctx_blocks = DIVIDE_ROUND_UP (max_context_len, BLOCK_SIZE);
@@ -1735,33 +1749,54 @@ void paged_attention_custom_launcher(
1735
1749
}
1736
1750
}
1737
1751
1738
- #define CALL_CUSTOM_LAUNCHER (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, \
1739
- ALIBI_ENABLED) \
1740
- paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
1741
- PSIZE, ALIBI_ENABLED>( \
1742
- out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
1743
- num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
1744
- max_context_len, alibi_slopes, k_scale, v_scale);
1745
-
1746
- #define CALL_CUSTOM_LAUNCHER_ALIBI (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
1747
- PSIZE) \
1748
- if (alibi_slopes) { \
1749
- CALL_CUSTOM_LAUNCHER (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, true ); \
1750
- } else { \
1751
- CALL_CUSTOM_LAUNCHER (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, false ); \
1752
+ #define CALL_CUSTOM_LAUNCHER (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
1753
+ PSIZE, ALIBI_ENABLED) \
1754
+ paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
1755
+ PSIZE, ALIBI_ENABLED>( \
1756
+ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
1757
+ num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
1758
+ max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale);
1759
+
1760
+ #define CALL_CUSTOM_LAUNCHER_ALIBI (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
1761
+ OUTT, PSIZE) \
1762
+ if (alibi_slopes) { \
1763
+ CALL_CUSTOM_LAUNCHER (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \
1764
+ true ); \
1765
+ } else { \
1766
+ CALL_CUSTOM_LAUNCHER (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \
1767
+ false ); \
1752
1768
}
1753
1769
1754
- #define CALL_CUSTOM_LAUNCHER_BLK (T, KVT, KV_DTYPE, HEAD_SIZE ) \
1755
- switch (block_size) { \
1756
- case 16 : \
1757
- CALL_CUSTOM_LAUNCHER_ALIBI (T, KVT, KV_DTYPE, 16 , HEAD_SIZE, 256 ); \
1758
- break ; \
1759
- case 32 : \
1760
- CALL_CUSTOM_LAUNCHER_ALIBI (T, KVT, KV_DTYPE, 32 , HEAD_SIZE, 256 ); \
1761
- break ; \
1762
- default : \
1763
- TORCH_CHECK (false , " Unsupported block size: " , block_size); \
1764
- break ; \
1770
+ #if defined(__HIPCC__) && defined(__gfx90a__)
1771
+ #define CALL_CUSTOM_LAUNCHER_OUT (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE ) \
1772
+ if (fp8_out_scale) { \
1773
+ TORCH_CHECK (false , " fp8 out scale unsupported for gfx90a" ); \
1774
+ } else { \
1775
+ CALL_CUSTOM_LAUNCHER_ALIBI (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
1776
+ 256 ); \
1777
+ }
1778
+ #else
1779
+ #define CALL_CUSTOM_LAUNCHER_OUT (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE ) \
1780
+ if (fp8_out_scale) { \
1781
+ CALL_CUSTOM_LAUNCHER_ALIBI (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
1782
+ uint8_t , 256 ); \
1783
+ } else { \
1784
+ CALL_CUSTOM_LAUNCHER_ALIBI (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
1785
+ 256 ); \
1786
+ }
1787
+ #endif
1788
+
1789
+ #define CALL_CUSTOM_LAUNCHER_BLK (T, KVT, KV_DTYPE, HEAD_SIZE ) \
1790
+ switch (block_size) { \
1791
+ case 16 : \
1792
+ CALL_CUSTOM_LAUNCHER_OUT (T, KVT, KV_DTYPE, 16 , HEAD_SIZE); \
1793
+ break ; \
1794
+ case 32 : \
1795
+ CALL_CUSTOM_LAUNCHER_OUT (T, KVT, KV_DTYPE, 32 , HEAD_SIZE); \
1796
+ break ; \
1797
+ default : \
1798
+ TORCH_CHECK (false , " Unsupported block size: " , block_size); \
1799
+ break ; \
1765
1800
}
1766
1801
1767
1802
#define CALL_CUSTOM_LAUNCHER_BLK_HEAD (T, KVT, KV_DTYPE ) \
@@ -1794,7 +1829,8 @@ void paged_attention(
1794
1829
int64_t block_size, int64_t max_context_len,
1795
1830
const std::optional<torch::Tensor>& alibi_slopes,
1796
1831
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
1797
- torch::Tensor& v_scale) {
1832
+ torch::Tensor& v_scale,
1833
+ const c10::optional<torch::Tensor>& fp8_out_scale) {
1798
1834
// clang-format on
1799
1835
const int head_size = query.size (2 );
1800
1836
if (kv_cache_dtype == " auto" ) {
0 commit comments