Skip to content

[Perf] Further tunings for SM100 FP8 CUTLASS kernel #19566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 15, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,25 @@ using c3x::cutlass_gemm_caller;
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_default {
// M in (128, inf)
// M in (256, inf)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment indicates that this configuration applies to M in the range (256, inf). Double check that this is the intended range, and that the upper bound is indeed unbounded.

static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_256, _128, _64>;
using TileShape = Shape<_256, _128, _128>;
using ClusterShape = Shape<_2, _2, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};

template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_M256 {
// M in (128, 256]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment indicates that this configuration applies to M in the range (128, 256]. Consider adding a unit test to verify that this config is used when M is 129, 192, and 256.

static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _2, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
Expand All @@ -33,8 +47,8 @@ struct sm100_fp8_config_M128 {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_128, _128, _64>;
using ClusterShape = Shape<_2, _2, _1>;
using TileShape = Shape<_128, _128, _256>;
using ClusterShape = Shape<_2, _4, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
Expand Down Expand Up @@ -72,6 +86,8 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
typename sm100_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
typename sm100_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM256 =
typename sm100_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;

uint32_t const m = a.size(0);
uint32_t const mp2 =
Expand All @@ -85,8 +101,12 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
// m in (64, 128]
return cutlass_gemm_caller<Cutlass3xGemmM128>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 256) {
// m in (128, 256]
return cutlass_gemm_caller<Cutlass3xGemmM256>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// m in (128, inf)
// m in (256, inf)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment indicates that this configuration applies to M in the range (256, inf). Consider adding a unit test to verify that this config is used when M is 257, 512, and 1024.

return cutlass_gemm_caller<Cutlass3xGemmDefault>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
Expand Down