Skip to content

Commit 6ac62ab

Browse files
[Matrix][SYCL] Add bfloat16 support for joint_matrix (#6113)
1 parent ba5a126 commit 6ac62ab

File tree

2 files changed

+342
-0
lines changed

2 files changed

+342
-0
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <CL/__spirv/spirv_ops.hpp>
1212
#include <CL/sycl/detail/defines_elementary.hpp>
1313
#include <CL/sycl/feature_test.hpp>
14+
#include <sycl/ext/oneapi/experimental/bfloat16.hpp>
1415

1516
__SYCL_INLINE_NAMESPACE(cl) {
1617
namespace sycl {
@@ -453,6 +454,156 @@ class wi_element<uint16_t, NumRows, NumCols, Layout, Group> {
453454
#undef OP
454455
};
455456

457+
template <size_t NumRows, size_t NumCols, matrix_layout Layout, typename Group>
458+
class wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols,
459+
Layout, Group> {
460+
joint_matrix<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols,
461+
Layout, Group> &M;
462+
std::size_t idx;
463+
464+
public:
465+
wi_element(joint_matrix<sycl::ext::oneapi::experimental::bfloat16, NumRows,
466+
NumCols, Layout, Group> &Mat,
467+
std::size_t i)
468+
: M(Mat), idx(i) {}
469+
operator sycl::ext::oneapi::experimental::bfloat16() {
470+
#ifdef __SYCL_DEVICE_ONLY__
471+
return __spirv_VectorExtractDynamic(M.spvm, idx);
472+
#else
473+
throw runtime_error("joint matrix is not supported on host device.",
474+
PI_INVALID_DEVICE);
475+
#endif // __SYCL_DEVICE_ONLY__
476+
}
477+
478+
explicit operator bool() {
479+
#ifdef __SYCL_DEVICE_ONLY__
480+
return std::fabs(static_cast<float>(__spirv_VectorExtractDynamic(
481+
M.spvm, idx))) >= std::numeric_limits<float>::epsilon();
482+
#else
483+
throw runtime_error("joint matrix is not supported on host device.",
484+
PI_INVALID_DEVICE);
485+
#endif // __SYCL_DEVICE_ONLY__
486+
}
487+
488+
wi_element &operator=(const sycl::ext::oneapi::experimental::bfloat16 &rhs) {
489+
#ifdef __SYCL_DEVICE_ONLY__
490+
M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx);
491+
return *this;
492+
#else
493+
(void)rhs;
494+
throw runtime_error("joint matrix is not supported on host device.",
495+
PI_INVALID_DEVICE);
496+
#endif // __SYCL_DEVICE_ONLY__
497+
}
498+
499+
wi_element &
500+
operator=(const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows,
501+
NumCols, Layout, Group> &rhs) {
502+
#ifdef __SYCL_DEVICE_ONLY__
503+
M.spvm = __spirv_VectorInsertDynamic(
504+
M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx);
505+
return *this;
506+
#else
507+
(void)rhs;
508+
throw runtime_error("joint matrix is not supported on host device.",
509+
PI_INVALID_DEVICE);
510+
#endif // __SYCL_DEVICE_ONLY__
511+
}
512+
513+
#if __SYCL_DEVICE_ONLY__
514+
#define OP(opassign, op) \
515+
wi_element &operator opassign( \
516+
const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \
517+
M.spvm = __spirv_VectorInsertDynamic( \
518+
M.spvm, __spirv_VectorExtractDynamic(M.spvm, idx) op rhs, idx); \
519+
return *this; \
520+
}
521+
#else // __SYCL_DEVICE_ONLY__
522+
#define OP(opassign, op) \
523+
wi_element &operator opassign( \
524+
const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \
525+
(void)rhs; \
526+
throw runtime_error("joint matrix is not supported on host device.", \
527+
PI_INVALID_DEVICE); \
528+
}
529+
#endif // __SYCL_DEVICE_ONLY__
530+
OP(+=, +)
531+
OP(-=, -)
532+
OP(*=, *)
533+
OP(/=, /)
534+
#undef OP
535+
536+
#if __SYCL_DEVICE_ONLY__
537+
#define OP(type, op) \
538+
friend type operator op( \
539+
const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \
540+
NumCols, Layout, Group> &lhs, \
541+
const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \
542+
return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) op rhs; \
543+
} \
544+
friend type operator op( \
545+
const sycl::ext::oneapi::experimental::bfloat16 &lhs, \
546+
const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \
547+
NumCols, Layout, Group> &rhs) { \
548+
return __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx) op lhs; \
549+
}
550+
OP(sycl::ext::oneapi::experimental::bfloat16, +)
551+
OP(sycl::ext::oneapi::experimental::bfloat16, -)
552+
OP(sycl::ext::oneapi::experimental::bfloat16, *)
553+
OP(sycl::ext::oneapi::experimental::bfloat16, /)
554+
#undef OP
555+
#define OP(type, op) \
556+
friend type operator op( \
557+
const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \
558+
NumCols, Layout, Group> &lhs, \
559+
const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \
560+
return type{static_cast<float>(__spirv_VectorExtractDynamic( \
561+
lhs.M.spvm, lhs.idx)) op static_cast<float>(rhs)}; \
562+
} \
563+
friend type operator op( \
564+
const sycl::ext::oneapi::experimental::bfloat16 &lhs, \
565+
const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \
566+
NumCols, Layout, Group> &rhs) { \
567+
return type{static_cast<float>(__spirv_VectorExtractDynamic( \
568+
rhs.M.spvm, rhs.idx)) op static_cast<float>(lhs)}; \
569+
}
570+
OP(bool, ==)
571+
OP(bool, !=)
572+
OP(bool, <)
573+
OP(bool, >)
574+
OP(bool, <=)
575+
OP(bool, >=)
576+
#undef OP
577+
#else // __SYCL_DEVICE_ONLY__
578+
#define OP(type, op) \
579+
friend type operator op( \
580+
const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \
581+
NumCols, Layout, Group> &, \
582+
const sycl::ext::oneapi::experimental::bfloat16 &) { \
583+
throw runtime_error("joint matrix is not supported on host device.", \
584+
PI_INVALID_DEVICE); \
585+
} \
586+
friend type operator op( \
587+
const sycl::ext::oneapi::experimental::bfloat16 &, \
588+
const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \
589+
NumCols, Layout, Group> &) { \
590+
throw runtime_error("joint matrix is not supported on host device.", \
591+
PI_INVALID_DEVICE); \
592+
}
593+
OP(sycl::ext::oneapi::experimental::bfloat16, +)
594+
OP(sycl::ext::oneapi::experimental::bfloat16, -)
595+
OP(sycl::ext::oneapi::experimental::bfloat16, *)
596+
OP(sycl::ext::oneapi::experimental::bfloat16, /)
597+
OP(bool, ==)
598+
OP(bool, !=)
599+
OP(bool, <)
600+
OP(bool, >)
601+
OP(bool, <=)
602+
OP(bool, >=)
603+
#undef OP
604+
#endif // __SYCL_DEVICE_ONLY__
605+
};
606+
456607
template <typename T, size_t NumRows, size_t NumCols, matrix_layout Layout,
457608
typename Group>
458609
class wi_data {
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
// RUN: %clangxx -fsycl -O2 %s -o %t.out
2+
#include <CL/sycl.hpp>
3+
#if (SYCL_EXT_ONEAPI_MATRIX == 2)
4+
#include <iostream>
5+
6+
using namespace sycl::ext::oneapi::experimental::matrix;
7+
using bfloat16 = sycl::ext::oneapi::experimental::bfloat16;
8+
9+
static constexpr auto TILE_SZ = 16;
10+
static constexpr auto TM = TILE_SZ - 1;
11+
static constexpr auto TN = TILE_SZ - 1;
12+
static constexpr auto TK = 2 * TILE_SZ - 2;
13+
14+
static constexpr auto SG_SZ = 16;
15+
16+
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
17+
public:
18+
T *mat;
19+
20+
public:
21+
T *get_data() { return mat; }
22+
void set_data(T *data) { mat = data; }
23+
big_matrix(T *data) : mat(data) {}
24+
};
25+
26+
template <typename T1, typename T2, size_t NUM_ROWS_A, size_t NUM_COLS_A,
27+
size_t NUM_ROWS_B, size_t NUM_COLS_B, size_t NUM_ROWS_C,
28+
size_t NUM_COLS_C>
29+
void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
30+
big_matrix<T2, NUM_ROWS_A, NUM_COLS_A> &A,
31+
big_matrix<T2, NUM_ROWS_B, NUM_COLS_B> &B) {
32+
size_t M = NUM_ROWS_C;
33+
size_t N = NUM_COLS_C;
34+
size_t K = NUM_COLS_A;
35+
// B => K/4 x N*4, A => M x K, C => M, N
36+
// stride should be X's cols, e.g., B's stirde = N*4
37+
assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 2);
38+
size_t NDRangeM = M / TM;
39+
size_t NDRangeN = N / TN;
40+
sycl::buffer<bfloat16, 2> bufA(A.get_data(), sycl::range<2>(M, K));
41+
sycl::buffer<bfloat16, 2> bufB(B.get_data(), sycl::range<2>(K, N));
42+
sycl::buffer<float, 2> bufC((float *)C.get_data(), sycl::range<2>(M, N));
43+
44+
sycl::queue q;
45+
q.submit([&](sycl::handler &cgh) {
46+
auto accC = bufC.get_access<sycl::access::mode::read_write>(cgh);
47+
auto accA = bufA.get_access<sycl::access::mode::read_write>(cgh);
48+
auto accB = bufB.get_access<sycl::access::mode::read_write>(cgh);
49+
50+
cgh.parallel_for<class imatrix>(
51+
sycl::nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
52+
[accA, accB, accC, M, N, K](sycl::nd_item<2> spmd_item)
53+
54+
{
55+
// The submatrix API has to be accessed by all the workitems in a
56+
// subgroup these functions will be called once by the subgroup no
57+
// code divergence between the workitems
58+
const auto global_idx = spmd_item.get_global_id(0);
59+
const auto global_idy = spmd_item.get_global_id(1);
60+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
61+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
62+
63+
sycl::ext::oneapi::sub_group sg = spmd_item.get_sub_group();
64+
joint_matrix<bfloat16, TM, TK> sub_a(sg);
65+
// For B, since current implementation does not support non-packed
66+
// layout, users need to specify the updated VNNI sizes along with
67+
// the packed_b layout. By default, the layout is row_major and size
68+
// is (TK, TN).
69+
joint_matrix<bfloat16, TK, TN, matrix_layout::packed_b> sub_b(sg);
70+
joint_matrix<float, TM, TN> sub_c(sg);
71+
72+
// AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64
73+
// strideX = X's cols, so strideC = N, strideA = K, strideB = N*4
74+
joint_matrix_load(sg, sub_c,
75+
accC.get_pointer() + (sg_startx * TM) * N +
76+
sg_starty / SG_SZ * TN,
77+
N, matrix_layout::row_major);
78+
for (int k = 0; k < K / TK; k += 1) { //
79+
joint_matrix_load(
80+
sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK,
81+
K, matrix_layout::row_major);
82+
// Assuming B data is already in VNNI format.
83+
joint_matrix_load(sg, sub_b,
84+
accB.get_pointer() + (k * TK / 2) * (N * 2) +
85+
sg_starty / SG_SZ * TN * 2,
86+
N * 2, matrix_layout::packed_b);
87+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
88+
}
89+
joint_matrix_store(sg, sub_c,
90+
accC.get_pointer() + (sg_startx * TM) * N +
91+
sg_starty / SG_SZ * TN,
92+
N, matrix_layout::row_major);
93+
}); // parallel for
94+
}).wait();
95+
}
96+
97+
static constexpr size_t MATRIX_M = TM * 2;
98+
static constexpr size_t MATRIX_N = TN * 2;
99+
static constexpr size_t MATRIX_K = TK * 2;
100+
bfloat16 A[MATRIX_M][MATRIX_K];
101+
bfloat16 B[MATRIX_K / 2][MATRIX_N * 2];
102+
unsigned short Aref[MATRIX_M][MATRIX_K];
103+
unsigned short Bref[MATRIX_K / 2][MATRIX_N * 2];
104+
float C[MATRIX_M][MATRIX_N];
105+
float D[MATRIX_M][MATRIX_N];
106+
107+
float make_fp32(short x) {
108+
unsigned int y = x;
109+
y = y << 16;
110+
float *res = reinterpret_cast<float *>(&y);
111+
return *res;
112+
}
113+
114+
unsigned short make_bf16(float x) {
115+
int *res = reinterpret_cast<int *>(&x);
116+
*res = *res >> 16;
117+
return (unsigned short)*res;
118+
}
119+
120+
void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N,
121+
int K) {
122+
// tiling
123+
for (int m = 0; m < M; m++)
124+
for (int n = 0; n < N; n++) {
125+
for (int k = 0; k < K; k++) {
126+
short *va = (short *)(A_mem + m * K + k);
127+
short *vb = (short *)(B_mem + k * N + n);
128+
float acc = *((float *)(C_mem + m * N + n));
129+
// FIXME: Should we do reduce-add in another version?
130+
for (int i = 0; i < 2; i++) {
131+
acc += (make_fp32(va[i]) * make_fp32(vb[i]));
132+
}
133+
*((float *)(C_mem + m * N + n)) = acc;
134+
}
135+
}
136+
}
137+
138+
int main() {
139+
for (int i = 0; i < MATRIX_M; i++) {
140+
for (int j = 0; j < MATRIX_K; j++) {
141+
// Ee create bfloat16 from unsigned short since float-to-bfloat's
142+
// conversion is not allowed.
143+
A[i][j] = bfloat16::from_bits(make_bf16(1.0f * (i + j)));
144+
Aref[i][j] = make_bf16(1.0f * (i + j));
145+
}
146+
}
147+
for (int i = 0; i < MATRIX_K / 2; i++) {
148+
for (int j = 0; j < MATRIX_N * 2; j++) {
149+
B[i][j] = bfloat16::from_bits((make_bf16(2.0f * i + 3.0f * j)));
150+
Bref[i][j] = make_bf16(2.0f * i + 3.0f * j);
151+
}
152+
}
153+
for (int i = 0; i < MATRIX_M; i++) {
154+
for (int j = 0; j < MATRIX_N; j++) {
155+
C[i][j] = 1.0;
156+
D[i][j] = 1.0;
157+
}
158+
}
159+
160+
big_matrix<float, MATRIX_M, MATRIX_N> MC((float *)&C);
161+
big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D);
162+
big_matrix<bfloat16, MATRIX_M, MATRIX_K> MA((bfloat16 *)&A);
163+
big_matrix<bfloat16, MATRIX_K / 2, MATRIX_N * 2> MB((bfloat16 *)&B);
164+
matrix_multiply(MC, MA, MB);
165+
matrix_multiply_ref((int32_t *)Aref, (int32_t *)Bref, (int32_t *)D, MATRIX_M,
166+
MATRIX_N, MATRIX_K / 2);
167+
168+
bool res = true;
169+
for (int i = 0; i < MATRIX_M; i++) {
170+
for (int j = 0; j < MATRIX_N; j++) {
171+
if (C[i][j] != D[i][j])
172+
res = false;
173+
}
174+
}
175+
if (res)
176+
std::cout << "passed\n";
177+
else
178+
std::cout << "failed\n";
179+
for (int i = 0; i < MATRIX_M; i++) {
180+
for (int j = 0; j < MATRIX_N; j++)
181+
std::cout << C[i][j] << ", ";
182+
std::cout << "\n";
183+
}
184+
std::cout << std::endl;
185+
for (int i = 0; i < MATRIX_M; i++) {
186+
for (int j = 0; j < MATRIX_N; j++)
187+
std::cout << D[i][j] << ", ";
188+
std::cout << "\n";
189+
}
190+
}
191+
#endif // (SYCL_EXT_ONEAPI_MATRIX == 2)

0 commit comments

Comments
 (0)