Skip to content

Commit 459f98f

Browse files
committed
Add fp16 grouped gemm support for sm90 (apache#54)
1 parent 0ac8ec9 commit 459f98f

File tree

5 files changed

+370
-172
lines changed

5 files changed

+370
-172
lines changed

cmake/modules/contrib/CUTLASS.cmake

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ if(USE_CUDA AND USE_CUTLASS)
6060
### Build cutlass runtime objects using TVM's 3rdparty/cutlass submodule
6161
set(CUTLASS_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass)
6262
set(TVM_CUTLASS_RUNTIME_SRCS "")
63+
64+
# TODO: Should get rid of the postfix 'a' and test sm >= 90
65+
if (CMAKE_CUDA_ARCHITECTURES MATCHES "90|90a")
66+
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_group_gemm.cu)
67+
endif()
68+
6369
if (USE_CUDA_FP8)
6470
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_fp8_gemm.cu)
6571
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_group_gemm.cu)
@@ -76,4 +82,4 @@ if(USE_CUDA AND USE_CUTLASS)
7682
list(APPEND TVM_RUNTIME_EXT_OBJS "${CUTLASS_RUNTIME_OBJS}")
7783

7884
message(STATUS "Build with CUTLASS")
79-
endif()
85+
endif()
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include <cuda_fp16.h>
21+
#include <float.h>
22+
#include <tvm/runtime/ndarray.h>
23+
#include <tvm/runtime/packed_func.h>
24+
#include <tvm/runtime/registry.h>
25+
26+
#include "group_gemm_runner.cuh"
27+
28+
29+
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
30+
31+
template <>
32+
struct KernelTraits<cutlass::half_t> {
33+
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
34+
using TileShape = Shape<_128, _256, _64>; // Threadblock-level tile size
35+
using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a cluster
36+
};
37+
38+
namespace tvm {
39+
namespace runtime {
40+
41+
template <typename ElementA, typename ElementB, typename ElementC>
42+
void tvm_cutlass_group_gemm_sm90(NDArray x, NDArray weight, NDArray indptr, NDArray workspace,
43+
NDArray out) {
44+
// Workspace is used for storing device-side group gemm arguments and cutlass internal workspace.
45+
// Recommened size is 4MB.
46+
auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
47+
ICHECK(func != nullptr);
48+
CHECK_EQ(x->ndim, 2);
49+
CHECK_EQ(weight->ndim, 3);
50+
CHECK_EQ(indptr->ndim, 1);
51+
CHECK_EQ(workspace->ndim, 1);
52+
CHECK_EQ(out->ndim, 2);
53+
int num_groups = weight->shape[0];
54+
int n = weight->shape[1];
55+
int k = weight->shape[2];
56+
cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
57+
cutlass_group_gemm(static_cast<ElementA*>(x->data), static_cast<ElementB*>(weight->data),
58+
static_cast<int64_t*>(indptr->data), static_cast<uint8_t*>(workspace->data),
59+
workspace->shape[0], n, k, num_groups, static_cast<ElementC*>(out->data),
60+
stream);
61+
}
62+
63+
TVM_REGISTER_GLOBAL("cutlass.group_gemm_fp16_sm90")
64+
.set_body_typed(
65+
tvm_cutlass_group_gemm_sm90<cutlass::half_t, cutlass::half_t, cutlass::half_t>);
66+
67+
} // namespace runtime
68+
} // namespace tvm
69+
70+
#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED

src/runtime/contrib/cutlass/fp8_group_gemm.cu

Lines changed: 11 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -23,181 +23,21 @@
2323
#include <tvm/runtime/packed_func.h>
2424
#include <tvm/runtime/registry.h>
2525

26-
#include <fstream>
27-
#include <iostream>
28-
#include <sstream>
29-
#include <vector>
26+
#include "group_gemm_runner.cuh"
3027

31-
#include "../../cuda/cuda_common.h"
32-
33-
// clang-format off
34-
#include "cutlass/cutlass.h"
35-
36-
#include "cute/tensor.hpp"
37-
#include "cutlass/tensor_ref.h"
38-
#include "cutlass/epilogue/collective/default_epilogue.hpp"
39-
#include "cutlass/epilogue/thread/linear_combination.h"
40-
#include "cutlass/gemm/dispatch_policy.hpp"
41-
#include "cutlass/gemm/group_array_problem_shape.hpp"
42-
#include "cutlass/gemm/collective/collective_builder.hpp"
43-
#include "cutlass/epilogue/collective/collective_builder.hpp"
44-
#include "cutlass/gemm/device/gemm_universal_adapter.h"
45-
#include "cutlass/gemm/kernel/gemm_universal.hpp"
46-
// clang-format on
47-
48-
#define CUTLASS_CHECK(status) \
49-
{ \
50-
cutlass::Status error = status; \
51-
if (error != cutlass::Status::kSuccess) { \
52-
std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \
53-
<< std::endl; \
54-
exit(EXIT_FAILURE); \
55-
} \
56-
}
57-
58-
using namespace cute;
59-
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>; // <M,N,K> per group
6028

6129
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
6230

63-
inline size_t aligned(size_t value, size_t alignment = 16) {
64-
return (value + alignment - 1) / alignment * alignment;
65-
}
66-
67-
template <typename ElementA, typename ElementB, typename ElementC,
68-
typename LayoutA = cutlass::layout::RowMajor,
69-
typename LayoutB = cutlass::layout::ColumnMajor,
70-
typename LayoutC = cutlass::layout::RowMajor>
71-
struct CutlassFP8GroupGemmRunner {
72-
static constexpr int AlignmentA =
73-
128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of elements
74-
// (up to 16 bytes)
75-
76-
static constexpr int AlignmentB =
77-
128 / cutlass::sizeof_bits<ElementB>::value; // Alignment of B matrix in units of elements
78-
// (up to 16 bytes)
79-
80-
static constexpr int AlignmentC =
81-
128 / cutlass::sizeof_bits<ElementC>::value; // Alignment of C matrix in units of elements
82-
// (up to 16 bytes)
83-
84-
// Core kernel configurations
85-
using ElementAccumulator = float; // Element type for internal accumulation
86-
using ArchTag =
87-
cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
88-
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
31+
template <>
32+
struct KernelTraits<cutlass::float_e4m3_t> {
33+
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum;
8934
using TileShape = Shape<_128, _256, _64>; // Threadblock-level tile size
9035
using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a cluster
91-
using StageCountType =
92-
cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
93-
using KernelSchedule =
94-
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; // Kernel to launch
95-
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch
96-
97-
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
98-
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape,
99-
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator,
100-
ElementC, LayoutC*, AlignmentC, ElementC, LayoutC*, AlignmentC,
101-
EpilogueSchedule>::CollectiveOp;
102-
103-
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
104-
ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, LayoutB*, AlignmentB,
105-
ElementAccumulator, TileShape, ClusterShape,
106-
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
107-
sizeof(typename CollectiveEpilogue::SharedStorage))>,
108-
KernelSchedule>::CollectiveOp;
109-
110-
using GemmKernel =
111-
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue>;
112-
113-
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
114-
115-
using StrideA = typename Gemm::GemmKernel::UnderlyingStrideA;
116-
using StrideB = typename Gemm::GemmKernel::UnderlyingStrideB;
117-
using StrideC = typename Gemm::GemmKernel::UnderlyingStrideC;
118-
using StrideD = typename Gemm::GemmKernel::UnderlyingStrideD;
119-
120-
void run_group_gemm(const ElementA** ptr_A, const ElementB** ptr_B, const ElementC** ptr_C,
121-
ElementC** ptr_D,
122-
typename ProblemShape::UnderlyingProblemShape* problem_sizes,
123-
typename ProblemShape::UnderlyingProblemShape* problem_sizes_host,
124-
StrideA* stride_A, StrideB* stride_B, StrideC* stride_C, StrideD* stride_D,
125-
uint8_t* workspace, int64_t workspace_size, int num_groups, float alpha,
126-
float beta, cudaStream_t stream) {
127-
typename Gemm::EpilogueOutputOp::Params epilogue_params{ElementAccumulator(alpha),
128-
ElementAccumulator(beta)};
129-
130-
cutlass::KernelHardwareInfo hw_info;
131-
hw_info.device_id = 0;
132-
hw_info.sm_count =
133-
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
134-
typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGrouped,
135-
{num_groups, problem_sizes, problem_sizes_host},
136-
{ptr_A, stride_A, ptr_B, stride_B},
137-
{epilogue_params, ptr_C, stride_C, ptr_D, stride_D},
138-
hw_info};
139-
Gemm gemm_op;
140-
CUTLASS_CHECK(gemm_op.can_implement(arguments));
141-
CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
142-
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
143-
CUTLASS_CHECK(gemm_op.run());
144-
}
14536
};
14637

147-
template <typename ElementA, typename ElementB, typename ElementC, typename StrideA,
148-
typename StrideB, typename StrideC>
149-
__global__ void prepare_group_gemm_arguments(
150-
const ElementA** ptr_A, const ElementB** ptr_B, ElementC** ptr_D,
151-
typename ProblemShape::UnderlyingProblemShape* problem_sizes, StrideA* stride_A,
152-
StrideB* stride_B, StrideC* stride_D, const ElementA* x, const ElementB* weight, ElementC* out,
153-
int64_t* indptr, int64_t n, int64_t k, int64_t num_experts) {
154-
int expert_id = threadIdx.x;
155-
if (expert_id >= num_experts) return;
156-
int prev_rows = expert_id == 0 ? 0 : indptr[expert_id - 1];
157-
ptr_A[expert_id] = x + prev_rows * k;
158-
ptr_B[expert_id] = weight + expert_id * k * n;
159-
ptr_D[expert_id] = out + prev_rows * n;
160-
problem_sizes[expert_id] = {static_cast<int>(indptr[expert_id] - prev_rows),
161-
static_cast<int>(n), static_cast<int>(k)};
162-
stride_A[expert_id] = cute::make_stride(k, Int<1>{}, int64_t{0});
163-
stride_B[expert_id] = cute::make_stride(k, Int<1>{}, int64_t{0});
164-
stride_D[expert_id] = cute::make_stride(n, Int<1>{}, int64_t{0});
165-
}
166-
167-
template <typename ElementA, typename ElementB, typename ElementC>
168-
void cutlass_fp8_group_gemm(ElementA* x, ElementB* weight, int64_t* indptr, uint8_t* workspace,
169-
int64_t workspace_size, int64_t n, int64_t k, int64_t num_groups,
170-
ElementC* out, cudaStream_t stream) {
171-
using Runner = CutlassFP8GroupGemmRunner<ElementA, ElementB, ElementC>;
172-
using StrideA = typename Runner::StrideA;
173-
using StrideB = typename Runner::StrideB;
174-
using StrideC = typename Runner::StrideC;
175-
176-
Runner runner;
177-
std::ptrdiff_t offset = 0;
178-
const ElementA** ptr_A = reinterpret_cast<const ElementA**>(workspace + offset);
179-
offset += aligned(sizeof(ElementA*) * num_groups);
180-
const ElementB** ptr_B = reinterpret_cast<const ElementB**>(workspace + offset);
181-
offset += aligned(sizeof(ElementB*) * num_groups);
182-
ElementC** ptr_D = reinterpret_cast<ElementC**>(workspace + offset);
183-
offset += aligned(sizeof(ElementC*) * num_groups);
184-
typename ProblemShape::UnderlyingProblemShape* problem_sizes =
185-
reinterpret_cast<typename ProblemShape::UnderlyingProblemShape*>(workspace + offset);
186-
offset += aligned(sizeof(typename ProblemShape::UnderlyingProblemShape) * num_groups);
187-
StrideA* stride_A = reinterpret_cast<StrideA*>(workspace + offset);
188-
offset += aligned(sizeof(StrideA) * num_groups);
189-
StrideB* stride_B = reinterpret_cast<StrideB*>(workspace + offset);
190-
offset += aligned(sizeof(StrideB) * num_groups);
191-
StrideC* stride_D = reinterpret_cast<StrideC*>(workspace + offset);
192-
offset += aligned(sizeof(StrideC) * num_groups);
193-
prepare_group_gemm_arguments<<<1, num_groups, 0, stream>>>(
194-
ptr_A, ptr_B, ptr_D, problem_sizes, stride_A, stride_B, stride_D, x, weight, out, indptr, n,
195-
k, num_groups);
196-
offset = aligned(offset, 256);
197-
runner.run_group_gemm(ptr_A, ptr_B, const_cast<const ElementC**>(ptr_D), ptr_D, problem_sizes,
198-
nullptr, stride_A, stride_B, stride_D, stride_D, workspace + offset,
199-
workspace_size - offset, num_groups, 1.0f, 0.0f, stream);
200-
}
38+
template <>
39+
struct KernelTraits<cutlass::float_e5m2_t> : KernelTraits<cutlass::float_e4m3_t> {
40+
};
20141

20242
namespace tvm {
20343
namespace runtime {
@@ -218,10 +58,10 @@ void tvm_cutlass_fp8_group_gemm(NDArray x, NDArray weight, NDArray indptr, NDArr
21858
int n = weight->shape[1];
21959
int k = weight->shape[2];
22060
cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
221-
cutlass_fp8_group_gemm(static_cast<ElementA*>(x->data), static_cast<ElementB*>(weight->data),
222-
static_cast<int64_t*>(indptr->data), static_cast<uint8_t*>(workspace->data),
223-
workspace->shape[0], n, k, num_groups, static_cast<ElementC*>(out->data),
224-
stream);
61+
cutlass_group_gemm(static_cast<ElementA*>(x->data), static_cast<ElementB*>(weight->data),
62+
static_cast<int64_t*>(indptr->data), static_cast<uint8_t*>(workspace->data),
63+
workspace->shape[0], n, k, num_groups, static_cast<ElementC*>(out->data),
64+
stream);
22565
}
22666

22767
TVM_REGISTER_GLOBAL("cutlass.group_gemm_e5m2_e5m2_fp16")

0 commit comments

Comments
 (0)