Skip to content

Commit dea9a97

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Update Cutlass to V3.8-2
Differential Revision: D69890673
1 parent 3fed238 commit dea9a97

File tree

6 files changed

+17
-20
lines changed

6 files changed

+17
-20
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,8 @@ at::Tensor bf16i4bf16_rowwise_impl(
9999
// threadblocks in a
100100
// cluster
101101
using CooperativeSchedule =
102-
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput;
103-
using PongSchedule =
104-
cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput;
102+
cutlass::gemm::KernelTmaWarpSpecializedCooperative;
103+
using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
105104
using CooperativeEpilogueSchedule =
106105
cutlass::epilogue::TmaWarpSpecializedCooperative;
107106
using PongEpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,8 @@ at::Tensor bf16i4bf16_rowwise_batched_impl(
103103
// threadblocks in a
104104
// cluster
105105
using CooperativeSchedule =
106-
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput;
107-
using PongSchedule =
108-
cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput;
106+
cutlass::gemm::KernelTmaWarpSpecializedCooperative;
107+
using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
109108
using CooperativeEpilogueSchedule =
110109
cutlass::epilogue::TmaWarpSpecializedCooperative;
111110
using PongEpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_grouped.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,17 @@ struct GroupedGemmConfigs {
6767
conditional_t<PONG, PongEpilogueSchedule, CooperativeEpilogueSchedule>;
6868

6969
// Implement rowwise scaling epilogue.
70-
using XScale = cutlass::epilogue::fusion::Sm90ColBroadcastPtrArray<
70+
using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<
7171
0,
7272
TileShape,
73-
ElementComputeEpilogue,
73+
ElementComputeEpilogue*,
7474
ElementComputeEpilogue,
7575
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;
7676

77-
using WScale = cutlass::epilogue::fusion::Sm90RowBroadcastPtrArray<
77+
using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
7878
0,
7979
TileShape,
80-
ElementComputeEpilogue,
80+
ElementComputeEpilogue*,
8181
ElementComputeEpilogue,
8282
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
8383

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8i4bf16_rowwise.cu

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,8 @@ at::Tensor f8i4bf16_rowwise_impl(
9393
// threadblocks in a
9494
// cluster
9595
using CooperativeSchedule =
96-
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput;
97-
using PongSchedule =
98-
cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput;
96+
cutlass::gemm::KernelTmaWarpSpecializedCooperative;
97+
using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
9998
using CooperativeEpilogueSchedule =
10099
cutlass::epilogue::TmaWarpSpecializedCooperative;
101100
using PongEpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
@@ -260,7 +259,7 @@ at::Tensor dispatch_f8i4bf16_rowwise_kernel(
260259
return f8i4bf16_rowwise_impl<
261260
128,
262261
256,
263-
64,
262+
128,
264263
2,
265264
1,
266265
1,
@@ -271,7 +270,7 @@ at::Tensor dispatch_f8i4bf16_rowwise_kernel(
271270
return f8i4bf16_rowwise_impl<
272271
128,
273272
256,
274-
64,
273+
128,
275274
2,
276275
1,
277276
1,

fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@ std::vector<at::Tensor> quantize_fp8_per_tensor(
786786
for (int i = 0; i < input.dim(); i++) {
787787
quantized_input_shape.push_back(input.size(i));
788788
}
789-
std::vector<long int> scale_shape = {1};
789+
std::vector<long int> scale_shape = {};
790790
input = input.cuda();
791791
at::Tensor quantized_input = torch::empty(
792792
quantized_input_shape,

fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,7 +1120,7 @@ def test_quantize_compile(self) -> None:
11201120
@unittest.skipIf(
11211121
not torch.version.cuda, "Skip on AMD: fast gemv op is not yet supported."
11221122
)
1123-
def test_gemv(
1123+
def run_gemv(
11241124
self, test_cases, gemv_op, atol, rtol, quantize_w=False, quantize_x=False
11251125
):
11261126
for M, N, K in test_cases:
@@ -1150,7 +1150,7 @@ def test_bf16_gemv(self) -> None:
11501150
(1, 7168, 8192),
11511151
(1, 8192, 3584),
11521152
]
1153-
self.test_gemv(test_cases, torch.ops.fbgemm.bf16_fast_gemv, 9.0e-3, 9.0e-3)
1153+
self.run_gemv(test_cases, torch.ops.fbgemm.bf16_fast_gemv, 9.0e-3, 9.0e-3)
11541154

11551155
@unittest.skipIf(
11561156
not torch.version.cuda, "Skip on AMD: fast gemv op is not yet supported."
@@ -1164,7 +1164,7 @@ def test_bf16_fp8_gemv(self) -> None:
11641164
(1, 7168, 8192),
11651165
(1, 8192, 3584),
11661166
]
1167-
self.test_gemv(
1167+
self.run_gemv(
11681168
test_cases,
11691169
torch.ops.fbgemm.bf16fp8bf16_fast_gemv,
11701170
1.0e-2,
@@ -1182,7 +1182,7 @@ def test_fp8_fp8_gemv(self) -> None:
11821182
(1, 7168, 8192),
11831183
(1, 8192, 3584),
11841184
]
1185-
self.test_gemv(
1185+
self.run_gemv(
11861186
test_cases,
11871187
torch.ops.fbgemm.fp8fp8bf16_fast_gemv,
11881188
9.0e-2,

0 commit comments

Comments
 (0)