Skip to content

Commit

Permalink
Support for TMA Epilogue for Group Gemm and add pingpong ptr array & …
Browse files Browse the repository at this point in the history
…Group Gemm (NVIDIA#1795)
  • Loading branch information
Junkai-Wu authored Sep 11, 2024
1 parent 21d0534 commit dbdae51
Show file tree
Hide file tree
Showing 23 changed files with 2,359 additions and 347 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,40 +95,66 @@ constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // M
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; // Epilogue to launch

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC,
ElementC, LayoutC, AlignmentC,
EpilogueSchedule
>::CollectiveOp;

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule
>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cutlass::gemm::ArrayProblemShape<Shape<int,int,int,int>>,
CollectiveMainloop,
CollectiveEpilogue
>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

// Different configs for pingpong/cooperative
struct CooperativeConfig {
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
using TileShape = Shape<_256,_128,_64>;
using ClusterShape = Shape<_1,_2,_1>;
};

struct PingpongConfig {
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = Shape<_64,_128,_64>;
using ClusterShape = Shape<_1,_1,_1>;
};

template <typename ScheduleConfig>
struct GemmGivenSchedule {
using TileShape = typename ScheduleConfig::TileShape; // Threadblock-level tile size
using ClusterShape = typename ScheduleConfig::ClusterShape; // Shape of the threadblocks in a cluster
using KernelSchedule = typename ScheduleConfig::KernelSchedule; // Kernel to launch
using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; // Epilogue to launch

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC,
ElementC, LayoutC, AlignmentC,
EpilogueSchedule
>::CollectiveOp;

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule
>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cutlass::gemm::ArrayProblemShape<Shape<int,int,int,int>>,
CollectiveMainloop,
CollectiveEpilogue
>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
};

using GemmKernel = GemmGivenSchedule<CooperativeConfig>::GemmKernel;
using Gemm = GemmGivenSchedule<CooperativeConfig>::Gemm;

using GemmKernelPingpong = GemmGivenSchedule<PingpongConfig>::GemmKernel;
using GemmPingpong = GemmGivenSchedule<PingpongConfig>::Gemm;


// Reference device GEMM implementation type
using DeviceGemmReference = cutlass::reference::device::Gemm<
Expand Down Expand Up @@ -261,14 +287,14 @@ bool initialize_block(
int bits_input = cutlass::sizeof_bits<Element>::value;

if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
scope_max = static_cast<Element>(2);
scope_min = static_cast<Element>(0);
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
scope_max = static_cast<Element>(2);
scope_min = static_cast<Element>(-2);
} else {
scope_max = 8;
scope_min = -8;
scope_max = static_cast<Element>(8);
scope_min = static_cast<Element>(-8);
}

cutlass::reference::device::BlockFillRandomUniform(
Expand Down Expand Up @@ -351,15 +377,16 @@ void initialize(const Options &options) {
}

/// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options)
template <typename GemmT>
typename GemmT::Arguments args_from_options(const Options &options)
{
cutlass::KernelHardwareInfo hw_info;
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
// to use a GPU other than that with device ID 0.
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);

typename Gemm::Arguments arguments{
typename GemmT::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kArray,
{{options.m, options.n, options.k, options.l}},
{ptr_A.get(), stride_A, ptr_B.get(), stride_B},
Expand Down Expand Up @@ -405,20 +432,20 @@ bool verify(const Options &options) {
}

/// Execute a given example GEMM computation
template <typename Gemm>
template <typename GemmT>
int run(Options &options)
{
allocate(options);
initialize(options);

// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
GemmT gemm;

// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
auto arguments = args_from_options<GemmT>(options);

// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
size_t workspace_size = GemmT::get_workspace_size(arguments);

// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
Expand Down Expand Up @@ -510,7 +537,10 @@ int main(int argc, char const **args) {
//

#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
std::cout << "\n*** Cooperative schedule ***" << std::endl;
run<Gemm>(options);
std::cout << "\n*** Pingpong schedule ***" << std::endl;
run<GemmPingpong>(options);
#endif

return 0;
Expand Down
105 changes: 77 additions & 28 deletions examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -117,20 +117,39 @@ constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // A
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_256,_128,_128>; // Threadblock-level tile size
using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
// Different configs for pingpong/cooperative
struct CooperativeConfig {
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
using TileShape = Shape<_256,_128,_128>;
using ClusterShape = Shape<_2,_2,_1>;
};

struct PingpongConfig {
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = Shape<_128,_128,_128>;
using ClusterShape = Shape<_2,_1,_1>;
};

template <typename ScheduleConfig>
struct GemmGivenSchedule {
using TileShape = typename ScheduleConfig::TileShape; // Threadblock-level tile size
using ClusterShape = typename ScheduleConfig::ClusterShape; // Shape of the threadblocks in a cluster
using KernelSchedule = typename ScheduleConfig::KernelSchedule; // Kernel to launch
using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; // Epilogue to launch

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC *, AlignmentC,
ElementC, LayoutC *, AlignmentC,
EpilogueSchedule
EpilogueSchedule,
cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>
>::CollectiveOp;

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
Expand All @@ -144,13 +163,20 @@ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder
KernelSchedule
>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloop,
CollectiveEpilogue
>;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloop,
CollectiveEpilogue
>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
};

using GemmKernel = GemmGivenSchedule<CooperativeConfig>::GemmKernel;
using Gemm = GemmGivenSchedule<CooperativeConfig>::Gemm;

using GemmKernelPingpong = GemmGivenSchedule<PingpongConfig>::GemmKernel;
using GemmPingpong = GemmGivenSchedule<PingpongConfig>::Gemm;

// Reference device GEMM implementation type
using DeviceGemmReference = cutlass::reference::device::Gemm<
Expand Down Expand Up @@ -271,10 +297,10 @@ struct Options {
int n = cmd_line_n;
int k = cmd_line_k;
if (m < 1) {
m = ((rand() % 512) + 1);
m = alignment * ((rand() % 64) + 1);
}
if (n < 1) {
n = ((rand() % 512) + 1);
n = alignment * ((rand() % 64) + 1);
}
if (k < 1) {
k = alignment * ((rand() % 64) + 1);
Expand Down Expand Up @@ -521,41 +547,58 @@ void initialize(const Options &options) {
}

/// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options, bool host_problem_shapes_available = true)
template <typename GemmT>
typename GemmT::Arguments args_from_options(const Options &options, bool host_problem_shapes_available = true)
{
cutlass::KernelHardwareInfo hw_info;
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
// to use a GPU other than that with device ID 0.
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);

typename Gemm::EpilogueOutputOp::Params params;
typename GemmT::Arguments arguments;
decltype(arguments.epilogue.thread) fusion_args;

if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
// If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
params = typename Gemm::EpilogueOutputOp::Params(
ElementAccumulator(options.alpha), ElementAccumulator(options.beta));
fusion_args.alpha = options.alpha;
fusion_args.beta = options.beta;
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = nullptr;
fusion_args.beta_ptr_array = nullptr;
// Single alpha and beta for all groups
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
}
else {
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
params = typename Gemm::EpilogueOutputOp::Params(alpha_device.get(), beta_device.get());
fusion_args.alpha = 0;
fusion_args.beta = 0;
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = alpha_device.get();
fusion_args.beta_ptr_array = beta_device.get();
// One alpha and beta per each group
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1};
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
}

typename Gemm::Arguments arguments;
if (host_problem_shapes_available) {
arguments = typename Gemm::Arguments {
arguments = typename GemmT::Arguments {
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), options.problem_sizes_host.data()},
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
{params, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info
};
}
else {
arguments = typename Gemm::Arguments {
arguments = typename GemmT::Arguments {
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), nullptr},
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
{params, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info
};
}
Expand Down Expand Up @@ -605,20 +648,20 @@ bool verify(const Options &options) {
}

/// Execute a given example GEMM computation
template <typename Gemm>
template <typename GemmT>
int run(Options &options, bool host_problem_shapes_available = true)
{
allocate(options);
initialize(options);

// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
GemmT gemm;

// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options, host_problem_shapes_available);
auto arguments = args_from_options<GemmT>(options, host_problem_shapes_available);

// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
size_t workspace_size = GemmT::get_workspace_size(arguments);

// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
Expand Down Expand Up @@ -713,8 +756,14 @@ int main(int argc, char const **args) {
//

#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
std::cout << "\n*** Cooperative schedule ***" << std::endl;
run<Gemm>(options);
std::cout << "\n*** Cooperative schedule (host problem shapes unavailable) ***" << std::endl;
run<Gemm>(options, false /*host_problem_shapes_available*/);
std::cout << "\n*** Pingpong schedule ***" << std::endl;
run<GemmPingpong>(options);
std::cout << "\n*** Pingpong schedule (host problem shapes unavailable) ***" << std::endl;
run<GemmPingpong>(options, false /*host_problem_shapes_available*/);
#endif

return 0;
Expand Down
4 changes: 2 additions & 2 deletions examples/57_hopper_grouped_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
set(TEST_RANDOM --iterations=0) # Random problem sizes
set(TEST_RANDOM_LARGE_GROUP --groups=500 --iterations=0) # Random problem sizes

set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=0) # Random problem sizes
set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes
set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=500 --iterations=0) # Random problem sizes

set(TEST_EPILOGUE_OP --beta=0.7 --iterations=1) # Random problem sizes
set(TEST_EPILOGUE_OP --beta=0.5 --iterations=1) # Random problem sizes
set(TEST_EPILOGUE_OP_LARGE_GROUP --alpha=1.5 --iterations=1) # Random problem sizes

set(TEST_FIXED --m=2048 --n=5120 --k=8192 --groups=50 --iterations=0) # Fixed problem sizes
Expand Down
Loading

0 comments on commit dbdae51

Please sign in to comment.