Skip to content

Commit 918c453

Browse files
committed
DRAFT - Add different output and accumulator support in spirv. Add casting bfloat16 to float
1 parent 7b9490b commit 918c453

17 files changed

+403
-460
lines changed

sycl/include/sycl/__spirv/spirv_ops.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,15 @@ extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixStoreCheckedINTEL(
8484
std::size_t Stride, size_t Height, size_t Width, size_t CoordX,
8585
size_t CoordY, __spv::MatrixLayout Layout = L, int MemOperand = 0);
8686

87-
template <typename TA, typename TB, typename TC, std::size_t M, std::size_t K,
87+
template <typename TA, typename TB, typename TC, typename TD, std::size_t M, std::size_t K,
8888
std::size_t N, __spv::MatrixUse UA, __spv::MatrixUse UB,
8989
__spv::MatrixUse UC,
9090
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
9191
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
9292
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
9393
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
9494
extern __DPCPP_SYCL_EXTERNAL
95-
__spv::__spirv_CooperativeMatrixKHR<TC, S, M, N, UC> *
95+
__spv::__spirv_CooperativeMatrixKHR<TD, S, M, N, UC> *
9696
__spirv_CooperativeMatrixMulAddKHR(
9797
__spv::__spirv_CooperativeMatrixKHR<TA, S, M, K, UA> *A,
9898
__spv::__spirv_CooperativeMatrixKHR<TB, S, K, N, UB> *B,

sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,26 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
263263
#endif // __SYCL_DEVICE_ONLY__
264264
}
265265

266+
operator float() {
267+
#ifdef __SYCL_DEVICE_ONLY__
268+
sycl::ext::oneapi::bfloat16 *ExtractP =
269+
__spirv_AccessChain<sycl::ext::oneapi::bfloat16,
270+
sycl::ext::oneapi::bfloat16, NumRows, NumCols,
271+
spv_matrix_use_traits<Use>::value,
272+
spv_scope_traits<Group>::value>(&M.spvm, idx);
273+
union {
274+
uint16_t intStorage;
275+
sycl::ext::oneapi::bfloat16 floatValue;
276+
};
277+
floatValue = *ExtractP;
278+
return __devicelib_ConvertBF16ToFINTEL(intStorage);
279+
280+
#else
281+
throw exception(make_error_code(errc::runtime),
282+
"joint matrix is not supported on host.");
283+
#endif // __SYCL_DEVICE_ONLY__
284+
}
285+
266286
explicit operator bool() {
267287
#ifdef __SYCL_DEVICE_ONLY__
268288
sycl::ext::oneapi::bfloat16 *ExtractP =
@@ -295,6 +315,21 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
295315
#endif // __SYCL_DEVICE_ONLY__
296316
}
297317

318+
wi_element &operator=(const float &rhs) {
319+
#ifdef __SYCL_DEVICE_ONLY__
320+
float *InsertP =
321+
__spirv_AccessChain<float, float, NumRows, NumCols,
322+
spv_matrix_use_traits<Use>::value,
323+
spv_scope_traits<Group>::value>(&M.spvm, idx);
324+
*InsertP = rhs;
325+
return *this;
326+
#else
327+
(void)rhs;
328+
throw exception(make_error_code(errc::runtime),
329+
"joint matrix is not supported on host.");
330+
#endif // __SYCL_DEVICE_ONLY__
331+
}
332+
298333
wi_element &operator=(const wi_element<sycl::ext::oneapi::bfloat16, NumRows,
299334
NumCols, Use, Layout, Group> &rhs) {
300335
#ifdef __SYCL_DEVICE_ONLY__

sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,18 @@ extern "C" constexpr __spv::MatrixLayout joint_matrix_layout_to_spv(
8585
}
8686
}
8787

88-
template<typename Ta, typename Tb, typename Tc>
88+
template<typename Ta, typename Tb, typename Tc, typename Td>
8989
constexpr uint32_t CalculateMatrixOperand() {
9090
if constexpr (std::is_same<Ta, sycl::ext::oneapi::bfloat16>::value &&
91-
std::is_same<Tb, sycl::ext::oneapi::bfloat16>::value &&
92-
std::is_same<Tc, float>::value)
91+
std::is_same<Tb, sycl::ext::oneapi::bfloat16>::value)
9392
return static_cast<uint32_t>(
9493
__spv::MatrixOperands::MatrixAAndBBFloat16ComponentsINTEL);
94+
if constexpr (std::is_same<Tc, sycl::ext::oneapi::bfloat16>::value)
95+
return static_cast<uint32_t>(
96+
__spv::MatrixOperands::MatrixCBFloat16ComponentsINTEL);
97+
if constexpr (std::is_same<Td, sycl::ext::oneapi::bfloat16>::value)
98+
return static_cast<uint32_t>(
99+
__spv::MatrixOperands::MatrixResultBFloat16ComponentsINTEL);
95100
if constexpr (std::is_signed<Ta>::value && std::is_unsigned<Tb>::value)
96101
return static_cast<uint32_t>(
97102
__spv::MatrixOperands::MatrixASignedComponentsKHR);

sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -431,8 +431,7 @@ template <typename Group, typename Ta, typename Tb, typename Tc, typename Td,
431431
sycl::detail::convertTypeToMatrixTypeString<Tc>(),
432432
sycl::detail::convertTypeToMatrixTypeString<Td>(), M, K, N)]]
433433
#endif // defined(__SYCL_DEVICE_ONLY__)
434-
inline __SYCL_ALWAYS_INLINE void
435-
joint_matrix_mad(
434+
inline __SYCL_ALWAYS_INLINE void joint_matrix_mad(
436435
Group,
437436
joint_matrix<Group, Td, use::accumulator, M, N,
438437
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D,
@@ -462,9 +461,9 @@ joint_matrix_mad(
462461
}
463462
#else
464463
constexpr uint32_t MatrixOperand =
465-
sycl::detail::CalculateMatrixOperand<Ta, Tb, Tc>();
466-
D.spvm =
467-
__spirv_CooperativeMatrixMulAddKHR(A.spvm, B.spvm, C.spvm, MatrixOperand);
464+
sycl::detail::CalculateMatrixOperand<Ta, Tb, Tc, Td>();
465+
D.spvm = __spirv_CooperativeMatrixMulAddKHR<Ta, Tb, Tc, Td>(
466+
A.spvm, B.spvm, C.spvm, MatrixOperand);
468467
#endif // defined(__NVPTX__)
469468
#else
470469
std::ignore = A;
@@ -489,10 +488,18 @@ void joint_matrix_copy(
489488
using storage_element_type =
490489
typename oneapi::detail::jm_type_interpretation_helper_trait<
491490
T2>::storage_element_type;
491+
using src_storage_element_type =
492+
typename oneapi::detail::jm_type_interpretation_helper_trait<
493+
T1>::storage_element_type;
494+
492495
auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, src);
493496
auto wi_data_dst = sycl::ext::oneapi::detail::get_wi_data(sg, dst);
494497
for (int i = 0; i < wi_data_c.length(); i++) {
495-
wi_data_dst[i] = static_cast<storage_element_type>(wi_data_c[i]);
498+
if constexpr (std::is_same_v<T1, half>) {
499+
wi_data_dst[i] = static_cast<storage_element_type>(static_cast<src_storage_element_type>(wi_data_c[i]));
500+
} else {
501+
wi_data_dst[i] = static_cast<storage_element_type>(wi_data_c[i]);
502+
}
496503
}
497504
#endif // defined(__NVPTX__)
498505
#else

sycl/test-e2e/Matrix/Inputs/common.hpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K,
6767
for (unsigned int n = 0; n < N; n++) {
6868
int c_ind = transpose_c ? (n * M + m) : m * N + n;
6969
Tc acc = *(C + c_ind);
70-
70+
float tmp = 0.f;
7171
for (unsigned int k = 0; k < K; k++) {
7272
int a_ind = colmajor_a ? (k * M + m) : m * K + k;
7373
int b_ind = colmajor_b ? (n * K + k) : k * N + n;
@@ -80,6 +80,8 @@ void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K,
8080
acc += make_fp32(va[i]) * make_fp32(vb[i]);
8181
else if constexpr (std::is_same_v<Ta, sycl::half>)
8282
acc += (float)va[i] * (float)vb[i];
83+
else if constexpr (std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, bfloat16>)
84+
tmp += (float)va[i] * (float)vb[i];
8385
else if constexpr (std::is_same_v<Ta, float> &&
8486
std::is_same_v<Tc, float> ||
8587
std::is_integral_v<Ta> && std::is_integral_v<Tc> ||
@@ -92,6 +94,8 @@ void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K,
9294
assert(false && "Unsupported type in matrix_multiply_ref.");
9395
}
9496
}
97+
if constexpr (std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, bfloat16>)
98+
acc += (bfloat16)tmp;
9599

96100
if constexpr (!std::is_same_v<F, std::nullptr_t>) {
97101
lambda(acc);
@@ -184,6 +188,7 @@ bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) {
184188
for (int j = 0; j < cols; j++) {
185189
if constexpr (!exact && (std::is_same_v<T1, float> ||
186190
std::is_same_v<T1, bfloat16> ||
191+
std::is_same_v<T1, half> ||
187192
(std::is_same_v<T1, double> &&
188193
std::is_same_v<T2, double>))) {
189194
float diff = std::fabs(src[i * cols + j] - (T1)ref[i * cols + j]);
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
//===---joint_matrix_16bit_impl.hpp - DPC++ joint_matrix----------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
template <typename Tab, typename TAcc, typename TResult, size_t TM, size_t TN,
10+
size_t TK, layout B_layout>
11+
class imatrix;
12+
13+
template <typename Tab, typename TAcc, typename TResult, size_t M, size_t N,
14+
size_t K, size_t TM, size_t TN, size_t TK, layout B_layout, size_t VF>
15+
void matrix_multiply(big_matrix<TResult, M, N> &D, big_matrix<TAcc, M, N> &C,
16+
big_matrix<Tab, M, K> &A, big_matrix<Tab, K / VF, N * VF> &B) {
17+
size_t NDRangeM = M / TM;
18+
size_t NDRangeN = N / TN;
19+
buffer<Tab, 2> bufA(A.get_data(), range<2>(M, K));
20+
buffer<Tab, 2> bufB(B.get_data(), range<2>(K, N));
21+
buffer<TAcc, 2> bufC((TAcc *)C.get_data(), range<2>(M, N));
22+
buffer<TResult, 2> bufD((TResult *)D.get_data(), range<2>(M, N));
23+
queue q;
24+
size_t sg_size = get_sg_size<imatrix<Tab, TAcc, TResult, TM, TN, TK, B_layout>>(q);
25+
26+
q.submit([&](handler &cgh) {
27+
accessor accA{bufA, cgh};
28+
accessor accB{bufB, cgh};
29+
accessor accC{bufC, cgh};
30+
accessor accD{bufD, cgh};
31+
32+
cgh.parallel_for<imatrix<Tab, TAcc, TResult, TM, TN, TK, B_layout>>(
33+
nd_range<2>({NDRangeM, NDRangeN * sg_size}, {1, 1 * sg_size}),
34+
[=](nd_item<2> spmd_item)
35+
#ifdef SG_SZ
36+
[[sycl::reqd_sub_group_size(SG_SZ)]]
37+
#endif
38+
{
39+
// The submatrix API has to be accessed by all the workitems in a
40+
// subgroup these functions will be called once by the subgroup no
41+
// code divergence between the workitems
42+
const auto global_idx = spmd_item.get_global_id(0);
43+
const auto global_idy = spmd_item.get_global_id(1);
44+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
45+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
46+
47+
sub_group sg = spmd_item.get_sub_group();
48+
joint_matrix<sub_group, Tab, use::a, TM, TK, layout::row_major> sub_a;
49+
joint_matrix<sub_group, Tab, use::b, TK, TN, B_layout> sub_b;
50+
joint_matrix<sub_group, TAcc, use::accumulator, TM, TN> sub_c;
51+
joint_matrix<sub_group, TResult, use::accumulator, TM, TN> sub_d;
52+
53+
joint_matrix_load(
54+
sg, sub_c,
55+
accC.template get_multi_ptr<access::decorated::no>() +
56+
(sg_startx * TM) * N + sg_starty / sg_size * TN,
57+
N, layout::row_major);
58+
59+
for (int k = 0; k < K / TK; k += 1) {
60+
joint_matrix_load(
61+
sg, sub_a,
62+
accA.template get_multi_ptr<access::decorated::no>() +
63+
(sg_startx * TM) * K + k * TK,
64+
K);
65+
joint_matrix_load(
66+
sg, sub_b,
67+
accB.template get_multi_ptr<access::decorated::no>() +
68+
(k * TK / VF) * (N * VF) + sg_starty / sg_size * TN * VF,
69+
N * VF);
70+
71+
joint_matrix_mad(sg, sub_d, sub_a, sub_b, sub_c);
72+
joint_matrix_copy(sg, sub_d, sub_c);
73+
}
74+
75+
joint_matrix_store(
76+
sg, sub_d,
77+
accD.template get_multi_ptr<access::decorated::no>() +
78+
(sg_startx * TM) * N + sg_starty / sg_size * TN,
79+
N, layout::row_major);
80+
}); // parallel for
81+
}).wait();
82+
}
83+
84+
template <typename Tab, typename TAcc, typename TResult, size_t TM, size_t TN,
85+
size_t TK, layout B_layout, size_t VF>
86+
void test() {
87+
std::cout << "Testing: " << TM << " x " << TN << " x " << TK
88+
<< " [TM x TN x TK]" << std::endl;
89+
90+
static constexpr size_t MATRIX_M = TM * 2;
91+
static constexpr size_t MATRIX_N = TN * 2;
92+
static constexpr size_t MATRIX_K = TK * 2;
93+
Tab A[MATRIX_M][MATRIX_K];
94+
Tab B[MATRIX_K / VF][MATRIX_N * VF];
95+
TAcc C[MATRIX_M][MATRIX_N];
96+
TResult D[MATRIX_M][MATRIX_N];
97+
TResult DRef[MATRIX_M][MATRIX_N];
98+
99+
matrix_rand<Tab>(MATRIX_M, MATRIX_K, (Tab *)A, Tab(1));
100+
matrix_rand<Tab>(MATRIX_K / VF, MATRIX_N * VF, (Tab *)B, Tab(1));
101+
102+
matrix_fill(MATRIX_M, MATRIX_N, (TAcc *)C, TAcc(1));
103+
matrix_fill(MATRIX_M, MATRIX_N, (TResult *)D, TResult(1));
104+
matrix_fill(MATRIX_M, MATRIX_N, (TResult *)DRef, TResult(1));
105+
106+
big_matrix<TAcc, MATRIX_M, MATRIX_N> MC((TAcc *)&C);
107+
big_matrix<TResult, MATRIX_M, MATRIX_N> MD((TResult *)&D);
108+
big_matrix<Tab, MATRIX_M, MATRIX_K> MA((Tab *)&A);
109+
big_matrix<Tab, MATRIX_K / VF, MATRIX_N * VF> MB((Tab *)&B);
110+
111+
matrix_multiply<Tab, TAcc, TResult, MATRIX_M, MATRIX_N, MATRIX_K, TM, TN, TK, B_layout, VF>(
112+
MD, MC, MA, MB);
113+
matrix_multiply_ref<Tab, Tab, TResult, VF>((Tab *)A, (Tab *)B, (TResult *)DRef, MATRIX_M,
114+
MATRIX_N, MATRIX_K / VF);
115+
assert(matrix_compare(MATRIX_M, MATRIX_N, (TResult *)D, (TResult *)DRef));
116+
}
117+
118+
template<typename T1, typename T2, size_t TM, size_t TN, size_t TK, layout B_layout, size_t VF> void test_combo() {
119+
test<T1, T1, T2, TM, TN, TK, B_layout, VF>();
120+
test<T1, T2, T1, TM, TN, TK, B_layout, VF>();
121+
test<T1, T1, T1, TM, TN, TK, B_layout, VF>();
122+
test<T1, T2, T2, TM, TN, TK, B_layout, VF>();
123+
}
124+
125+
126+
template <typename T1, typename T2, layout B_layout, size_t VF> void test_all() {
127+
test_combo<T1, T2, /*TM*/ 8, /*TN*/ 16, /*TK*/ 16, B_layout, VF>();
128+
test_combo<T1, T2, /*TM*/ 16, /*TN*/ 16, /*TK*/ 16, B_layout, VF>();
129+
test_combo<T1, T2, /*TM*/ 1, /*TN*/ 64, /*TK*/ 16, B_layout, VF>();
130+
test_combo<T1, T2, /*TM*/ 1, /*TN*/ 64, /*TK*/ 32, B_layout, VF>();
131+
test_combo<T1, T2, /*TM*/ 32, /*TN*/ 64, /*TK*/ 16, B_layout, VF>();
132+
test_combo<T1, T2, /*TM*/ 32, /*TN*/ 64, /*TK*/ 32, B_layout, VF>();
133+
}

0 commit comments

Comments
 (0)