-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[Perf] Further tunings for SM100 FP8 CUTLASS kernel #19566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
simon-mo
merged 1 commit into
vllm-project:main
from
neuralmagic:imarkov/fp8_cutlass_configs
Jun 15, 2025
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,11 +15,25 @@ using c3x::cutlass_gemm_caller; | |
template <typename InType, typename OutType, | ||
template <typename, typename, typename> typename Epilogue> | ||
struct sm100_fp8_config_default { | ||
// M in (128, inf) | ||
// M in (256, inf) | ||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); | ||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; | ||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; | ||
using TileShape = Shape<_256, _128, _64>; | ||
using TileShape = Shape<_256, _128, _128>; | ||
using ClusterShape = Shape<_2, _2, _1>; | ||
using Cutlass3xGemm = | ||
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape, | ||
KernelSchedule, EpilogueSchedule>; | ||
}; | ||
|
||
template <typename InType, typename OutType, | ||
template <typename, typename, typename> typename Epilogue> | ||
struct sm100_fp8_config_M256 { | ||
// M in (128, 256] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); | ||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; | ||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; | ||
using TileShape = Shape<_128, _128, _128>; | ||
using ClusterShape = Shape<_2, _2, _1>; | ||
using Cutlass3xGemm = | ||
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape, | ||
|
@@ -33,8 +47,8 @@ struct sm100_fp8_config_M128 { | |
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); | ||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; | ||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; | ||
using TileShape = Shape<_128, _128, _64>; | ||
using ClusterShape = Shape<_2, _2, _1>; | ||
using TileShape = Shape<_128, _128, _256>; | ||
using ClusterShape = Shape<_2, _4, _1>; | ||
using Cutlass3xGemm = | ||
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape, | ||
KernelSchedule, EpilogueSchedule>; | ||
|
@@ -72,6 +86,8 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, | |
typename sm100_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm; | ||
using Cutlass3xGemmM128 = | ||
typename sm100_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm; | ||
using Cutlass3xGemmM256 = | ||
typename sm100_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm; | ||
|
||
uint32_t const m = a.size(0); | ||
uint32_t const mp2 = | ||
|
@@ -85,8 +101,12 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, | |
// m in (64, 128] | ||
return cutlass_gemm_caller<Cutlass3xGemmM128>( | ||
out, a, b, std::forward<EpilogueArgs>(args)...); | ||
} else if (mp2 <= 256) { | ||
// m in (128, 256] | ||
return cutlass_gemm_caller<Cutlass3xGemmM256>( | ||
out, a, b, std::forward<EpilogueArgs>(args)...); | ||
} else { | ||
// m in (128, inf) | ||
// m in (256, inf) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
return cutlass_gemm_caller<Cutlass3xGemmDefault>( | ||
out, a, b, std::forward<EpilogueArgs>(args)...); | ||
} | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment indicates that this configuration applies to M in the range (256, inf). Double check that this is the intended range, and that the upper bound is indeed unbounded.