Skip to content

[Joint Matrix] Enable different accumulator and output types in spirv. Add tests to cover bfloat16 and half floating-point sizes. #17502

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: sycl
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions sycl/include/sycl/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,15 @@ extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixStoreCheckedINTEL(
std::size_t Stride, size_t Height, size_t Width, size_t CoordX,
size_t CoordY, __spv::MatrixLayout Layout = L, int MemOperand = 0);

template <typename TA, typename TB, typename TC, std::size_t M, std::size_t K,
std::size_t N, __spv::MatrixUse UA, __spv::MatrixUse UB,
__spv::MatrixUse UC,
template <typename TA, typename TB, typename TC, typename TD, std::size_t M,
std::size_t K, std::size_t N, __spv::MatrixUse UA,
__spv::MatrixUse UB, __spv::MatrixUse UC,
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_CooperativeMatrixKHR<TC, S, M, N, UC> *
__spv::__spirv_CooperativeMatrixKHR<TD, S, M, N, UC> *
__spirv_CooperativeMatrixMulAddKHR(
__spv::__spirv_CooperativeMatrixKHR<TA, S, M, K, UA> *A,
__spv::__spirv_CooperativeMatrixKHR<TB, S, K, N, UB> *B,
Expand Down
34 changes: 34 additions & 0 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,25 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
#endif // __SYCL_DEVICE_ONLY__
}

operator float() {
#ifdef __SYCL_DEVICE_ONLY__
sycl::ext::oneapi::bfloat16 *ExtractP =
__spirv_AccessChain<sycl::ext::oneapi::bfloat16,
sycl::ext::oneapi::bfloat16, NumRows, NumCols,
spv_matrix_use_traits<Use>::value,
spv_scope_traits<Group>::value>(&M.spvm, idx);
union {
uint16_t intStorage;
sycl::ext::oneapi::bfloat16 floatValue;
};
floatValue = *ExtractP;
return __spirv_ConvertBF16ToFINTEL(intStorage);
#else
throw exception(make_error_code(errc::runtime),
"joint matrix is not supported on host.");
#endif // __SYCL_DEVICE_ONLY__
}

explicit operator bool() {
#ifdef __SYCL_DEVICE_ONLY__
sycl::ext::oneapi::bfloat16 *ExtractP =
Expand Down Expand Up @@ -295,6 +314,21 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
#endif // __SYCL_DEVICE_ONLY__
}

wi_element &operator=(const float &rhs) {
#ifdef __SYCL_DEVICE_ONLY__
float *InsertP =
__spirv_AccessChain<float, float, NumRows, NumCols,
spv_matrix_use_traits<Use>::value,
spv_scope_traits<Group>::value>(&M.spvm, idx);
*InsertP = rhs;
return *this;
#else
(void)rhs;
throw exception(make_error_code(errc::runtime),
"joint matrix is not supported on host.");
#endif // __SYCL_DEVICE_ONLY__
}

wi_element &operator=(const wi_element<sycl::ext::oneapi::bfloat16, NumRows,
NumCols, Use, Layout, Group> &rhs) {
#ifdef __SYCL_DEVICE_ONLY__
Expand Down
30 changes: 15 additions & 15 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,26 +85,26 @@ extern "C" constexpr __spv::MatrixLayout joint_matrix_layout_to_spv(
}
}

template<typename Ta, typename Tb, typename Tc>
template <typename Ta, typename Tb, typename Tc, typename Td>
constexpr uint32_t CalculateMatrixOperand() {
uint32_t returnValue = 0x00;
if constexpr (std::is_same<Ta, sycl::ext::oneapi::bfloat16>::value &&
std::is_same<Tb, sycl::ext::oneapi::bfloat16>::value &&
std::is_same<Tc, float>::value)
return static_cast<uint32_t>(
std::is_same<Tb, sycl::ext::oneapi::bfloat16>::value)
returnValue += static_cast<uint32_t>(
__spv::MatrixOperands::MatrixAAndBBFloat16ComponentsINTEL);
if constexpr (std::is_signed<Ta>::value && std::is_unsigned<Tb>::value)
return static_cast<uint32_t>(
if constexpr (std::is_same<Tc, sycl::ext::oneapi::bfloat16>::value)
returnValue += static_cast<uint32_t>(
__spv::MatrixOperands::MatrixCBFloat16ComponentsINTEL);
if constexpr (std::is_same<Td, sycl::ext::oneapi::bfloat16>::value)
returnValue += static_cast<uint32_t>(
__spv::MatrixOperands::MatrixResultBFloat16ComponentsINTEL);
if constexpr (std::is_signed<Ta>::value)
returnValue += static_cast<uint32_t>(
__spv::MatrixOperands::MatrixASignedComponentsKHR);
if constexpr (std::is_unsigned<Ta>::value && std::is_signed<Tb>::value)
return static_cast<uint32_t>(
if constexpr (std::is_signed<Tb>::value)
returnValue += static_cast<uint32_t>(
__spv::MatrixOperands::MatrixBSignedComponentsKHR);
if constexpr (std::is_signed<Ta>::value && std::is_signed<Tb>::value) {
return static_cast<uint32_t>(
__spv::MatrixOperands::MatrixASignedComponentsKHR) +
static_cast<uint32_t>(
__spv::MatrixOperands::MatrixBSignedComponentsKHR);
}
return 0;
return returnValue;
}

} // namespace detail
Expand Down
24 changes: 18 additions & 6 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,7 @@ template <typename Group, typename Ta, typename Tb, typename Tc, typename Td,
sycl::detail::convertTypeToMatrixTypeString<Tc>(),
sycl::detail::convertTypeToMatrixTypeString<Td>(), M, K, N)]]
#endif // defined(__SYCL_DEVICE_ONLY__)
inline __SYCL_ALWAYS_INLINE void
joint_matrix_mad(
inline __SYCL_ALWAYS_INLINE void joint_matrix_mad(
Group,
joint_matrix<Group, Td, use::accumulator, M, N,
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D,
Expand Down Expand Up @@ -462,9 +461,9 @@ joint_matrix_mad(
}
#else
constexpr uint32_t MatrixOperand =
sycl::detail::CalculateMatrixOperand<Ta, Tb, Tc>();
D.spvm =
__spirv_CooperativeMatrixMulAddKHR(A.spvm, B.spvm, C.spvm, MatrixOperand);
sycl::detail::CalculateMatrixOperand<Ta, Tb, Tc, Td>();
D.spvm = __spirv_CooperativeMatrixMulAddKHR<Ta, Tb, Tc, Td>(
A.spvm, B.spvm, C.spvm, MatrixOperand);
#endif // defined(__NVPTX__)
#else
std::ignore = A;
Expand All @@ -489,10 +488,23 @@ void joint_matrix_copy(
using storage_element_type =
typename oneapi::detail::jm_type_interpretation_helper_trait<
T2>::storage_element_type;
using src_storage_element_type =
typename oneapi::detail::jm_type_interpretation_helper_trait<
T1>::storage_element_type;

auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, src);
auto wi_data_dst = sycl::ext::oneapi::detail::get_wi_data(sg, dst);
for (int i = 0; i < wi_data_c.length(); i++) {
wi_data_dst[i] = static_cast<storage_element_type>(wi_data_c[i]);
if constexpr (std::is_same_v<T1, sycl::half>) {
// Special case for SRC type sycl:half since we can't
// cast directly from wi_element(typed half) to other type.
// first cast is from wi_element to half (T1).
// second cast is from half to dst type (T2).
wi_data_dst[i] = static_cast<storage_element_type>(
static_cast<src_storage_element_type>(wi_data_c[i]));
} else {
wi_data_dst[i] = static_cast<storage_element_type>(wi_data_c[i]);
}
}
#endif // defined(__NVPTX__)
#else
Expand Down
17 changes: 12 additions & 5 deletions sycl/test-e2e/Matrix/Inputs/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K,
for (unsigned int n = 0; n < N; n++) {
int c_ind = transpose_c ? (n * M + m) : m * N + n;
Tc acc = *(C + c_ind);

float tmp = 0.f;
for (unsigned int k = 0; k < K; k++) {
int a_ind = colmajor_a ? (k * M + m) : m * K + k;
int b_ind = colmajor_b ? (n * K + k) : k * N + n;
Expand All @@ -80,6 +80,9 @@ void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K,
acc += make_fp32(va[i]) * make_fp32(vb[i]);
else if constexpr (std::is_same_v<Ta, sycl::half>)
acc += (float)va[i] * (float)vb[i];
else if constexpr (std::is_same_v<Ta, bfloat16> &&
std::is_same_v<Tc, bfloat16>)
tmp += (float)va[i] * (float)vb[i];
else if constexpr (std::is_same_v<Ta, float> &&
std::is_same_v<Tc, float> ||
std::is_integral_v<Ta> && std::is_integral_v<Tc> ||
Expand All @@ -92,6 +95,9 @@ void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K,
assert(false && "Unsupported type in matrix_multiply_ref.");
}
}
if constexpr (std::is_same_v<Ta, bfloat16> &&
std::is_same_v<Tc, bfloat16>)
acc += (bfloat16)tmp;

if constexpr (!std::is_same_v<F, std::nullptr_t>) {
lambda(acc);
Expand Down Expand Up @@ -182,10 +188,11 @@ template <typename T1, typename T2, bool exact = false>
bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) {
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
if constexpr (!exact && (std::is_same_v<T1, float> ||
std::is_same_v<T1, bfloat16> ||
(std::is_same_v<T1, double> &&
std::is_same_v<T2, double>))) {
if constexpr (!exact &&
(std::is_same_v<T1, float> ||
std::is_same_v<T1, bfloat16> || std::is_same_v<T1, half> ||
(std::is_same_v<T1, double> &&
std::is_same_v<T2, double>))) {
float diff = std::fabs(src[i * cols + j] - (T1)ref[i * cols + j]);
if (diff > FLOAT_EPSILON || std::isnan(src[i * cols + j])) {
std::cerr << "Incorrect result in matrix. "
Expand Down
138 changes: 138 additions & 0 deletions sycl/test-e2e/Matrix/Inputs/joint_matrix_16bit_impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
//===---joint_matrix_16bit_impl.hpp - DPC++ joint_matrix----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

template <typename Tab, typename TAcc, typename TResult, size_t TM, size_t TN,
size_t TK, layout B_layout>
class imatrix;

template <typename Tab, typename TAcc, typename TResult, size_t M, size_t N,
size_t K, size_t TM, size_t TN, size_t TK, layout B_layout, size_t VF>
void matrix_multiply(big_matrix<TResult, M, N> &D, big_matrix<TAcc, M, N> &C,
big_matrix<Tab, M, K> &A,
big_matrix<Tab, K / VF, N * VF> &B) {
size_t NDRangeM = M / TM;
size_t NDRangeN = N / TN;
buffer<Tab, 2> bufA(A.get_data(), range<2>(M, K));
buffer<Tab, 2> bufB(B.get_data(), range<2>(K, N));
buffer<TAcc, 2> bufC((TAcc *)C.get_data(), range<2>(M, N));
buffer<TResult, 2> bufD((TResult *)D.get_data(), range<2>(M, N));
queue q;
size_t sg_size =
get_sg_size<imatrix<Tab, TAcc, TResult, TM, TN, TK, B_layout>>(q);

q.submit([&](handler &cgh) {
accessor accA{bufA, cgh};
accessor accB{bufB, cgh};
accessor accC{bufC, cgh};
accessor accD{bufD, cgh};

cgh.parallel_for<imatrix<Tab, TAcc, TResult, TM, TN, TK, B_layout>>(
nd_range<2>({NDRangeM, NDRangeN * sg_size}, {1, 1 * sg_size}),
[=](nd_item<2> spmd_item)
#ifdef SG_SZ
[[sycl::reqd_sub_group_size(SG_SZ)]]
#endif
{
// The submatrix API has to be accessed by all the workitems in a
// subgroup these functions will be called once by the subgroup no
// code divergence between the workitems
const auto global_idx = spmd_item.get_global_id(0);
const auto global_idy = spmd_item.get_global_id(1);
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

sub_group sg = spmd_item.get_sub_group();
joint_matrix<sub_group, Tab, use::a, TM, TK, layout::row_major>
sub_a;
joint_matrix<sub_group, Tab, use::b, TK, TN, B_layout> sub_b;
joint_matrix<sub_group, TAcc, use::accumulator, TM, TN> sub_c;
joint_matrix<sub_group, TResult, use::accumulator, TM, TN> sub_d;

joint_matrix_load(
sg, sub_c,
accC.template get_multi_ptr<access::decorated::no>() +
(sg_startx * TM) * N + sg_starty / sg_size * TN,
N, layout::row_major);

for (int k = 0; k < K / TK; k += 1) {
joint_matrix_load(
sg, sub_a,
accA.template get_multi_ptr<access::decorated::no>() +
(sg_startx * TM) * K + k * TK,
K);
joint_matrix_load(
sg, sub_b,
accB.template get_multi_ptr<access::decorated::no>() +
(k * TK / VF) * (N * VF) + sg_starty / sg_size * TN * VF,
N * VF);

joint_matrix_mad(sg, sub_d, sub_a, sub_b, sub_c);
joint_matrix_copy(sg, sub_d, sub_c);
}

joint_matrix_store(
sg, sub_d,
accD.template get_multi_ptr<access::decorated::no>() +
(sg_startx * TM) * N + sg_starty / sg_size * TN,
N, layout::row_major);
}); // parallel for
}).wait();
}

template <typename Tab, typename TAcc, typename TResult, size_t TM, size_t TN,
size_t TK, layout B_layout, size_t VF>
void test() {
std::cout << "Testing: " << TM << " x " << TN << " x " << TK
<< " [TM x TN x TK]" << std::endl;

static constexpr size_t MATRIX_M = TM * 2;
static constexpr size_t MATRIX_N = TN * 2;
static constexpr size_t MATRIX_K = TK * 2;
Tab A[MATRIX_M][MATRIX_K];
Tab B[MATRIX_K / VF][MATRIX_N * VF];
TAcc C[MATRIX_M][MATRIX_N];
TResult D[MATRIX_M][MATRIX_N];
TResult DRef[MATRIX_M][MATRIX_N];

matrix_rand<Tab>(MATRIX_M, MATRIX_K, (Tab *)A, Tab(1));
matrix_rand<Tab>(MATRIX_K / VF, MATRIX_N * VF, (Tab *)B, Tab(1));

matrix_fill(MATRIX_M, MATRIX_N, (TAcc *)C, TAcc(1));
matrix_fill(MATRIX_M, MATRIX_N, (TResult *)D, TResult(1));
matrix_fill(MATRIX_M, MATRIX_N, (TResult *)DRef, TResult(1));

big_matrix<TAcc, MATRIX_M, MATRIX_N> MC((TAcc *)&C);
big_matrix<TResult, MATRIX_M, MATRIX_N> MD((TResult *)&D);
big_matrix<Tab, MATRIX_M, MATRIX_K> MA((Tab *)&A);
big_matrix<Tab, MATRIX_K / VF, MATRIX_N * VF> MB((Tab *)&B);

matrix_multiply<Tab, TAcc, TResult, MATRIX_M, MATRIX_N, MATRIX_K, TM, TN, TK,
B_layout, VF>(MD, MC, MA, MB);
matrix_multiply_ref<Tab, Tab, TResult, VF>(
(Tab *)A, (Tab *)B, (TResult *)DRef, MATRIX_M, MATRIX_N, MATRIX_K / VF);
assert(matrix_compare(MATRIX_M, MATRIX_N, (TResult *)D, (TResult *)DRef));
}

template <typename TLow, typename THigh, size_t TM, size_t TN, size_t TK,
layout B_layout, size_t VF>
void test_combo() {
test<TLow, TLow, THigh, TM, TN, TK, B_layout, VF>();
test<TLow, THigh, TLow, TM, TN, TK, B_layout, VF>();
test<TLow, TLow, TLow, TM, TN, TK, B_layout, VF>();
test<TLow, THigh, THigh, TM, TN, TK, B_layout, VF>();
}

template <typename TLow, typename THigh, layout B_layout, size_t VF>
void test_all() {
test_combo<TLow, THigh, /*TM*/ 8, /*TN*/ 16, /*TK*/ 16, B_layout, VF>();
test_combo<TLow, THigh, /*TM*/ 16, /*TN*/ 16, /*TK*/ 16, B_layout, VF>();
test_combo<TLow, THigh, /*TM*/ 1, /*TN*/ 64, /*TK*/ 16, B_layout, VF>();
test_combo<TLow, THigh, /*TM*/ 1, /*TN*/ 64, /*TK*/ 32, B_layout, VF>();
test_combo<TLow, THigh, /*TM*/ 32, /*TN*/ 64, /*TK*/ 16, B_layout, VF>();
test_combo<TLow, THigh, /*TM*/ 32, /*TN*/ 64, /*TK*/ 32, B_layout, VF>();
}
Loading
Loading