Skip to content

Commit 8b881b9

Browse files
committed
feature: add cutlass backend in fp4 mm
support a/b input type e2m1, block quant type e4m3 with block size 16, layout 128x4. output bfloat16 and fp16, kernels ported from trtllm. Signed-off-by: Vincent Huang <vincenth@nvidia.com>
1 parent cd928a7 commit 8b881b9

File tree

10 files changed

+1608
-6
lines changed

10 files changed

+1608
-6
lines changed

csrc/fp4_gemm_cutlass.cu

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
/*
2+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include <ATen/cuda/EmptyTensor.h>
17+
#include <cuda_fp16.h>
18+
19+
#include <cstddef>
20+
#include <cstdint>
21+
#include <functional>
22+
#include <type_traits>
23+
#include <vector>
24+
25+
#include "flashinfer/gemm/cutlass_gemm_configs.h"
26+
#include "flashinfer/gemm/fp4_gemm_cutlass.h"
27+
#include "flashinfer/gemm/fp4_gemm_cutlass_template.h"
28+
#include "pytorch_extension_utils.h"
29+
30+
using flashinfer::gemm::ClusterShape;
31+
using flashinfer::gemm::CutlassFp4GemmRunner;
32+
using flashinfer::gemm::CutlassFp4GemmRunnerInterface;
33+
using flashinfer::gemm::CutlassGemmConfig;
34+
using flashinfer::gemm::CutlassTileConfigSM100;
35+
using flashinfer::gemm::EpilogueScheduleType;
36+
using flashinfer::gemm::FP4GemmType;
37+
using flashinfer::gemm::MainloopScheduleType;
38+
39+
namespace flashinfer {
40+
namespace gemm {
41+
template class CutlassFp4GemmRunner<__nv_bfloat16, FP4GemmType::W4A4_NVFP4_NVFP4>;
42+
template class CutlassFp4GemmRunner<half, FP4GemmType::W4A4_NVFP4_NVFP4>;
43+
} // namespace gemm
44+
} // namespace flashinfer
45+
46+
namespace torch_ext {
47+
48+
namespace {
49+
50+
CutlassGemmConfig getFp4GemmConfig(int64_t m, int64_t n, int64_t k, int64_t tactic) {
51+
auto getCutlassFp4GemmConfigs = []() {
52+
CutlassFp4GemmRunner<__nv_bfloat16, FP4GemmType::W4A4_NVFP4_NVFP4> gemmRunner;
53+
return gemmRunner.getConfigs();
54+
};
55+
static std::vector<CutlassGemmConfig> globalConfigs = getCutlassFp4GemmConfigs();
56+
TORCH_CHECK(tactic >= 0 && tactic < globalConfigs.size(), "tactic must be between 0 and ",
57+
globalConfigs.size());
58+
return globalConfigs[tactic];
59+
}
60+
61+
template <typename T>
62+
void runGemm(at::Tensor& out, at::Tensor const& mat1, at::Tensor const& mat2,
63+
at::Tensor const& mat1Scale, at::Tensor const& mat2Scale,
64+
at::Tensor const& globalScale, int64_t m, int64_t n, int64_t k, int64_t batch_count,
65+
CutlassGemmConfig const& gemmConfig, at::Tensor workspace_buffer) {
66+
CutlassFp4GemmRunner<T, FP4GemmType::W4A4_NVFP4_NVFP4> gemmRunner;
67+
68+
int64_t const required_workspace_size = gemmRunner.getWorkspaceSize(m, n, k, batch_count);
69+
int64_t const provided_workspace_size =
70+
workspace_buffer.numel() * workspace_buffer.element_size();
71+
72+
auto runKernel = [&](void* workspace) {
73+
gemmRunner.gemm(out.data_ptr(), mat1.const_data_ptr(), mat2.const_data_ptr(),
74+
mat1Scale.const_data_ptr(), mat2Scale.const_data_ptr(),
75+
globalScale.data_ptr<float>(), m, n, k, batch_count, gemmConfig,
76+
reinterpret_cast<char*>(workspace), required_workspace_size,
77+
at::cuda::getCurrentCUDAStream(mat1.get_device()));
78+
};
79+
80+
if (provided_workspace_size < required_workspace_size) {
81+
at::Tensor new_workspace = at::detail::empty_cuda(
82+
{required_workspace_size}, at::ScalarType::Char, mat1.device(), std::nullopt);
83+
84+
runKernel(new_workspace.data_ptr());
85+
} else {
86+
runKernel(workspace_buffer.data_ptr());
87+
}
88+
}
89+
90+
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; // uint8_t
91+
constexpr auto SF_DTYPE = at::ScalarType::Byte; // uint8_t
92+
93+
#define CHECK_GPU_INPUT(x, st) \
94+
TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") \
95+
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") \
96+
TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type: " #x)
97+
98+
// mat1: [B, M, K / 2], FLOAT4_E2M1X2 or [B, M, K], FLOAT8_E4M3FN
99+
// mat2: [B, N, K / 2], FLOAT4_E2M1X2
100+
// out: [B, M, N], fp16/bf16/fp32
101+
// mat1Scale: ceil(M / 128) * 128 * ceil(K / sfVecSize / 4) * 4, SF_DTYPE (UE4M3 or UE8M0)
102+
// mat2Scale: ceil(N / 128) * 128 * ceil(K / sfVecSize / 4) * 4, SF_DTYPE (UE4M3 or UE8M0)
103+
// globalScale: [1], 1 / (((448 * 6) / mat1.abs().max()) * ((448 * 6) / mat2.abs().max()))
104+
// B = 1 for GEMM op as a special case
105+
at::Tensor fp4_bmm_impl(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& mat1Scale,
106+
at::Tensor const& mat2Scale, at::Tensor const& globalScale, at::Tensor out,
107+
at::Tensor workspace_buffer, int64_t tactic) {
108+
CHECK_GPU_INPUT(mat1, FLOAT4_E2M1X2);
109+
CHECK_GPU_INPUT(mat2, FLOAT4_E2M1X2);
110+
111+
int mat2_k_scale = 1;
112+
113+
CHECK_GPU_INPUT(mat1Scale, SF_DTYPE);
114+
CHECK_GPU_INPUT(mat2Scale, SF_DTYPE);
115+
116+
CHECK_GPU_INPUT(globalScale, at::ScalarType::Float);
117+
118+
int64_t m, n, k, b;
119+
if (mat1.dim() == 2) {
120+
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
121+
TORCH_CHECK(mat1.sizes()[1] == mat2.sizes()[1] * mat2_k_scale,
122+
"mat1 and mat2 shapes cannot be multiplied (", mat1.sizes()[0], "x",
123+
mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
124+
m = mat1.sizes()[0];
125+
n = mat2.sizes()[0];
126+
k = mat2.sizes()[1] * 2;
127+
b = 1;
128+
} else if (mat1.dim() == 3) {
129+
TORCH_CHECK(mat2.dim() == 3, "mat2 must be a batch of matrices");
130+
TORCH_CHECK(mat1.sizes()[0] == mat2.sizes()[0], "mat1 and mat2 must have the same batch size (",
131+
mat1.sizes()[0], " and ", mat2.sizes()[0], ")");
132+
TORCH_CHECK(mat1.sizes()[2] == mat2.sizes()[2] * mat2_k_scale,
133+
"mat1 and mat2 shapes cannot be multiplied (", mat1.sizes()[1], "x",
134+
mat1.sizes()[2], " and ", mat2.sizes()[1], "x", mat2.sizes()[2], ")");
135+
m = mat1.sizes()[1];
136+
n = mat2.sizes()[1];
137+
k = mat2.sizes()[2] * 2;
138+
b = mat1.sizes()[0];
139+
} else {
140+
C10_THROW_ERROR(NotImplementedError, "mat1 must be a matrix or a batch of matrices");
141+
}
142+
143+
// No heuristic for now, we rely on the autotuner to select the best tactic.
144+
if (tactic == -1) {
145+
tactic = 0;
146+
}
147+
auto config = getFp4GemmConfig(m, n, k, tactic);
148+
149+
constexpr int alignment = 32;
150+
TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment,
151+
", but got mat1 shape: (", mat1.sizes()[0], "x", mat1.sizes()[1], "), k: ", k, ".");
152+
TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment,
153+
", but got mat2 shape: (", mat2.sizes()[0], "x", mat2.sizes()[1], ").");
154+
155+
// Validate out dimensions
156+
std::vector<int64_t> out_shape =
157+
mat1.dim() == 2 ? std::vector<int64_t>{m, n} : std::vector<int64_t>{b, m, n};
158+
TORCH_CHECK(out.dim() == out_shape.size(), "out must have ", out_shape.size(),
159+
" dimensions, but got ", out.dim());
160+
for (int i = 0; i < out_shape.size(); ++i) {
161+
TORCH_CHECK(out.sizes()[i] == out_shape[i], "out shape mismatch at dimension ", i,
162+
": expected ", out_shape[i], ", got ", out.sizes()[i]);
163+
}
164+
165+
c10::ScalarType out_dtype = out.scalar_type();
166+
167+
switch (out_dtype) {
168+
case at::ScalarType::Half:
169+
runGemm<half>(out, mat1, mat2, mat1Scale, mat2Scale, globalScale, m, n, k, b, config,
170+
workspace_buffer);
171+
break;
172+
case at::ScalarType::BFloat16:
173+
runGemm<__nv_bfloat16>(out, mat1, mat2, mat1Scale, mat2Scale, globalScale, m, n, k, b, config,
174+
workspace_buffer);
175+
break;
176+
default:
177+
TORCH_CHECK(false, "out_dtype must be one of fp16/bf16.");
178+
}
179+
return out;
180+
}
181+
182+
} // namespace
183+
184+
at::Tensor fp4_gemm(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& mat1Scale,
185+
at::Tensor const& mat2Scale, at::Tensor const& globalScale, at::Tensor out,
186+
at::Tensor workspace_buffer, int64_t tactic) {
187+
return fp4_bmm_impl(mat1, mat2, mat1Scale, mat2Scale, globalScale, out, workspace_buffer, tactic);
188+
}
189+
190+
int64_t fp4_gemm_tactic_num() {
191+
auto getCutlassConfigs = []() {
192+
CutlassFp4GemmRunner<__nv_bfloat16, FP4GemmType::W4A4_NVFP4_NVFP4> gemmRunner;
193+
return gemmRunner.getConfigs();
194+
};
195+
static int64_t totalTactics = getCutlassConfigs().size();
196+
return totalTactics;
197+
}
198+
199+
} // namespace torch_ext
200+
201+
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
202+
m.def("fp4_gemm", &torch_ext::fp4_gemm);
203+
m.def("fp4_gemm_tactic_num", &torch_ext::fp4_gemm_tactic_num);
204+
}

csrc/fp4_gemm_cutlass.jinja

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "flashinfer/gemm/fp4_gemm_cutlass_template.h"
18+
19+
namespace flashinfer {
20+
namespace gemm {
21+
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 1, 1, _1SM)
22+
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 2, 1, _1SM)
23+
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 4, 1, _1SM)
24+
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 1, 1, _2SM)
25+
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 2, 1, _2SM)
26+
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 4, 1, _2SM)
27+
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 4, 2, 1, _2SM)
28+
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 4, 4, 1, _2SM)
29+
30+
} // namespace gemm
31+
} // namespace flashinfer

0 commit comments

Comments
 (0)