2323#include < tvm/runtime/packed_func.h>
2424#include < tvm/runtime/registry.h>
2525
26- #include < fstream>
27- #include < iostream>
28- #include < sstream>
29- #include < vector>
26+ #include " group_gemm_runner.cuh"
3027
31- #include " ../../cuda/cuda_common.h"
32-
33- // clang-format off
34- #include " cutlass/cutlass.h"
35-
36- #include " cute/tensor.hpp"
37- #include " cutlass/tensor_ref.h"
38- #include " cutlass/epilogue/collective/default_epilogue.hpp"
39- #include " cutlass/epilogue/thread/linear_combination.h"
40- #include " cutlass/gemm/dispatch_policy.hpp"
41- #include " cutlass/gemm/group_array_problem_shape.hpp"
42- #include " cutlass/gemm/collective/collective_builder.hpp"
43- #include " cutlass/epilogue/collective/collective_builder.hpp"
44- #include " cutlass/gemm/device/gemm_universal_adapter.h"
45- #include " cutlass/gemm/kernel/gemm_universal.hpp"
46- // clang-format on
47-
48- #define CUTLASS_CHECK (status ) \
49- { \
50- cutlass::Status error = status; \
51- if (error != cutlass::Status::kSuccess ) { \
52- std::cerr << " Got cutlass error: " << cutlassGetStatusString (error) << " at: " << __LINE__ \
53- << std::endl; \
54- exit (EXIT_FAILURE); \
55- } \
56- }
57-
58- using namespace cute ;
59- using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int , int , int >>; // <M,N,K> per group
6028
6129#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
6230
63- inline size_t aligned (size_t value, size_t alignment = 16 ) {
64- return (value + alignment - 1 ) / alignment * alignment;
65- }
66-
67- template <typename ElementA, typename ElementB, typename ElementC,
68- typename LayoutA = cutlass::layout::RowMajor,
69- typename LayoutB = cutlass::layout::ColumnMajor,
70- typename LayoutC = cutlass::layout::RowMajor>
71- struct CutlassFP8GroupGemmRunner {
72- static constexpr int AlignmentA =
73- 128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of elements
74- // (up to 16 bytes)
75-
76- static constexpr int AlignmentB =
77- 128 / cutlass::sizeof_bits<ElementB>::value; // Alignment of B matrix in units of elements
78- // (up to 16 bytes)
79-
80- static constexpr int AlignmentC =
81- 128 / cutlass::sizeof_bits<ElementC>::value; // Alignment of C matrix in units of elements
82- // (up to 16 bytes)
83-
84- // Core kernel configurations
85- using ElementAccumulator = float ; // Element type for internal accumulation
86- using ArchTag =
87- cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
88- using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
31+ template <>
32+ struct KernelTraits <cutlass::float_e4m3_t > {
33+ using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum;
8934 using TileShape = Shape<_128, _256, _64>; // Threadblock-level tile size
9035 using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a cluster
91- using StageCountType =
92- cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
93- using KernelSchedule =
94- cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; // Kernel to launch
95- using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch
96-
97- using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
98- cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape,
99- cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator,
100- ElementC, LayoutC*, AlignmentC, ElementC, LayoutC*, AlignmentC,
101- EpilogueSchedule>::CollectiveOp;
102-
103- using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
104- ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, LayoutB*, AlignmentB,
105- ElementAccumulator, TileShape, ClusterShape,
106- cutlass::gemm::collective::StageCountAutoCarveout<static_cast <int >(
107- sizeof (typename CollectiveEpilogue::SharedStorage))>,
108- KernelSchedule>::CollectiveOp;
109-
110- using GemmKernel =
111- cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue>;
112-
113- using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
114-
115- using StrideA = typename Gemm::GemmKernel::UnderlyingStrideA;
116- using StrideB = typename Gemm::GemmKernel::UnderlyingStrideB;
117- using StrideC = typename Gemm::GemmKernel::UnderlyingStrideC;
118- using StrideD = typename Gemm::GemmKernel::UnderlyingStrideD;
119-
120- void run_group_gemm (const ElementA** ptr_A, const ElementB** ptr_B, const ElementC** ptr_C,
121- ElementC** ptr_D,
122- typename ProblemShape::UnderlyingProblemShape* problem_sizes,
123- typename ProblemShape::UnderlyingProblemShape* problem_sizes_host,
124- StrideA* stride_A, StrideB* stride_B, StrideC* stride_C, StrideD* stride_D,
125- uint8_t * workspace, int64_t workspace_size, int num_groups, float alpha,
126- float beta, cudaStream_t stream) {
127- typename Gemm::EpilogueOutputOp::Params epilogue_params{ElementAccumulator (alpha),
128- ElementAccumulator (beta)};
129-
130- cutlass::KernelHardwareInfo hw_info;
131- hw_info.device_id = 0 ;
132- hw_info.sm_count =
133- cutlass::KernelHardwareInfo::query_device_multiprocessor_count (hw_info.device_id );
134- typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGrouped ,
135- {num_groups, problem_sizes, problem_sizes_host},
136- {ptr_A, stride_A, ptr_B, stride_B},
137- {epilogue_params, ptr_C, stride_C, ptr_D, stride_D},
138- hw_info};
139- Gemm gemm_op;
140- CUTLASS_CHECK (gemm_op.can_implement (arguments));
141- CHECK_GE (workspace_size, gemm_op.get_workspace_size (arguments));
142- CUTLASS_CHECK (gemm_op.initialize (arguments, workspace, stream));
143- CUTLASS_CHECK (gemm_op.run ());
144- }
14536};
14637
147- template <typename ElementA, typename ElementB, typename ElementC, typename StrideA,
148- typename StrideB, typename StrideC>
149- __global__ void prepare_group_gemm_arguments (
150- const ElementA** ptr_A, const ElementB** ptr_B, ElementC** ptr_D,
151- typename ProblemShape::UnderlyingProblemShape* problem_sizes, StrideA* stride_A,
152- StrideB* stride_B, StrideC* stride_D, const ElementA* x, const ElementB* weight, ElementC* out,
153- int64_t * indptr, int64_t n, int64_t k, int64_t num_experts) {
154- int expert_id = threadIdx .x ;
155- if (expert_id >= num_experts) return ;
156- int prev_rows = expert_id == 0 ? 0 : indptr[expert_id - 1 ];
157- ptr_A[expert_id] = x + prev_rows * k;
158- ptr_B[expert_id] = weight + expert_id * k * n;
159- ptr_D[expert_id] = out + prev_rows * n;
160- problem_sizes[expert_id] = {static_cast <int >(indptr[expert_id] - prev_rows),
161- static_cast <int >(n), static_cast <int >(k)};
162- stride_A[expert_id] = cute::make_stride (k, Int<1 >{}, int64_t {0 });
163- stride_B[expert_id] = cute::make_stride (k, Int<1 >{}, int64_t {0 });
164- stride_D[expert_id] = cute::make_stride (n, Int<1 >{}, int64_t {0 });
165- }
166-
167- template <typename ElementA, typename ElementB, typename ElementC>
168- void cutlass_fp8_group_gemm (ElementA* x, ElementB* weight, int64_t * indptr, uint8_t * workspace,
169- int64_t workspace_size, int64_t n, int64_t k, int64_t num_groups,
170- ElementC* out, cudaStream_t stream) {
171- using Runner = CutlassFP8GroupGemmRunner<ElementA, ElementB, ElementC>;
172- using StrideA = typename Runner::StrideA;
173- using StrideB = typename Runner::StrideB;
174- using StrideC = typename Runner::StrideC;
175-
176- Runner runner;
177- std::ptrdiff_t offset = 0 ;
178- const ElementA** ptr_A = reinterpret_cast <const ElementA**>(workspace + offset);
179- offset += aligned (sizeof (ElementA*) * num_groups);
180- const ElementB** ptr_B = reinterpret_cast <const ElementB**>(workspace + offset);
181- offset += aligned (sizeof (ElementB*) * num_groups);
182- ElementC** ptr_D = reinterpret_cast <ElementC**>(workspace + offset);
183- offset += aligned (sizeof (ElementC*) * num_groups);
184- typename ProblemShape::UnderlyingProblemShape* problem_sizes =
185- reinterpret_cast <typename ProblemShape::UnderlyingProblemShape*>(workspace + offset);
186- offset += aligned (sizeof (typename ProblemShape::UnderlyingProblemShape) * num_groups);
187- StrideA* stride_A = reinterpret_cast <StrideA*>(workspace + offset);
188- offset += aligned (sizeof (StrideA) * num_groups);
189- StrideB* stride_B = reinterpret_cast <StrideB*>(workspace + offset);
190- offset += aligned (sizeof (StrideB) * num_groups);
191- StrideC* stride_D = reinterpret_cast <StrideC*>(workspace + offset);
192- offset += aligned (sizeof (StrideC) * num_groups);
193- prepare_group_gemm_arguments<<<1 , num_groups, 0 , stream>>> (
194- ptr_A, ptr_B, ptr_D, problem_sizes, stride_A, stride_B, stride_D, x, weight, out, indptr, n,
195- k, num_groups);
196- offset = aligned (offset, 256 );
197- runner.run_group_gemm (ptr_A, ptr_B, const_cast <const ElementC**>(ptr_D), ptr_D, problem_sizes,
198- nullptr , stride_A, stride_B, stride_D, stride_D, workspace + offset,
199- workspace_size - offset, num_groups, 1 .0f , 0 .0f , stream);
200- }
38+ template <>
39+ struct KernelTraits <cutlass::float_e5m2_t > : KernelTraits<cutlass::float_e4m3_t > {
40+ };
20141
20242namespace tvm {
20343namespace runtime {
@@ -218,10 +58,10 @@ void tvm_cutlass_fp8_group_gemm(NDArray x, NDArray weight, NDArray indptr, NDArr
21858 int n = weight->shape [1 ];
21959 int k = weight->shape [2 ];
22060 cudaStream_t stream = static_cast <cudaStream_t>((*func)().operator void *());
221- cutlass_fp8_group_gemm (static_cast <ElementA*>(x->data ), static_cast <ElementB*>(weight->data ),
222- static_cast <int64_t *>(indptr->data ), static_cast <uint8_t *>(workspace->data ),
223- workspace->shape [0 ], n, k, num_groups, static_cast <ElementC*>(out->data ),
224- stream);
61+ cutlass_group_gemm (static_cast <ElementA*>(x->data ), static_cast <ElementB*>(weight->data ),
62+ static_cast <int64_t *>(indptr->data ), static_cast <uint8_t *>(workspace->data ),
63+ workspace->shape [0 ], n, k, num_groups, static_cast <ElementC*>(out->data ),
64+ stream);
22565}
22666
22767TVM_REGISTER_GLOBAL (" cutlass.group_gemm_e5m2_e5m2_fp16" )
0 commit comments