Skip to content

Commit 4e1834f

Browse files
ngimelpytorchmergebot
authored andcommitted
use cooperative schedule in scaled_mm for fast_accum=false (#144809)
This improves perf for large matrices by more than 2x, more detailed benchmark coming. On master ![image](https://github.com/user-attachments/assets/fc6a0987-5b82-475d-a2ff-b46641bb17dc) On this branch <img width="601" alt="image" src="https://github.com/user-attachments/assets/7f55152b-1110-45e4-b2ea-6f274d543869" /> A plot similar to pytorch/ao#1325 (comment) <details> <summary>Benchmarking code:</summary> ```python import torch from triton.testing import do_bench import itertools def fn_aten_scales(a, b, scale_a, scale_b, use_fast_accum=False): return torch._scaled_mm(a, b.t(), scale_a.view(-1, 1), scale_b.view(1, -1), use_fast_accum=use_fast_accum, out_dtype=torch.bfloat16) def fn_aten(a, b, scale, use_fast_accum=False): return torch._scaled_mm(a, b.t(), scale, scale, use_fast_accum=use_fast_accum, out_dtype=torch.bfloat16) for i,j,k in itertools.product(range(9, 15), range(9, 15), range(9, 15)): m = 2**i n = 2**j k = 2**k a=torch.randn(m, k, device="cuda").to(dtype=torch.float8_e4m3fn) b=torch.randn(n, k, device="cuda").to(dtype=torch.float8_e4m3fn) scale_a = torch.randint(1, 11, (a.shape[0],), device="cuda", dtype=torch.float32) scale_b = torch.randint(1, 11, (b.shape[0],), device="cuda", dtype=torch.float32) scale_0 = torch.randn((), device="cuda", dtype=torch.float32) ms_rowwise_fast = do_bench(lambda: fn_aten_scales(a, b, scale_a, scale_b, use_fast_accum=True), warmup=25, rep=50) ms_rowwise_slow = do_bench(lambda: fn_aten_scales(a, b, scale_a, scale_b, use_fast_accum=False), warmup=25, rep=50) ms_tensor_fast = do_bench(lambda: fn_aten(a, b, scale_0, use_fast_accum=True), warmup=25, rep=50) ms_tensor_slow = do_bench(lambda: fn_aten(a, b, scale_0, use_fast_accum=False), warmup=25, rep=50) print(f"m={m}, n={n}, k={k}, fast={ms_rowwise_fast}, slow={ms_rowwise_slow}, ratio_tw={ms_tensor_slow /ms_tensor_fast}, ratio_rw={ms_rowwise_slow / ms_rowwise_fast}") ``` </details> Higher N/K values still have about 40% penalty, perhaps some additional heuristics tweaks would be useful. Pull Request resolved: #144809 Approved by: https://github.com/drisspg
1 parent 0f051ea commit 4e1834f

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed

aten/src/ATen/native/cuda/RowwiseScaledMM.cu

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -105,27 +105,34 @@ using Cast = cutlass::epilogue::fusion::Sm90Compute<
105105
DtypeEpilogue,
106106
cutlass::FloatRoundStyle::round_to_nearest>;
107107

108-
template <bool PingPong, bool FastAccum>
108+
template <bool LargeTile, bool FastAccum>
109109
struct Schedule;
110110

111111
template <>
112-
struct Schedule</*PingPong=*/false, /*FastAccum=*/false> {
112+
struct Schedule</*LargeTile=*/false, /*FastAccum=*/false> {
113113
using type = cutlass::gemm::KernelTmaWarpSpecialized;
114+
using epilogue_type = cutlass::epilogue::TmaWarpSpecialized;
114115
};
115116

116117
template <>
117-
struct Schedule</*PingPong=*/true, /*FastAccum=*/false> {
118-
using type = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
118+
struct Schedule</*LargeTile=*/true, /*FastAccum=*/false> {
119+
// For a 128x128x128 tile with fastAccum = false, using
120+
// pingpong schedule will lead to spilling, and WarpSpecialized w/o pingpong
121+
// is slow
122+
using type = cutlass::gemm::KernelTmaWarpSpecializedCooperative;
123+
using epilogue_type = cutlass::epilogue::TmaWarpSpecializedCooperative;
119124
};
120125

121126
template <>
122-
struct Schedule</*PingPong=*/false, /*FastAccum=*/true> {
127+
struct Schedule</*LargeTile=*/false, /*FastAccum=*/true> {
123128
using type = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
129+
using epilogue_type = cutlass::epilogue::TmaWarpSpecialized;
124130
};
125131

126132
template <>
127-
struct Schedule</*PingPong=*/true, /*FastAccum=*/true> {
133+
struct Schedule</*LargeTile=*/true, /*FastAccum=*/true> {
128134
using type = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
135+
using epilogue_type = cutlass::epilogue::TmaWarpSpecialized;
129136
};
130137

131138
int ceildiv(int a, int b) {
@@ -140,7 +147,6 @@ int round_up_to_nearest_multiple(int a, int b) {
140147
template <
141148
typename TileShape,
142149
typename ClusterShape,
143-
typename PingPong,
144150
typename Transposed,
145151
typename FastAccum,
146152
typename DtypeA,
@@ -226,6 +232,8 @@ void f8f8bf16_rowwise_impl(
226232
Bias,
227233
AccumScale>>;
228234

235+
constexpr bool large_tile = std::is_same_v<TileShape, cute::Shape<cute::_128, cute::_128, cute::_128>>;
236+
229237
using CollectiveEpilogue =
230238
typename cutlass::epilogue::collective::CollectiveBuilder<
231239
ArchTag,
@@ -241,7 +249,7 @@ void f8f8bf16_rowwise_impl(
241249
DtypeOutput,
242250
LayoutOutput,
243251
AlignmentOutput,
244-
cutlass::epilogue::TmaWarpSpecialized,
252+
typename Schedule<large_tile, FastAccum::value>::epilogue_type,
245253
EpilogueEVT>::CollectiveOp;
246254

247255
using CollectiveMainloop =
@@ -259,7 +267,7 @@ void f8f8bf16_rowwise_impl(
259267
ClusterShape,
260268
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
261269
sizeof(typename CollectiveEpilogue::SharedStorage))>,
262-
typename Schedule<PingPong::value, FastAccum::value>::type>::
270+
typename Schedule<large_tile, FastAccum::value>::type>::
263271
CollectiveOp;
264272

265273
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
@@ -370,13 +378,11 @@ void dispatch_fp8_rowwise_kernel_on_tile_size(
370378
return f8f8bf16_rowwise_impl<
371379
/*TileShape=*/cute::Shape<cute::_64, cute::_128, cute::_128>,
372380
ClusterShape,
373-
/*PingPong=*/std::false_type,
374381
Types...>(XQ, WQ, x_scale, w_scale, bias, out, swizzle);
375382
} else {
376383
return f8f8bf16_rowwise_impl<
377384
/*TileShape=*/cute::Shape<cute::_128, cute::_128, cute::_128>,
378385
ClusterShape,
379-
/*PingPong=*/std::true_type,
380386
Types...>(XQ, WQ, x_scale, w_scale, bias, out, swizzle);
381387
}
382388
}

0 commit comments

Comments
 (0)