Skip to content

[SYCL][CUDA][MATRIX] joint_matrix_bmad implementation #5363

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 163 additions & 1 deletion sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ enum class matrix_layout { row_major, col_major, packed_a, packed_b };

namespace precision {
class tf32 {};
class b1 {};
} // namespace precision

template <typename T, matrix_use Use, size_t Rows = sycl::dynamic_extent,
Expand Down Expand Up @@ -113,6 +114,9 @@ __SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, a, 8, 4, 1)
__SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, b, 4, 8, 1)
__SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, accumulator, 8, 8, 2)

// single-bit accumulator
__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int32_t, accumulator, 8, 8, 2)

#undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR

#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION(precision, use, M, N, type, \
Expand Down Expand Up @@ -153,6 +157,32 @@ __SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 8, int32_t, 2)
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 16, 16, int32_t, 4)
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 16, int32_t, 4)

// single-bit
template <matrix_layout Layout>
struct joint_matrix<
precision::b1, matrix_use::a, 8, 128, Layout, sycl::sub_group,
typename std::enable_if_t<Layout == matrix_layout::row_major ||
Layout == matrix_layout::col_major>> {
joint_matrix() {
static_assert((Layout == matrix_layout::row_major),
"For the matrix_use::a case, matrix_layout::row_major must "
"be used for Bitwise MAD");
};
int32_t data;
};

template <matrix_layout Layout>
struct joint_matrix<
precision::b1, matrix_use::b, 128, 8, Layout, sycl::sub_group,
typename std::enable_if_t<Layout == matrix_layout::row_major ||
Layout == matrix_layout::col_major>> {
joint_matrix() {
static_assert((Layout == matrix_layout::col_major),
"For the matrix_use::b case, matrix_layout::col_major must "
"be used for Bitwise MAD");
};
int32_t data;
};
#undef __SYCL_JOINT_MATRIX_OVERLOAD

template <typename Group, typename T, matrix_use Use, size_t NumRows,
Expand Down Expand Up @@ -342,6 +372,9 @@ struct joint_matrix_load_impl<
} else if constexpr (NumRows == 32 && NumCols == 8) {
__imma_m32n8k16_ld_c(destptr, src.get(), stride,
get_layout_id<Layout>());
} else if constexpr (NumRows == 8 && NumCols == 8) {
__bmma_m8n8k128_ld_c(destptr, src.get(), stride,
get_layout_id<Layout>());
}
} else if constexpr (std::is_same<T, float>::value) {
if constexpr (std::is_same<S, float>::value) {
Expand Down Expand Up @@ -381,6 +414,16 @@ struct joint_matrix_load_impl<
matrix_use::accumulator) {
__dmma_m8n8k4_ld_c(dstptr, src.get(), stride, get_layout_id<Layout>());
}
} else if constexpr (std::is_same<S, sycl::ext::oneapi::experimental::
matrix::precision::b1>::value) {
int32_t *tileptr = reinterpret_cast<int32_t *>(src.get());
if constexpr (NumRows == 8 && NumCols == 128) {
__bmma_m8n8k128_ld_a_b1(&res.data, tileptr, stride,
get_layout_id<Layout>());
} else if constexpr (NumRows == 128 && NumCols == 8) {
__bmma_m8n8k128_ld_b_b1(&res.data, tileptr, stride,
get_layout_id<Layout>());
}
}
}
};
Expand Down Expand Up @@ -458,6 +501,10 @@ struct joint_matrix_store_impl<
__dmma_m8n8k4_st_c_f64(dst.get(),
reinterpret_cast<double *>(&src.wi_marray), stride,
get_layout_id<Layout>());
} else if constexpr (std::is_same<T, int32_t>::value) {
__bmma_m8n8k128_st_c_i32(dst.get(),
reinterpret_cast<int32_t *>(&src.wi_marray),
stride, get_layout_id<Layout>());
}
}
};
Expand Down Expand Up @@ -486,6 +533,33 @@ struct joint_matrix_mad_impl {
C);
};

template <std::size_t M, std::size_t K, std::size_t N,
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC,
class BinaryOperation, typename Cond = void>
struct joint_matrix_bmad_impl {
sycl::ext::oneapi::experimental::matrix::joint_matrix<
int32_t, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
M, N, LayoutC, sycl::sub_group>
bmad(sycl::ext::oneapi::experimental::matrix::joint_matrix<
sycl::ext::oneapi::experimental::matrix::precision::b1,
sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K,
sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major,
sycl::sub_group>
A,
sycl::ext::oneapi::experimental::matrix::joint_matrix<
sycl::ext::oneapi::experimental::matrix::precision::b1,
sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N,
sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major,
sycl::sub_group>
B,
sycl::ext::oneapi::experimental::matrix::joint_matrix<
int32_t,
sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M,
N, LayoutC, sycl::sub_group>
C,
BinaryOperation Op);
};

template <sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutA,
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutB>
constexpr int get_layout_pair_id();
Expand Down Expand Up @@ -686,6 +760,63 @@ struct joint_matrix_mad_impl<
};
#endif // __cplusplus >= 201703L

#if __cplusplus >= 201703L // if constexpr usage
template <std::size_t M, std::size_t K, std::size_t N,
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC,
class BinaryOperation>
struct joint_matrix_bmad_impl<
M, K, N, LayoutC, BinaryOperation,
typename std::enable_if_t<(
LayoutC ==
sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major ||
LayoutC == sycl::ext::oneapi::experimental::matrix::matrix_layout::
col_major)>> {
sycl::ext::oneapi::experimental::matrix::joint_matrix<
int32_t, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
M, N, LayoutC, sycl::sub_group>
bmad(sycl::ext::oneapi::experimental::matrix::joint_matrix<
sycl::ext::oneapi::experimental::matrix::precision::b1,
sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K,
sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major,
sycl::sub_group>
A,
sycl::ext::oneapi::experimental::matrix::joint_matrix<
sycl::ext::oneapi::experimental::matrix::precision::b1,
sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N,
sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major,
sycl::sub_group>
B,
sycl::ext::oneapi::experimental::matrix::joint_matrix<
int32_t,
sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M,
N, LayoutC, sycl::sub_group>
C,
BinaryOperation Op) {
sycl::ext::oneapi::experimental::matrix::joint_matrix<
int32_t,
sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, N,
LayoutC, sycl::sub_group>
D;

if constexpr (std::is_same<
BinaryOperation,
sycl::bit_and<sycl::ext::oneapi::experimental::matrix::
precision::b1>>::value) {
__bmma_m8n8k128_mma_and_popc_b1(
reinterpret_cast<int32_t *>(&D.wi_marray), &A.data, &B.data,
reinterpret_cast<int32_t *>(&C.wi_marray), 1);
} else if constexpr (std::is_same<
BinaryOperation,
sycl::bit_xor<sycl::ext::oneapi::experimental::
matrix::precision::b1>>::value) {
__bmma_m8n8k128_mma_xor_popc_b1(
reinterpret_cast<int32_t *>(&D.wi_marray), &A.data, &B.data,
reinterpret_cast<int32_t *>(&C.wi_marray), 1);
}
return D;
}
};
#endif // __cplusplus >= 201703L
} // namespace detail

namespace experimental {
Expand All @@ -696,7 +827,9 @@ template <typename Group, typename S, typename T, matrix_use Use,
access::address_space Space,
std::enable_if_t<std::is_same<S, T>::value ||
(std::is_same<S, precision::tf32>::value &&
std::is_same<T, float>::value),
std::is_same<T, float>::value) ||
(std::is_same<S, precision::b1>::value &&
std::is_same<T, uint32_t>::value),
bool> = true>
void joint_matrix_load(
Group sg, joint_matrix<S, Use, NumRows, NumCols, Layout, Group> &res,
Expand Down Expand Up @@ -777,6 +910,35 @@ float round_to_tf32(float a) {
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
}

template <typename Group, std::size_t M, std::size_t K, std::size_t N,
matrix_layout LayoutC, class BinaryOperation>
joint_matrix<int32_t, matrix_use::accumulator, M, N, LayoutC, Group>
joint_matrix_bmad(
Group sg,
joint_matrix<precision::b1, matrix_use::a, M, K, matrix_layout::row_major,
Group>
A,
joint_matrix<precision::b1, matrix_use::b, K, N, matrix_layout::col_major,
Group>
B,
joint_matrix<int32_t, matrix_use::accumulator, M, N, LayoutC, Group> C,
BinaryOperation Op) {
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
return sycl::ext::oneapi::detail::joint_matrix_bmad_impl<M, K, N, LayoutC,
BinaryOperation>{}
.bmad(A, B, C, Op);
#else
std::ignore = sg;
std::ignore = A;
std::ignore = B;
std::ignore = C;
std::ignore = Op;
throw runtime_error("joint_matrix_bmad is "
"only supported by CUDA devices",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
}

} // namespace matrix
} // namespace experimental
} // namespace oneapi
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// REQUIRES: cuda

// RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s

#include <sycl/sycl.hpp>

using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;

// M, N, (K * 32) define the sizes of dimensions of the three matrix types (a,
// b, accumulator) used per subgroup operation.
constexpr int M = 8; // number of rows of accumulator,
// number of cols of b.
constexpr int N = 8; // number of cols of accumulator,
// number of rows of a.
constexpr int K = 128; // number of cols of a/number of rows of b divided by 32

// Each bit of each uint32_t A/B array element is an element of a single-bit
// matrix. joint_matrix_bmad performs Binary Dot Products on these matrices (see
// M. Rastegari et al. Computer Vision – ECCV 2016, 525-542 and A. Li et al.
// IEEE Transactions on Parallel and Distributed Systems, 32(7):1878-1891,
// 2021))
uint32_t A[M * (K / 32)];
uint32_t B[(K / 32) * N];
int32_t C[M * N];
int32_t D[M * N];

int main() {

buffer<uint32_t, 1> bufA(A, range<1>(M * (K / 32)));
buffer<uint32_t, 1> bufB(B, range<1>((K / 32) * N));
buffer<int32_t, 1> bufC(C, range<1>(M * N));
buffer<int32_t, 1> bufD(D, range<1>(M * N));

queue q;

q.submit([&](handler &cgh) {
auto accC = bufC.get_access<access::mode::read_write>(cgh);
auto accA = bufA.get_access<access::mode::read_write>(cgh);
auto accB = bufB.get_access<access::mode::read_write>(cgh);
auto accD = bufD.get_access<access::mode::read_write>(cgh);

cgh.parallel_for<class row_col>(
nd_range<2>({1, 32}, {1, 32}),
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
sycl::sub_group sg = item.get_sub_group();

joint_matrix<int32_t, matrix_use::accumulator, M, N,
matrix_layout::row_major>
sub_c;

joint_matrix<precision::b1, matrix_use::a, M, K,
matrix_layout::row_major>
sub_a;

joint_matrix<precision::b1, matrix_use::b, K, N,
matrix_layout::col_major>
sub_b;

//CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n8k128.load.c.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_accC, i32 8)
joint_matrix_load(sg, sub_c, accC.get_pointer(), N);
//CHECK: tail call i32 @llvm.nvvm.wmma.m8n8k128.load.a.row.stride.b1.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 128)
joint_matrix_load(sg, sub_a, accA.get_pointer(), K);
//CHECK: tail call i32 @llvm.nvvm.wmma.m8n8k128.load.b.col.stride.b1.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 128)
joint_matrix_load(sg, sub_b, accB.get_pointer(), K);
//CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n8k128.mma.xor.popc.row.col.b1(i32 %3, i32 %4, i32 %1, i32 %2)
sub_c = joint_matrix_bmad(sg, sub_a, sub_b, sub_c,
sycl::bit_xor<precision::b1>());
//CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n8k128.mma.and.popc.row.col.b1(i32 %3, i32 %4, i32 %6, i32 %7)
sub_c = joint_matrix_bmad(sg, sub_a, sub_b, sub_c,
sycl::bit_and<precision::b1>());
//CHECK: tail call void @llvm.nvvm.wmma.m8n8k128.store.d.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_accD, i32 %9, i32 %10, i32 8)
joint_matrix_store(sg, sub_c, accD.get_pointer(), N);
});
});

return 0;
};