Skip to content

Commit dc741e7

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Enable preshuffled mixed dtype Cutlass Gemm (pytorch#3722)
Summary: WIP to enable new optimized preshuffled fp8xint4 gemm. While the example compiles and runs, it runs into a variety of problems. The outputs are either completely incorrect, contain NaNs, or the kernel hits an Illegal Memory Access. I'm not yet sure why. Differential Revision: D69955197
1 parent dbef355 commit dc741e7

File tree

4 files changed

+391
-0
lines changed

4 files changed

+391
-0
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,54 @@ def cuda(self) -> bool:
11211121
return True
11221122

11231123

1124+
@register_quantize_op
1125+
class F8I4ShuffledGemm(F8I4RowwiseGemm):
1126+
def _int4_row_quantize(
1127+
self,
1128+
x: torch.Tensor,
1129+
group_size: int = 128,
1130+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1131+
n_bit = 4 # Number of target bits.
1132+
to_quant = x.reshape(-1, group_size).to(torch.float)
1133+
1134+
max_val = torch.abs(to_quant).amax(dim=1, keepdim=True)
1135+
max_int = 2 ** (n_bit - 1)
1136+
min_int = -(2 ** (n_bit - 1))
1137+
scales = max_val.clamp(min=1e-6) / max_int
1138+
1139+
out = to_quant.div(scales).round().clamp_(min_int, max_int - 1)
1140+
1141+
# Cast to int8 and restore shape.
1142+
out = out.to(dtype=torch.int8).reshape(x.shape)
1143+
1144+
# View scales as rows, groups.
1145+
scales = scales.view(x.shape[0], -1)
1146+
1147+
return out, scales
1148+
1149+
def quantize(self, x, w):
1150+
# Quantize both input tensors.
1151+
xq, x_scale = quantize_fp8_row(x)
1152+
wq, w_scale = self._int4_row_quantize(w)
1153+
# Pack int4 values together.
1154+
wq = self._pack_int4(wq)
1155+
# Shuffle weights and scales for faster compute.
1156+
wq, w_scale = torch.ops.fbgemm.preshuffle_i4(wq, w_scale)
1157+
return xq, wq, x_scale, w_scale
1158+
1159+
def compute(self, xq, wq, x_scale, w_scale):
1160+
out = torch.ops.fbgemm.f8i4bf16_shuffled(xq, wq, x_scale, w_scale)
1161+
return out
1162+
1163+
def quantize_and_compute(self, x, w):
1164+
xq, wq, x_scale, w_scale = self.quantize(x, w)
1165+
return self.compute(xq, wq, x_scale, w_scale)
1166+
1167+
@property
1168+
def name(self) -> str:
1169+
return "cutlass_f8i4_preshuffle"
1170+
1171+
11241172
@register_quantize_op
11251173
class BF16I4RowwiseGemm(F8I4RowwiseGemm):
11261174
"""
Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <ATen/ATen.h>
10+
#include <ATen/cuda/CUDAContext.h>
11+
12+
#include "cutlass/cutlass.h"
13+
14+
#include "cute/tensor.hpp"
15+
#include "cutlass/epilogue/collective/collective_builder.hpp"
16+
#include "cutlass/epilogue/collective/default_epilogue.hpp"
17+
#include "cutlass/epilogue/thread/linear_combination.h"
18+
#include "cutlass/gemm/collective/collective_builder.hpp"
19+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
20+
#include "cutlass/gemm/dispatch_policy.hpp"
21+
#include "cutlass/gemm/kernel/gemm_universal.hpp"
22+
#include "cutlass/tensor_ref.h"
23+
24+
#include "cutlass/util/mixed_dtype_utils.hpp"
25+
#include "cutlass/util/packed_stride.hpp"
26+
27+
namespace fbgemm_gpu {
28+
29+
at::Tensor f8i4bf16_shuffled(
30+
at::Tensor XQ,
31+
at::Tensor WQ,
32+
at::Tensor x_scale,
33+
at::Tensor w_scale) {
34+
// Get shape information from input tensors.
35+
int M = XQ.size(0);
36+
int K = XQ.size(1);
37+
int N = WQ.size(0);
38+
int num_groups = w_scale.size(1);
39+
int group_size = K / num_groups;
40+
// Allocate output.
41+
at::Tensor Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16));
42+
43+
// Define input types.
44+
using MmaType = cutlass::float_e4m3_t;
45+
using QuantType = cutlass::int4b_t;
46+
constexpr int TileShapeK = 128 * 8 / cute::sizeof_bits<MmaType>::value;
47+
48+
// A Matrix configuration.
49+
using ElementA = MmaType;
50+
using LayoutA = cutlass::layout::RowMajor;
51+
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
52+
53+
// B Matrix Configuration.
54+
using ElementB = QuantType;
55+
using LayoutB = cutlass::layout::ColumnMajor;
56+
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
57+
58+
// We need to manually swap and transpose inputs. Unclear how required this is
59+
// though.
60+
using LayoutA_Transpose =
61+
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
62+
using LayoutB_Transpose =
63+
typename cutlass::layout::LayoutTranspose<LayoutB>::type;
64+
65+
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
66+
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
67+
68+
// Define layout for shuffled weight tensor.
69+
using LayoutAtomQuant =
70+
decltype(cutlass::compute_memory_reordering_atom<MmaType>());
71+
using LayoutB_Reordered = decltype(cute::tile_to_shape(
72+
LayoutAtomQuant{}, cute::Layout<cute::Shape<int, int, int>, StrideB>{}));
73+
74+
using ElementScale = MmaType;
75+
using LayoutScale = cutlass::layout::RowMajor;
76+
77+
// Output Matrix configuration.
78+
using ElementC = cutlass::bfloat16_t;
79+
using LayoutC = cutlass::layout::RowMajor;
80+
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
81+
82+
// Core kernel configurations
83+
using ElementAccumulator = float;
84+
using ElementCompute = float;
85+
using ArchTag = cutlass::arch::Sm90;
86+
using OperatorClass = cutlass::arch::OpClassTensorOp;
87+
// TODO tune these shapes.
88+
using TileShape = cute::Shape<cute::_128, cute::_128, cute::Int<TileShapeK>>;
89+
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
90+
// TODO Should we use fast accum here?
91+
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative;
92+
// Might be the only epilogue schedule that supports swap + transpose.
93+
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
94+
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
95+
96+
// Define EVT for rowwise scaling.
97+
using XScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
98+
0,
99+
TileShape,
100+
ElementAccumulator,
101+
ElementAccumulator,
102+
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
103+
104+
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
105+
106+
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
107+
cutlass::multiplies,
108+
ElementC, // First stage output type.
109+
ElementAccumulator, // First stage input types.
110+
cutlass::FloatRoundStyle::round_to_nearest>;
111+
112+
using EpilogueEVT =
113+
cutlass::epilogue::fusion::Sm90EVT<Compute0, XScale, Accum>;
114+
115+
using CollectiveEpilogue =
116+
typename cutlass::epilogue::collective::CollectiveBuilder<
117+
cutlass::arch::Sm90,
118+
cutlass::arch::OpClassTensorOp,
119+
TileShape,
120+
ClusterShape,
121+
EpilogueTileType,
122+
ElementAccumulator,
123+
ElementAccumulator,
124+
ElementC,
125+
typename cutlass::layout::LayoutTranspose<LayoutC>::type,
126+
AlignmentC,
127+
ElementC,
128+
typename cutlass::layout::LayoutTranspose<LayoutC>::type,
129+
AlignmentC,
130+
EpilogueSchedule,
131+
EpilogueEVT>::CollectiveOp;
132+
133+
using CollectiveMainloopScaleOnly =
134+
typename cutlass::gemm::collective::CollectiveBuilder<
135+
ArchTag,
136+
OperatorClass,
137+
cute::tuple<ElementB, cutlass::Array<ElementScale, 8>>,
138+
LayoutB_Transpose,
139+
AlignmentB,
140+
ElementA,
141+
LayoutA_Transpose,
142+
AlignmentA,
143+
ElementAccumulator,
144+
TileShape,
145+
ClusterShape,
146+
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
147+
sizeof(typename CollectiveEpilogue::SharedStorage))>,
148+
KernelSchedule>::CollectiveOp;
149+
150+
using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal<
151+
cute::Shape<int, int, int, int>, // Indicates ProblemShape
152+
CollectiveMainloopScaleOnly,
153+
CollectiveEpilogue>;
154+
155+
using CollectiveMainloopShuffled =
156+
typename cutlass::gemm::collective::CollectiveBuilder<
157+
ArchTag,
158+
OperatorClass,
159+
cute::tuple<ElementB, cutlass::Array<ElementScale, 8>>,
160+
LayoutB_Reordered,
161+
AlignmentB,
162+
ElementA,
163+
LayoutA_Transpose,
164+
AlignmentA,
165+
ElementAccumulator,
166+
TileShape,
167+
ClusterShape,
168+
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
169+
sizeof(typename CollectiveEpilogue::SharedStorage))>,
170+
KernelSchedule>::CollectiveOp;
171+
172+
using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal<
173+
cute::Shape<int, int, int, int>,
174+
CollectiveMainloopShuffled,
175+
CollectiveEpilogue>;
176+
177+
using GemmScaleOnly =
178+
cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
179+
using GemmShuffled =
180+
cutlass::gemm::device::GemmUniversalAdapter<GemmKernelShuffled>;
181+
182+
using StrideC = typename GemmKernelScaleOnly::StrideC;
183+
184+
/// Initialization
185+
auto shape_B = cute::make_shape(N, K, 1);
186+
StrideA stride_A =
187+
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, 1));
188+
StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
189+
StrideC stride_C =
190+
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(N, M, 1));
191+
LayoutB_Reordered layout_B_reordered =
192+
cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
193+
using StrideS = typename CollectiveMainloopScaleOnly::StrideScale;
194+
StrideS stride_S = cutlass::make_cute_packed_stride(
195+
StrideS{}, cute::make_shape(N, num_groups, 1));
196+
197+
// Define Gemm arguments.
198+
typename GemmShuffled::Arguments arguments{
199+
cutlass::gemm::GemmUniversalMode::kGemm,
200+
{N, M, K, 1},
201+
{reinterpret_cast<ElementB*>(WQ.data_ptr()),
202+
layout_B_reordered,
203+
reinterpret_cast<ElementA*>(XQ.data_ptr()),
204+
stride_A,
205+
reinterpret_cast<cutlass::Array<ElementScale, 8>*>(w_scale.data_ptr()),
206+
stride_S,
207+
group_size},
208+
{{},
209+
reinterpret_cast<ElementC*>(Y.data_ptr()),
210+
stride_C,
211+
reinterpret_cast<ElementC*>(Y.data_ptr()),
212+
stride_C}};
213+
214+
arguments.epilogue.thread = {
215+
{reinterpret_cast<ElementAccumulator*>(x_scale.data_ptr())}, // x_scale
216+
{}, // Accumulator
217+
{}, // Multiplies
218+
};
219+
220+
// Launch the workload.
221+
GemmShuffled gemm;
222+
223+
// Using the arguments, query for extra workspace required for matrix
224+
// multiplication computation
225+
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
226+
227+
// Allocate workspace memory
228+
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
229+
230+
// Check the problem size is supported or not
231+
cutlass::Status status = gemm.can_implement(arguments);
232+
if (status != cutlass::Status::kSuccess) {
233+
throw std::runtime_error("cutlass cannot implement");
234+
}
235+
236+
// Initialize CUTLASS kernel with arguments and workspace pointer
237+
status = gemm.initialize(arguments, workspace.get());
238+
if (status != cutlass::Status::kSuccess) {
239+
throw std::runtime_error("cutlass cannot initialize");
240+
}
241+
242+
status = gemm(at::cuda::getCurrentCUDAStream());
243+
244+
if (status != cutlass::Status::kSuccess) {
245+
throw std::runtime_error(
246+
std::string("cutlass cannot run") +
247+
cutlass::cutlassGetStatusString(status));
248+
}
249+
C10_CUDA_KERNEL_LAUNCH_CHECK();
250+
251+
return Y;
252+
}
253+
254+
} // namespace fbgemm_gpu
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <ATen/ATen.h>
10+
#include <ATen/cuda/CUDAContext.h>
11+
12+
#include "cute/layout.hpp"
13+
#include "cutlass/detail/layout.hpp"
14+
#include "cutlass/layout/matrix.h"
15+
#include "cutlass/util/mixed_dtype_utils.hpp"
16+
17+
namespace fbgemm_gpu {
18+
19+
std::tuple<at::Tensor, at::Tensor> preshuffle_i4(
20+
at::Tensor WQ,
21+
at::Tensor w_scale) {
22+
// Check that w_scale is proper type. if not, quantize it.
23+
if (w_scale.dtype() != at::kFloat8_e4m3fn) {
24+
TORCH_WARN(
25+
"Weight scale must be FP8 for preshuffled GEMM. Performing downcasting.");
26+
w_scale = w_scale.to(WQ.options().dtype(at::kFloat8_e4m3fn));
27+
}
28+
// Start by allocating space for shuffled tensors.
29+
at::Tensor WQ_shuffled = at::empty_like(WQ);
30+
at::Tensor w_scale_packed = at::empty_like(w_scale);
31+
// WQ has two int4 values packed into each int8 dtype, so the size
32+
// is larger than it seems.
33+
size_t WQ_size = 2 * WQ.numel();
34+
// Encode weights to enable efficient lookup.
35+
cutlass::unified_encode_int4b(
36+
reinterpret_cast<cutlass::int4b_t*>(WQ.data_ptr()),
37+
reinterpret_cast<cutlass::int4b_t*>(WQ_shuffled.data_ptr()),
38+
WQ_size);
39+
40+
// Pack scale values. Size is divided by 8 to account for packing.
41+
size_t w_scale_size = w_scale.numel() / 8;
42+
cutlass::pack_scale_fp8(
43+
reinterpret_cast<cutlass::float_e4m3_t*>(w_scale.data_ptr()),
44+
reinterpret_cast<cutlass::Array<cutlass::float_e4m3_t, 8>*>(
45+
w_scale_packed.data_ptr()),
46+
w_scale_size);
47+
48+
// Next we need to shuffle B. To do this, we define a few helper objects.
49+
const int N = WQ.size(0);
50+
const int K = 2 * WQ.size(1);
51+
auto shape_B = cute::make_shape(N, K, 1);
52+
using LayoutB = cutlass::layout::ColumnMajor;
53+
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
54+
using LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom<
55+
cutlass::float_e4m3_t>());
56+
using LayoutB_Reordered = decltype(cute::tile_to_shape(
57+
LayoutAtomQuant{}, cute::Layout<cute::Shape<int, int, int>, StrideB>{}));
58+
StrideB stride_B;
59+
auto layout_B = make_layout(shape_B, stride_B);
60+
LayoutB_Reordered layout_B_reordered =
61+
cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
62+
;
63+
64+
// Now we're ready to reorder the tensor into proper layout.
65+
cutlass::reorder_tensor(
66+
reinterpret_cast<cutlass::int4b_t*>(WQ_shuffled.data_ptr()),
67+
layout_B,
68+
layout_B_reordered);
69+
70+
// Tensors should now be preshuffled and ready for use.
71+
return {WQ_shuffled, w_scale_packed};
72+
}
73+
74+
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)