Skip to content

Commit ef015f4

Browse files
[Matrix] Enable joint_matrix_fill for joint_matrix feature
1 parent 3205368 commit ef015f4

File tree

3 files changed

+193
-0
lines changed

3 files changed

+193
-0
lines changed

sycl/include/CL/__spirv/spirv_ops.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ __spirv_JointMatrixSUMadINTEL(
8686
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S> *C,
8787
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
8888

89+
template <typename T, std::size_t R, std::size_t C,
90+
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
91+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
92+
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S> *
93+
__spirv_JointMatrixFillINTEL(const T &v, __spv::Scope::Flag Sc = S);
94+
8995
#ifndef __SPIRV_BUILTIN_DECLARATIONS__
9096
#error \
9197
"SPIR-V built-ins are not available. Please set -fdeclare-spirv-builtins flag."

sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,24 @@ joint_matrix_mad(Group sg, joint_matrix<T1, M, K, LayoutA, Group> &mA,
191191
PI_INVALID_DEVICE);
192192
#endif // __SYCL_DEVICE_ONLY__
193193
}
194+
195+
template <typename Group, typename T, size_t NumRows, size_t NumCols,
196+
matrix_layout Layout>
197+
inline __SYCL_ALWAYS_INLINE void
198+
joint_matrix_fill(Group sg,
199+
joint_matrix<T, NumRows, NumCols, Layout, Group> &res,
200+
const T &v) {
201+
#ifdef __SYCL_DEVICE_ONLY__
202+
res.spvm =
203+
__spirv_JointMatrixFillINTEL<T, NumRows, NumCols,
204+
spv_matrix_layout_traits<Layout>::value>(
205+
v, spv_scope_traits<Group>::value);
206+
#else
207+
(void)res;
208+
(void)v;
209+
#endif // __SYCL_DEVICE_ONLY__
210+
}
211+
194212
} // namespace experimental::matrix
195213
} // namespace oneapi
196214
} // namespace ext
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
// RUN: %clangxx -fsycl -O2 %s -o %t.out
2+
#include <CL/sycl.hpp>
3+
#if (SYCL_EXT_ONEAPI_MATRIX == 2)
4+
#include <iostream>
5+
6+
using namespace sycl;
7+
using namespace sycl::ext::oneapi::experimental::matrix;
8+
9+
#define TILE_SZ 16
10+
#define TM (TILE_SZ-4)
11+
#define TN (TILE_SZ-4)
12+
#define TK (4 * TILE_SZ-16)
13+
14+
#define SG_SZ 16
15+
16+
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix{
17+
public:
18+
T *mat;
19+
20+
public:
21+
T *get_data() { return mat; }
22+
void set_data(T *data) { mat = data; }
23+
big_matrix(T *data) : mat(data) {
24+
}
25+
};
26+
27+
template <typename T1, typename T2, size_t NUM_ROWS_A, size_t NUM_COLS_A, size_t NUM_ROWS_B,
28+
size_t NUM_COLS_B, size_t NUM_ROWS_C, size_t NUM_COLS_C>
29+
void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C, big_matrix<T2, NUM_ROWS_A, NUM_COLS_A> &A, big_matrix<T2, NUM_ROWS_B, NUM_COLS_B> &B) {
30+
size_t M = NUM_ROWS_C;
31+
size_t N = NUM_COLS_C;
32+
size_t K = NUM_COLS_A;
33+
// B => K/4 x N*4, A => M x K, C => M, N
34+
// stride should be X's cols, e.g., B's stirde = N*4
35+
assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 4);
36+
size_t NDRangeM = M / TM;
37+
size_t NDRangeN = N / TN;
38+
buffer<int8_t, 2> bufA(A.get_data(), range<2>(M, K));
39+
buffer<int8_t, 2> bufB(B.get_data(), range<2>(K, N));
40+
buffer<int32_t, 2> bufC(C.get_data(), range<2>(M, N));
41+
42+
queue q;
43+
q.submit([&](handler &cgh) {
44+
auto accC = bufC.get_access<access::mode::read_write>(cgh);
45+
auto accA = bufA.get_access<access::mode::read_write>(cgh);
46+
auto accB = bufB.get_access<access::mode::read_write>(cgh);
47+
48+
cgh.parallel_for<class imatrix>(
49+
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
50+
[accA, accB, accC, M, N, K](nd_item<2> spmd_item)
51+
52+
{
53+
// The submatrix API has to be accessed by all the workitems in a
54+
// subgroup these functions will be called once by the subgroup no
55+
// code divergence between the workitems
56+
const auto global_idx = spmd_item.get_global_id(0);
57+
const auto global_idy = spmd_item.get_global_id(1);
58+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
59+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
60+
61+
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
62+
joint_matrix<int8_t, TM, TK> sub_a(sg);
63+
// For B, since current implementation does not support non-packed layout,
64+
// users need to specify the updated VNNI sizes along with the packed_b layout.
65+
// By default, the layout is row_major and size is (TK, TN).
66+
joint_matrix<int8_t, TK, TN, matrix_layout::packed_b> sub_b(sg);
67+
joint_matrix<int32_t, TM, TN> sub_c(sg);
68+
69+
// AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64
70+
// strideX = X's cols, so strideC = N, strideA = K, strideB = N*4
71+
joint_matrix_fill(sg, sub_c, 0);
72+
joint_matrix_load(sg, sub_c,
73+
accC.get_pointer() + (sg_startx * TM) * N +
74+
sg_starty / SG_SZ * TN,
75+
N, matrix_layout::row_major);
76+
for (int k = 0; k < K / TK; k += 1) {
77+
joint_matrix_load(
78+
sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK,
79+
K, matrix_layout::row_major);
80+
// Assuming B data is already in VNNI format.
81+
joint_matrix_load(sg, sub_b,
82+
accB.get_pointer() + (k * TK / 4) * (N * 4) +
83+
sg_starty / SG_SZ * TN * 4,
84+
N * 4, matrix_layout::packed_b);
85+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
86+
}
87+
joint_matrix_store(sg, sub_c,
88+
accC.get_pointer() + (sg_startx * TM) * N +
89+
sg_starty / SG_SZ * TN,
90+
N, matrix_layout::row_major);
91+
}); // parallel for
92+
}).wait();
93+
}
94+
95+
static constexpr size_t MATRIX_M = TM * 2;
96+
static constexpr size_t MATRIX_N = TN * 2;
97+
static constexpr size_t MATRIX_K = TK * 2;
98+
int8_t A[MATRIX_M][MATRIX_K];
99+
int8_t B[MATRIX_K / 4][MATRIX_N * 4];
100+
int32_t C[MATRIX_M][MATRIX_N];
101+
int32_t D[MATRIX_M][MATRIX_N];
102+
103+
void matrix_multiply_ref(int32_t *A_mem, int32_t *B_mem, int32_t *C_mem, int M,
104+
int N, int K) {
105+
// tiling
106+
for (int m = 0; m < M; m++)
107+
for (int n = 0; n < N; n++) {
108+
for (int k = 0; k < K; k++) {
109+
char *va = (char *)(A_mem + m * K + k);
110+
char *vb = (char *)(B_mem + k * N + n);
111+
int acc = *(C_mem + m * N + n);
112+
for (int i = 0; i < 4; i++) {
113+
acc += (va[i] * vb[i]);
114+
}
115+
*(C_mem + m * N + n) = acc;
116+
}
117+
}
118+
}
119+
120+
int main() {
121+
for (int i = 0; i < MATRIX_M; i++) {
122+
for (int j = 0; j < MATRIX_K; j++) {
123+
A[i][j] = i+2*j;
124+
}
125+
}
126+
for (int i = 0; i < MATRIX_K / 4; i++) {
127+
for (int j = 0; j < MATRIX_N * 4; j++) {
128+
B[i][j] = i+j;
129+
}
130+
}
131+
for (int i = 0; i < MATRIX_M; i++) {
132+
for (int j = 0; j < MATRIX_N; j++) {
133+
C[i][j] = 1;
134+
D[i][j] = 1;
135+
}
136+
}
137+
138+
big_matrix<int32_t, MATRIX_M, MATRIX_N> MC((int32_t *)&C);
139+
big_matrix<int32_t, MATRIX_M, MATRIX_N> MD((int32_t *)&D);
140+
big_matrix<int8_t, MATRIX_M, MATRIX_K> MA((int8_t *)&A);
141+
big_matrix<int8_t,MATRIX_K / 4, MATRIX_N * 4> MB((int8_t *)&B);
142+
matrix_multiply(MC, MA, MB);
143+
matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M,
144+
MATRIX_N, MATRIX_K / 4);
145+
146+
bool res = true;
147+
for (int i = 0; i < MATRIX_M; i++) {
148+
for (int j = 0; j < MATRIX_N; j++) {
149+
if (C[i][j] != D[i][j])
150+
res = false;
151+
}
152+
}
153+
if (res)
154+
std::cout << "passed\n";
155+
else
156+
std::cout << "failed\n";
157+
for (int i = 0; i < MATRIX_M; i++) {
158+
for (int j = 0; j < MATRIX_N; j++)
159+
std::cout << C[i][j] << ", ";
160+
std::cout << "\n";
161+
}
162+
std::cout << std::endl;
163+
for (int i = 0; i < MATRIX_M; i++) {
164+
for (int j = 0; j < MATRIX_N; j++)
165+
std::cout << D[i][j] << ", ";
166+
std::cout << "\n";
167+
}
168+
}
169+
#endif // (SYCL_EXT_ONEAPI_MATRIX == 2)

0 commit comments

Comments
 (0)