Skip to content

Commit 70477ce

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Enable preshuffled mixed dtype Cutlass Gemm (#3722)
Summary: WIP to enable new optimized preshuffled fp8xint4 gemm. While the example compiles and runs, it runs into a variety of problems. The outputs are either completely incorrect, contain NaNs, or the kernel hits an Illegal Memory Access. I'm not yet sure why. Differential Revision: D69955197
1 parent dea9a97 commit 70477ce

File tree

4 files changed

+396
-0
lines changed

4 files changed

+396
-0
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,54 @@ def cuda(self) -> bool:
11211121
return True
11221122

11231123

1124+
@register_quantize_op
1125+
class F8I4ShuffledGemm(F8I4RowwiseGemm):
1126+
def _int4_row_quantize(
1127+
self,
1128+
x: torch.Tensor,
1129+
group_size: int = 128,
1130+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1131+
n_bit = 4 # Number of target bits.
1132+
to_quant = x.reshape(-1, group_size).to(torch.float)
1133+
1134+
max_val = torch.abs(to_quant).amax(dim=1, keepdim=True)
1135+
max_int = 2 ** (n_bit - 1)
1136+
min_int = -(2 ** (n_bit - 1))
1137+
scales = max_val.clamp(min=1e-6) / max_int
1138+
1139+
out = to_quant.div(scales).round().clamp_(min_int, max_int - 1)
1140+
1141+
# Cast to int8 and restore shape.
1142+
out = out.to(dtype=torch.int8).reshape(x.shape)
1143+
1144+
# View scales as rows, groups.
1145+
scales = scales.view(x.shape[0], -1)
1146+
1147+
return out, scales
1148+
1149+
def quantize(self, x, w):
1150+
# Quantize both input tensors.
1151+
xq, x_scale = quantize_fp8_row(x)
1152+
wq, w_scale = self._int4_row_quantize(w)
1153+
# Pack int4 values together.
1154+
wq = self._pack_int4(wq)
1155+
# Shuffle weights and scales for faster compute.
1156+
wq, w_scale = torch.ops.fbgemm.preshuffle_i4(wq, w_scale)
1157+
return xq, wq, x_scale, w_scale
1158+
1159+
def compute(self, xq, wq, x_scale, w_scale):
1160+
out = torch.ops.fbgemm.f8i4bf16_shuffled(xq, wq, x_scale, w_scale)
1161+
return out
1162+
1163+
def quantize_and_compute(self, x, w):
1164+
xq, wq, x_scale, w_scale = self.quantize(x, w)
1165+
return self.compute(xq, wq, x_scale, w_scale)
1166+
1167+
@property
1168+
def name(self) -> str:
1169+
return "cutlass_f8i4_preshuffle"
1170+
1171+
11241172
@register_quantize_op
11251173
class BF16I4RowwiseGemm(F8I4RowwiseGemm):
11261174
"""
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <ATen/ATen.h>
10+
#include <ATen/cuda/CUDAContext.h>
11+
12+
#include "cutlass/cutlass.h"
13+
14+
#include "cute/tensor.hpp"
15+
#include "cutlass/epilogue/collective/collective_builder.hpp"
16+
#include "cutlass/epilogue/collective/default_epilogue.hpp"
17+
#include "cutlass/epilogue/thread/linear_combination.h"
18+
#include "cutlass/gemm/collective/collective_builder.hpp"
19+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
20+
#include "cutlass/gemm/dispatch_policy.hpp"
21+
#include "cutlass/gemm/kernel/gemm_universal.hpp"
22+
#include "cutlass/tensor_ref.h"
23+
24+
#include "cutlass/util/mixed_dtype_utils.hpp"
25+
#include "cutlass/util/packed_stride.hpp"
26+
27+
namespace fbgemm_gpu {
28+
29+
at::Tensor f8i4bf16_shuffled(
30+
at::Tensor XQ,
31+
at::Tensor WQ,
32+
at::Tensor x_scale,
33+
at::Tensor w_scale) {
34+
// Get shape information from input tensors.
35+
int M = XQ.size(0);
36+
int K = XQ.size(1);
37+
int N = WQ.size(0);
38+
// Make sure w_scale is in proper format.
39+
TORCH_CHECK(
40+
w_scale.size(2) == 8,
41+
"Weights and scales must be prepacked with preshuffle_i4.");
42+
int num_groups = w_scale.size(1);
43+
int group_size = K / num_groups;
44+
// Allocate output.
45+
at::Tensor Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16));
46+
47+
// Define input types.
48+
using MmaType = cutlass::float_e4m3_t;
49+
using QuantType = cutlass::int4b_t;
50+
constexpr int TileShapeK = 128 * 8 / cute::sizeof_bits<MmaType>::value;
51+
52+
// A Matrix configuration.
53+
using ElementA = MmaType;
54+
using LayoutA = cutlass::layout::RowMajor;
55+
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
56+
57+
// B Matrix Configuration.
58+
using ElementB = QuantType;
59+
using LayoutB = cutlass::layout::ColumnMajor;
60+
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
61+
62+
// We need to manually swap and transpose inputs. Unclear how required this is
63+
// though.
64+
using LayoutA_Transpose =
65+
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
66+
using LayoutB_Transpose =
67+
typename cutlass::layout::LayoutTranspose<LayoutB>::type;
68+
69+
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
70+
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
71+
72+
// Define layout for shuffled weight tensor.
73+
using LayoutAtomQuant =
74+
decltype(cutlass::compute_memory_reordering_atom<MmaType>());
75+
using LayoutB_Reordered = decltype(cute::tile_to_shape(
76+
LayoutAtomQuant{}, cute::Layout<cute::Shape<int, int, int>, StrideB>{}));
77+
78+
using ElementScale = MmaType;
79+
using LayoutScale = cutlass::layout::RowMajor;
80+
81+
// Output Matrix configuration.
82+
using ElementC = cutlass::bfloat16_t;
83+
using LayoutC = cutlass::layout::RowMajor;
84+
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
85+
86+
// Core kernel configurations
87+
using ElementAccumulator = float;
88+
using ElementCompute = float;
89+
using ArchTag = cutlass::arch::Sm90;
90+
using OperatorClass = cutlass::arch::OpClassTensorOp;
91+
// TODO tune these shapes.
92+
using TileShape = cute::Shape<cute::_128, cute::_128, cute::Int<TileShapeK>>;
93+
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
94+
// TODO Should we use fast accum here?
95+
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative;
96+
// Might be the only epilogue schedule that supports swap + transpose.
97+
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
98+
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
99+
100+
// Define EVT for rowwise scaling.
101+
using XScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
102+
0,
103+
TileShape,
104+
ElementAccumulator,
105+
ElementAccumulator,
106+
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
107+
108+
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
109+
110+
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
111+
cutlass::multiplies,
112+
ElementC, // First stage output type.
113+
ElementAccumulator, // First stage input types.
114+
cutlass::FloatRoundStyle::round_to_nearest>;
115+
116+
using EpilogueEVT =
117+
cutlass::epilogue::fusion::Sm90EVT<Compute0, XScale, Accum>;
118+
119+
using CollectiveEpilogue =
120+
typename cutlass::epilogue::collective::CollectiveBuilder<
121+
cutlass::arch::Sm90,
122+
cutlass::arch::OpClassTensorOp,
123+
TileShape,
124+
ClusterShape,
125+
EpilogueTileType,
126+
ElementAccumulator,
127+
ElementAccumulator,
128+
ElementC,
129+
typename cutlass::layout::LayoutTranspose<LayoutC>::type,
130+
AlignmentC,
131+
ElementC,
132+
typename cutlass::layout::LayoutTranspose<LayoutC>::type,
133+
AlignmentC,
134+
EpilogueSchedule,
135+
EpilogueEVT>::CollectiveOp;
136+
137+
using CollectiveMainloopScaleOnly =
138+
typename cutlass::gemm::collective::CollectiveBuilder<
139+
ArchTag,
140+
OperatorClass,
141+
cute::tuple<ElementB, cutlass::Array<ElementScale, 8>>,
142+
LayoutB_Transpose,
143+
AlignmentB,
144+
ElementA,
145+
LayoutA_Transpose,
146+
AlignmentA,
147+
ElementAccumulator,
148+
TileShape,
149+
ClusterShape,
150+
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
151+
sizeof(typename CollectiveEpilogue::SharedStorage))>,
152+
KernelSchedule>::CollectiveOp;
153+
154+
using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal<
155+
cute::Shape<int, int, int, int>, // Indicates ProblemShape
156+
CollectiveMainloopScaleOnly,
157+
CollectiveEpilogue>;
158+
159+
using CollectiveMainloopShuffled =
160+
typename cutlass::gemm::collective::CollectiveBuilder<
161+
ArchTag,
162+
OperatorClass,
163+
cute::tuple<ElementB, cutlass::Array<ElementScale, 8>>,
164+
LayoutB_Reordered,
165+
AlignmentB,
166+
ElementA,
167+
LayoutA_Transpose,
168+
AlignmentA,
169+
ElementAccumulator,
170+
TileShape,
171+
ClusterShape,
172+
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
173+
sizeof(typename CollectiveEpilogue::SharedStorage))>,
174+
KernelSchedule>::CollectiveOp;
175+
176+
using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal<
177+
cute::Shape<int, int, int, int>,
178+
CollectiveMainloopShuffled,
179+
CollectiveEpilogue>;
180+
181+
using GemmScaleOnly =
182+
cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
183+
using GemmShuffled =
184+
cutlass::gemm::device::GemmUniversalAdapter<GemmKernelShuffled>;
185+
186+
using StrideC = typename GemmKernelScaleOnly::StrideC;
187+
188+
/// Initialization
189+
auto shape_B = cute::make_shape(N, K, 1);
190+
StrideA stride_A =
191+
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, 1));
192+
StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
193+
StrideC stride_C =
194+
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(N, M, 1));
195+
LayoutB_Reordered layout_B_reordered =
196+
cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
197+
using StrideS = typename CollectiveMainloopScaleOnly::StrideScale;
198+
StrideS stride_S = cutlass::make_cute_packed_stride(
199+
StrideS{}, cute::make_shape(N, num_groups, 1));
200+
201+
// Define Gemm arguments.
202+
typename GemmShuffled::Arguments arguments{
203+
cutlass::gemm::GemmUniversalMode::kGemm,
204+
{N, M, K, 1},
205+
{reinterpret_cast<ElementB*>(WQ.data_ptr()),
206+
layout_B_reordered,
207+
reinterpret_cast<ElementA*>(XQ.data_ptr()),
208+
stride_A,
209+
reinterpret_cast<cutlass::Array<ElementScale, 8>*>(w_scale.data_ptr()),
210+
stride_S,
211+
group_size},
212+
{{},
213+
reinterpret_cast<ElementC*>(Y.data_ptr()),
214+
stride_C,
215+
reinterpret_cast<ElementC*>(Y.data_ptr()),
216+
stride_C}};
217+
218+
arguments.epilogue.thread = {
219+
{reinterpret_cast<ElementAccumulator*>(x_scale.data_ptr())}, // x_scale
220+
{}, // Accumulator
221+
{}, // Multiplies
222+
};
223+
224+
// Launch the workload.
225+
GemmShuffled gemm;
226+
227+
// Using the arguments, query for extra workspace required for matrix
228+
// multiplication computation
229+
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
230+
231+
// Allocate workspace memory
232+
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
233+
234+
// Check the problem size is supported or not
235+
cutlass::Status status = gemm.can_implement(arguments);
236+
if (status != cutlass::Status::kSuccess) {
237+
throw std::runtime_error("cutlass cannot implement");
238+
}
239+
240+
// Initialize CUTLASS kernel with arguments and workspace pointer
241+
status = gemm.initialize(arguments, workspace.get());
242+
if (status != cutlass::Status::kSuccess) {
243+
throw std::runtime_error("cutlass cannot initialize");
244+
}
245+
246+
status = gemm(at::cuda::getCurrentCUDAStream());
247+
248+
if (status != cutlass::Status::kSuccess) {
249+
throw std::runtime_error(
250+
std::string("cutlass cannot run") +
251+
cutlass::cutlassGetStatusString(status));
252+
}
253+
C10_CUDA_KERNEL_LAUNCH_CHECK();
254+
255+
return Y;
256+
}
257+
258+
} // namespace fbgemm_gpu
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <ATen/ATen.h>
10+
#include <ATen/cuda/CUDAContext.h>
11+
12+
#include "cute/layout.hpp"
13+
#include "cutlass/detail/layout.hpp"
14+
#include "cutlass/layout/matrix.h"
15+
#include "cutlass/util/mixed_dtype_utils.hpp"
16+
17+
namespace fbgemm_gpu {
18+
19+
std::tuple<at::Tensor, at::Tensor> preshuffle_i4(
20+
at::Tensor WQ,
21+
at::Tensor w_scale) {
22+
// Check that w_scale is proper type. if not, quantize it.
23+
if (w_scale.dtype() != at::kFloat8_e4m3fn) {
24+
TORCH_WARN(
25+
"Weight scale must be FP8 for preshuffled GEMM. Performing downcasting.");
26+
w_scale = w_scale.to(WQ.options().dtype(at::kFloat8_e4m3fn));
27+
}
28+
// Start by allocating space for shuffled tensors.
29+
at::Tensor WQ_shuffled = at::empty_like(WQ);
30+
// Packed scale contains 8 lookup values for each original scale element.
31+
at::Tensor w_scale_packed =
32+
at::empty({w_scale.size(0), w_scale.size(1), 8}, w_scale.options());
33+
// WQ has two int4 values packed into each int8 dtype, so the size
34+
// is larger than it seems.
35+
size_t WQ_size = 2 * WQ.numel();
36+
// Encode weights to enable efficient lookup.
37+
cutlass::unified_encode_int4b(
38+
reinterpret_cast<cutlass::int4b_t*>(WQ.data_ptr()),
39+
reinterpret_cast<cutlass::int4b_t*>(WQ_shuffled.data_ptr()),
40+
WQ_size);
41+
42+
size_t w_scale_size = w_scale.numel();
43+
cutlass::pack_scale_fp8(
44+
reinterpret_cast<cutlass::float_e4m3_t*>(w_scale.data_ptr()),
45+
reinterpret_cast<cutlass::Array<cutlass::float_e4m3_t, 8>*>(
46+
w_scale_packed.data_ptr()),
47+
w_scale_size);
48+
49+
// Next we need to shuffle B. To do this, we define a few helper objects.
50+
const int N = WQ.size(0);
51+
const int K = 2 * WQ.size(1);
52+
auto shape_B = cute::make_shape(N, K, 1);
53+
using LayoutB = cutlass::layout::ColumnMajor;
54+
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
55+
using LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom<
56+
cutlass::float_e4m3_t>());
57+
using LayoutB_Reordered = decltype(cute::tile_to_shape(
58+
LayoutAtomQuant{}, cute::Layout<cute::Shape<int, int, int>, StrideB>{}));
59+
StrideB stride_B;
60+
auto layout_B = make_layout(shape_B, stride_B);
61+
LayoutB_Reordered layout_B_reordered =
62+
cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
63+
;
64+
65+
// Now we're ready to reorder the tensor into proper layout.
66+
cutlass::reorder_tensor(
67+
reinterpret_cast<cutlass::int4b_t*>(WQ_shuffled.data_ptr()),
68+
layout_B,
69+
layout_B_reordered);
70+
71+
// Tensors should now be preshuffled and ready for use.
72+
return {WQ_shuffled, w_scale_packed};
73+
}
74+
75+
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)