-
-
Notifications
You must be signed in to change notification settings - Fork 5k
/
scaled_mm_dq_c2x.cu
329 lines (272 loc) · 12.4 KB
/
scaled_mm_dq_c2x.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
#include <stddef.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
// clang-format will break include orders
// clang-format off
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm_coord.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "broadcast_load_epilogue_c2x.hpp"
#include "common.hpp"
// clang-format on
using namespace cute;
/*
This defines a quantized GEMM operation with dequantized output, similar to
torch._scaled_mm. It is defined using the CUTLASS 2.x API, and is used for
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
namespace {
// Wrappers for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template <typename Kernel>
struct enable_sm75_to_sm80 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
template <typename Kernel>
struct enable_sm80_to_sm89 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
template <typename Kernel>
struct enable_sm89_to_sm90 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
template <typename Arch, template <typename> typename ArchGuard,
typename ElementAB_, typename ElementD_, typename TileShape,
typename WarpShape, typename InstructionShape, int32_t MainLoopStages>
struct cutlass_2x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using Operator =
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
cutlass::arch::OpMultiplyAddSaturate,
cutlass::arch::OpMultiplyAdd>::type;
using OutputTileThreadMap =
cutlass::epilogue::threadblock::OutputTileThreadLayout<
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
>;
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
OutputTileThreadMap, float, Stride<Int<1>, Int<0>, Int<0>>>;
using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
OutputTileThreadMap, float, Stride<Int<0>, Int<1>, Int<0>>>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute1 =
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
Stride<int64_t, Int<1>, Int<0>>>;
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute1>;
// clang-format off
using RowMajor = typename cutlass::layout::RowMajor;
using ColumnMajor = typename cutlass::layout::ColumnMajor;
using KernelType =
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
float, cutlass::layout::RowMajor, 4,
ElementAcc, float, cutlass::arch::OpClassTensorOp,
Arch,
TileShape, WarpShape, InstructionShape,
EVTD,
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
MainLoopStages, Operator,
1 /* epilogue stages */
>::GemmKernel>;
// clang-format on
using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
};
template <typename Gemm>
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
int32_t m = a.size(0);
int32_t n = b.size(1);
int32_t k = a.size(1);
cutlass::gemm::GemmCoord problem_size{m, n, k};
int64_t lda = a.stride(0);
int64_t ldb = b.stride(1);
int64_t ldc = out.stride(0);
using StrideC = Stride<int64_t, Int<1>, Int<0>>;
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
auto a_scales_ptr = a_scales.data_ptr<float>();
auto b_scales_ptr = b_scales.data_ptr<float>();
using ScaleAArgs = typename Gemm::ScaleA::Arguments;
using ScaleBArgs = typename Gemm::ScaleB::Arguments;
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
typename Gemm::EVTCompute0::Arguments evt0_compute_args{b_args};
typename Gemm::EVTCompute1::Arguments evt1_compute_args{a_args,
evt0_compute_args};
typename Gemm::D::Arguments d_args{c_ptr, c_stride};
typename Gemm::EVTD::Arguments epilogue_args{
evt1_compute_args,
d_args,
};
typename Gemm::Op::Arguments args{
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode
problem_size, // problem size
1, // batch count
epilogue_args,
a_ptr,
b_ptr,
nullptr,
nullptr,
0,
0,
0,
0,
lda,
ldb,
ldc,
ldc};
// Launch the CUTLASS GEMM kernel.
typename Gemm::Op gemm_op;
size_t workspace_size = gemm_op.get_workspace_size(args);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
CUTLASS_CHECK(gemm_op.can_implement(args));
cutlass::Status status = gemm_op(args, workspace.get(), stream);
CUTLASS_CHECK(status);
}
} // namespace
void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::half_t,
TileShape, WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
b_scales);
}
}
void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::half_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
}
}
void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (a.dtype() == torch::kInt8) {
TORCH_CHECK(b.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
} else {
assert(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::half_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
}
} else {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
cutlass::bfloat16_t, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
cutlass::half_t, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
}
}
}