Skip to content

Commit fb82c93

Browse files
committed
add int8 packed gemm support on CPU device
ghstack-source-id: c0f5133 Pull Request resolved: #118056
1 parent b32ace4 commit fb82c93

File tree

8 files changed

+440
-0
lines changed

8 files changed

+440
-0
lines changed

aten/src/ATen/native/LinearAlgebra.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include <ATen/ops/_linalg_slogdet_native.h>
3838
#include <ATen/ops/_unsafe_view.h>
3939
#include <ATen/ops/_weight_int4pack_mm_native.h>
40+
#include <ATen/ops/_weight_int8pack_mm_native.h>
4041
#include <ATen/ops/addbmm_native.h>
4142
#include <ATen/ops/addmm_native.h>
4243
#include <ATen/ops/addr.h>
@@ -3375,6 +3376,7 @@ Tensor kron(const Tensor& self, const Tensor& other) {
33753376
// Weight Only Quantization Gemm
33763377
DEFINE_DISPATCH(weight_to_int4pack_stub);
33773378
DEFINE_DISPATCH(int4pack_mm_stub);
3379+
DEFINE_DISPATCH(int8pack_mm_stub);
33783380

33793381
Tensor _convert_weight_to_int4pack_cpu(
33803382
const Tensor& in,
@@ -3436,5 +3438,36 @@ Tensor _weight_int4pack_mm_cpu(
34363438
return C;
34373439
}
34383440

3441+
Tensor _weight_int8pack_mm_cpu(
3442+
const Tensor& A,
3443+
const Tensor& B,
3444+
const Tensor& scales) {
3445+
3446+
auto M = A.size(0);
3447+
auto N = B.size(0);
3448+
auto K = A.size(1);
3449+
3450+
TORCH_CHECK(A.dtype() == kBFloat16,
3451+
"_weight_int8pack_mm: expect A to be bfloat16 tensor.");
3452+
TORCH_CHECK(A.is_contiguous(),
3453+
"_weight_int8pack_mm: expect A to be contiguous.");
3454+
TORCH_CHECK(A.dim() == 2,
3455+
"_weight_int8pack_mm: expect A to be 2D tensor.");
3456+
3457+
TORCH_CHECK(B.dtype() == kChar,
3458+
"_weight_int8pack_mm: expect B to be int8 tensor.");
3459+
TORCH_CHECK(B.is_contiguous(),
3460+
"_weight_int8pack_mm: expect B to be contiguous.");
3461+
TORCH_CHECK(B.size(1) == K);
3462+
3463+
TORCH_CHECK(scales.dim() == 1);
3464+
TORCH_CHECK(scales.size(0) == N);
3465+
3466+
auto C = at::empty({M, N}, A.options());
3467+
int8pack_mm_stub(kCPU, C, A, B, scales);
3468+
3469+
return C;
3470+
}
3471+
34393472
} // namespace native
34403473
} // namespace at

aten/src/ATen/native/LinearAlgebra.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@ namespace at::native {
1717
using addr_fn = void (*)(TensorIterator &, const Scalar& beta, const Scalar& alpha);
1818
using weight_to_int4pack_fn = void(*)(const Tensor&, const Tensor&);
1919
using int4pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, int64_t, const Tensor&);
20+
using int8pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&);
2021

2122
DECLARE_DISPATCH(addr_fn, addr_stub);
2223
DECLARE_DISPATCH(weight_to_int4pack_fn, weight_to_int4pack_stub);
2324
DECLARE_DISPATCH(int4pack_mm_fn, int4pack_mm_stub);
25+
DECLARE_DISPATCH(int8pack_mm_fn, int8pack_mm_stub);
2426

2527
} // namespace at::native
Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2+
#include <ATen/core/Tensor.h>
3+
4+
#include <ATen/Dispatch.h>
5+
#include <ATen/Parallel.h>
6+
#include <ATen/cpu/vec/functional.h>
7+
#include <ATen/cpu/vec/vec.h>
8+
#include <ATen/native/cpu/utils.h>
9+
#include <ATen/native/LinearAlgebra.h>
10+
#include <c10/util/irange.h>
11+
12+
#if (defined(_WIN32) || defined(_WIN64))
13+
#define RESTRICT __restrict
14+
#else
15+
#define RESTRICT __restrict__
16+
#endif
17+
18+
namespace at::native {
19+
20+
namespace {
21+
22+
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
23+
24+
// A block : {BLOCK_M, BLOCK_K}, lda = K
25+
// B block : {BLOCK_K, BLOCK_N}, ldb = K
26+
// C block : {BLOCK_M, BLOCK_N}, ldc = N
27+
//
28+
// scales block: {BLOCK_N}
29+
//
30+
template <int BLOCK_M, int BLOCK_N>
31+
inline void tinygemm_kernel(
32+
const BFloat16* RESTRICT A,
33+
const int8_t* RESTRICT B,
34+
const BFloat16* RESTRICT scales,
35+
BFloat16* RESTRICT C,
36+
int lda,
37+
int ldb,
38+
int ldc,
39+
int K) {
40+
41+
constexpr int ROWS = BLOCK_M;
42+
constexpr int COLS = BLOCK_N;
43+
44+
const int PREFETCH_SIZE_K = 16 * 4;
45+
46+
__m512 va;
47+
__m512 vb[COLS];
48+
__m512 vc[ROWS * COLS];
49+
__m512 scale[COLS];
50+
51+
auto load_scale = [&](int i) {
52+
float ss = static_cast<float>(scales[i]);
53+
scale[i] = _mm512_set1_ps(ss);
54+
};
55+
compile_time_for<COLS>::op(load_scale);
56+
57+
auto loadc = [&](auto i) {
58+
vc[i] = _mm512_setzero_ps();
59+
};
60+
compile_time_for<ROWS * COLS>::op(loadc);
61+
62+
auto compute = [&](auto i, int k) {
63+
constexpr int row = i / COLS;
64+
constexpr int col = i % COLS;
65+
66+
if constexpr (col == 0) {
67+
__m256i a16 = _mm256_load_si256((__m256i*)(A + row * lda + k));
68+
if (k + PREFETCH_SIZE_K < K) {
69+
_mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0);
70+
}
71+
vec::cvtbf16_fp32(a16, va);
72+
}
73+
74+
if constexpr (row == 0) {
75+
__m128i b8 = _mm_load_si128((__m128i*)(B + col * ldb + k));
76+
if (k + PREFETCH_SIZE_K < K) {
77+
_mm_prefetch(B + col * ldb + k + PREFETCH_SIZE_K, _MM_HINT_T0);
78+
}
79+
__m512i b32 = _mm512_cvtepi8_epi32(b8);
80+
vb[col] = _mm512_cvtepi32_ps(b32);
81+
vb[col] = _mm512_mul_ps(vb[col], scale[col]);
82+
}
83+
84+
constexpr int idx = row * COLS + col;
85+
vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]);
86+
};
87+
88+
for (int k = 0; k < K; k += 16) {
89+
compile_time_for<ROWS * COLS>::op(compute, k);
90+
}
91+
92+
auto storec = [&](auto i) {
93+
constexpr int row = i / COLS;
94+
constexpr int col = i % COLS;
95+
C[row * ldc + col] = static_cast<BFloat16>(_mm512_reduce_add_ps(vc[i]));
96+
};
97+
compile_time_for<ROWS * COLS>::op(storec);
98+
}
99+
100+
#elif defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
101+
102+
static inline float _mm256_reduce_add_ps(__m256& v) {
103+
__m256 v1 = _mm256_permute2f128_ps(v, v, 0x1);
104+
v = _mm256_add_ps(v, v1);
105+
v1 = _mm256_shuffle_ps(v, v, 0x4E);
106+
v = _mm256_add_ps(v, v1);
107+
v1 = _mm256_shuffle_ps(v, v, 0xB1);
108+
v = _mm256_add_ps(v, v1);
109+
return _mm256_cvtss_f32(v);
110+
}
111+
112+
template <int BLOCK_M, int BLOCK_N>
113+
inline void tinygemm_kernel(
114+
const BFloat16* RESTRICT A,
115+
const int8_t* RESTRICT B,
116+
const BFloat16* RESTRICT scales,
117+
BFloat16* RESTRICT C,
118+
int lda,
119+
int ldb,
120+
int ldc,
121+
int K) {
122+
123+
constexpr int ROWS = BLOCK_M;
124+
constexpr int COLS = BLOCK_N;
125+
126+
const int PREFETCH_SIZE_K = 16 * 4;
127+
128+
__m256 va;
129+
__m256 vb[COLS];
130+
__m256 vc[ROWS * COLS];
131+
__m256 scale[COLS];
132+
133+
auto load_scale = [&](int i) {
134+
float ss = static_cast<float>(scales[i]);
135+
scale[i] = _mm256_set1_ps(ss);
136+
};
137+
compile_time_for<COLS>::op(load_scale);
138+
139+
auto loadc = [&](auto i) {
140+
vc[i] = _mm256_setzero_ps();
141+
};
142+
compile_time_for<ROWS * COLS>::op(loadc);
143+
144+
auto compute = [&](auto i, int k) {
145+
constexpr int row = i / COLS;
146+
constexpr int col = i % COLS;
147+
148+
if constexpr (col == 0) {
149+
__m128i a16 = _mm_load_si128((__m128i*)(A + row * lda + k));
150+
if (k + PREFETCH_SIZE_K < K) {
151+
_mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0);
152+
}
153+
vec::cvtbf16_fp32(a16, va);
154+
}
155+
156+
if constexpr (row == 0) {
157+
__m128i b8 = _mm_loadu_si64((__m128i*)(B + col * ldb + k));
158+
if (k + PREFETCH_SIZE_K < K) {
159+
_mm_prefetch(B + col * ldb + k + PREFETCH_SIZE_K, _MM_HINT_T0);
160+
}
161+
__m256i b32 = _mm256_cvtepi8_epi32(b8);
162+
vb[col] = _mm256_cvtepi32_ps(b32);
163+
vb[col] = _mm256_mul_ps(vb[col], scale[col]);
164+
}
165+
166+
constexpr int idx = row * COLS + col;
167+
vc[idx] = _mm256_fmadd_ps(va, vb[col], vc[idx]);
168+
};
169+
170+
for (int k = 0; k < K; k += 8) {
171+
compile_time_for<ROWS * COLS>::op(compute, k);
172+
}
173+
174+
auto storec = [&](auto i) {
175+
constexpr int row = i / COLS;
176+
constexpr int col = i % COLS;
177+
C[row * ldc + col] = static_cast<BFloat16>(_mm256_reduce_add_ps(vc[i]));
178+
};
179+
compile_time_for<ROWS * COLS>::op(storec);
180+
}
181+
182+
#else
183+
184+
// non-vectorized version
185+
template <int BLOCK_M, int BLOCK_N>
186+
inline void tinygemm_kernel(
187+
const BFloat16* RESTRICT A,
188+
const int8_t* RESTRICT B,
189+
const BFloat16* RESTRICT scales,
190+
BFloat16* RESTRICT C,
191+
int lda,
192+
int ldb,
193+
int ldc,
194+
int K) {
195+
196+
for (const auto m : c10::irange(BLOCK_M)) {
197+
for (const auto n : c10::irange(BLOCK_N)) {
198+
float c_val = 0;
199+
float scale_val = static_cast<float>(scales[n]);
200+
for (const auto k : c10::irange(K)) {
201+
float a_val = static_cast<float>(A[m * lda + k]);
202+
float b_val = static_cast<float>(B[n * ldb + k]);
203+
c_val += a_val * (b_val * scale_val);
204+
}
205+
C[m * ldc + n] = c_val;
206+
}
207+
}
208+
}
209+
210+
#endif
211+
212+
#define LAUNCH_TINYGEMM_KERNEL(MB_SIZE, NB_SIZE) \
213+
tinygemm_kernel<MB_SIZE, NB_SIZE>( \
214+
A_ptr, B_ptr, S_ptr, C_ptr, \
215+
K, K, N, K);
216+
217+
#define LAUNCH_TINYGEMM_NB_SIZE(MB_SIZE) \
218+
switch (nb_size) { \
219+
case 1: \
220+
LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 1); \
221+
break; \
222+
case 2: \
223+
LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 2); \
224+
break; \
225+
case 3: \
226+
LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 3); \
227+
break; \
228+
case 4: \
229+
LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 4); \
230+
break; \
231+
default: \
232+
TORCH_CHECK(false, "Unsupported n block size: ", nb_size); \
233+
break; \
234+
}
235+
236+
void int8pack_mm_kernel(
237+
const Tensor& C,
238+
const Tensor& A,
239+
const Tensor& B,
240+
const Tensor& scales) {
241+
242+
const BFloat16* A_data = A.data_ptr<BFloat16>();
243+
const int8_t* B_data = B.data_ptr<int8_t>();
244+
BFloat16* C_data = C.data_ptr<BFloat16>();
245+
const BFloat16* S_data = scales.data_ptr<BFloat16>();
246+
247+
int M = A.size(0);
248+
int N = B.size(0);
249+
int K = A.size(1);
250+
251+
constexpr int BLOCK_M = 4;
252+
constexpr int BLOCK_N = 4;
253+
254+
const int MB = (M + BLOCK_M - 1) / BLOCK_M;
255+
const int NB = (N + BLOCK_N - 1) / BLOCK_N;
256+
257+
at::parallel_for(0, MB * NB, 0, [&](int begin, int end) {
258+
int mb{0}, nb{0};
259+
data_index_init(begin, mb, MB, nb, NB);
260+
261+
for (const auto i : c10::irange(begin, end)) {
262+
(void)i;
263+
264+
int mb_start = mb * BLOCK_M;
265+
int mb_size = std::min(BLOCK_M, M - mb_start);
266+
int nb_start = nb * BLOCK_N;
267+
int nb_size = std::min(BLOCK_N, N - nb_start);
268+
269+
const BFloat16* A_ptr = A_data + mb_start * K;
270+
const int8_t* B_ptr = B_data + nb_start * K;
271+
const BFloat16* S_ptr = S_data + nb_start;
272+
BFloat16* C_ptr = C_data + mb_start * N + nb_start;
273+
274+
switch (mb_size) {
275+
case 1:
276+
LAUNCH_TINYGEMM_NB_SIZE(1);
277+
break;
278+
case 2:
279+
LAUNCH_TINYGEMM_NB_SIZE(2);
280+
break;
281+
case 3:
282+
LAUNCH_TINYGEMM_NB_SIZE(3);
283+
break;
284+
case 4:
285+
LAUNCH_TINYGEMM_NB_SIZE(4);
286+
break;
287+
default:
288+
TORCH_CHECK(false, "Unsupported m block size: ", mb_size);
289+
}
290+
291+
// move to the next index
292+
data_index_step(mb, MB, nb, NB);
293+
}
294+
});
295+
}
296+
297+
} // anonymous namespace
298+
299+
ALSO_REGISTER_AVX512_DISPATCH(int8pack_mm_stub, &int8pack_mm_kernel);
300+
301+
} // at::native

aten/src/ATen/native/native_functions.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4089,6 +4089,10 @@
40894089
CPU: _weight_int4pack_mm_cpu
40904090
CUDA: _weight_int4pack_mm_cuda
40914091

4092+
- func: _weight_int8pack_mm(Tensor self, Tensor mat2, Tensor scales) -> Tensor
4093+
dispatch:
4094+
CPU: _weight_int8pack_mm_cpu
4095+
40924096
- func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor
40934097
python_module: sparse
40944098

test/expect/HasDecompTest.test_has_decomposition.expect

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,7 @@ aten::_values
603603
aten::_values_copy
604604
aten::_values_copy.out
605605
aten::_weight_int4pack_mm
606+
aten::_weight_int8pack_mm
606607
aten::_weight_norm_interface_backward
607608
aten::_weight_norm_interface_backward.out
608609
aten::adaptive_avg_pool2d.out

0 commit comments

Comments
 (0)