Skip to content

[SYCL][Matrix] Add initial get_coord API #7851

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions sycl/include/CL/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
extern __DPCPP_SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
__spirv_CompositeConstruct(const T v);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL __ocl_vec_t<uint32_t, 2>
__spirv_JointMatrixGetElementCoordINTEL(
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *, size_t i);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
Expand Down
43 changes: 43 additions & 0 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ class wi_element {
NumCols, Layout> &M;
std::size_t idx;

template <typename T1, size_t NRows, size_t NCols,
sycl::ext::oneapi::experimental::matrix::use Use1,
sycl::ext::oneapi::experimental::matrix::layout Layout1,
typename Grp>
friend std::tuple<uint32_t, uint32_t>
get_coord(wi_element<T1, NRows, NCols, Use1, Layout1, Grp> &);

public:
wi_element(sycl::ext::oneapi::experimental::matrix::joint_matrix<
Group, T, Use, NumRows, NumCols, Layout> &Mat,
Expand Down Expand Up @@ -165,6 +172,13 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
Group, sycl::ext::oneapi::bfloat16, Use, NumRows, NumCols, Layout> &M;
std::size_t idx;

template <typename T1, size_t NRows, size_t NCols,
sycl::ext::oneapi::experimental::matrix::use Use1,
sycl::ext::oneapi::experimental::matrix::layout Layout1,
typename Grp>
friend std::tuple<uint32_t, uint32_t>
get_coord(wi_element<T1, NRows, NCols, Use1, Layout1, Grp> &);

public:
wi_element(sycl::ext::oneapi::experimental::matrix::joint_matrix<
Group, sycl::ext::oneapi::bfloat16, Use, NumRows, NumCols,
Expand Down Expand Up @@ -308,6 +322,35 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,

// End wi_element definition

template <typename T, size_t NumRows, size_t NumCols,
sycl::ext::oneapi::experimental::matrix::use Use,
sycl::ext::oneapi::experimental::matrix::layout Layout,
typename Group>
inline __SYCL_ALWAYS_INLINE std::tuple<uint32_t, uint32_t>
get_coord(wi_element<T, NumRows, NumCols, Use, Layout, Group> &we) {
#if defined(__SYCL_DEVICE_ONLY__)
__ocl_vec_t<uint32_t, 2> coord =
__spirv_JointMatrixGetElementCoordINTEL(we.M.spvm, we.idx);
const uint32_t row = coord[0];
const uint32_t col = coord[1];
return std::make_tuple(row, col);
#else
std::ignore = we;
throw runtime_error(
"get_coord is only supported on Intel XMX and AMX devices.",
PI_ERROR_INVALID_DEVICE);
#endif // __SYCL_DEVICE_ONLY__
}

// To make host compilation possible, here the argument is not a wi_element
// type, but just base data types e.g. float, int8 etc.
template <typename T>
inline __SYCL_ALWAYS_INLINE std::tuple<uint32_t, uint32_t> get_coord(T &we) {
std::ignore = we;
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
}

// Begin wi_data definition

template <typename Group, typename T,
Expand Down
204 changes: 204 additions & 0 deletions sycl/test/matrix/matrix-bfloat16-test-coord-basicA.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
// RUN: %clangxx -fsycl -O2 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 %s -o %t.out

#include <iostream>
#include <sycl/sycl.hpp>

using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;

#define SG_SZ 16

#define TM 8
#define TN SG_SZ
#define TK 32

template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
public:
T *mat;

public:
T *get_data() { return mat; }
void set_data(T *data) { mat = data; }
big_matrix(T *data) : mat(data) {}
};

template <typename T, size_t M, size_t K>
void sum_rows_ref(host_accessor<T, 2, access::mode::read_write> A,
host_accessor<int, 1, access::mode::read_write> sum_rows) {
int sum_rows_ref[M] = {0};
for (size_t i = 0; i < M; i++) {
for (size_t j = 0; j < K; j++) {
sum_rows_ref[i] += A[i][j];
}
auto diff = sum_rows[i] - sum_rows_ref[i];
assert(std::fabs(static_cast<int>(diff)) <=
std::numeric_limits<int>::epsilon());
}
}

// clang-format off
/*
Here's how the data is distributed among work items

0 0 0 0
/
/
1 1 1 1
/
/
2 2 2 2
/
/
3 3 3 3

W0 --> 0 0 1 1 2 2 3 3 .... 7 7
wi [0,0] -> i=0, [0, 0] wi [0,1] --> i=0, [0, 2] wi [0,15] --> i=0, [0, 30]
i=1, [0, 1] i=1, [0, 3] i=1, [0, 31]
i=2, [1, 0] i=2, [1, 2] i=2, [1, 30]
i=3, [1, 1] i=3, [1, 3] i=3, [1, 31]
i=4, [2, 0] i=4, [2, 2] ...
i=5, [2, 1] i=5, [2, 3]
... ....
i=14,[7, 0] i=14, [7, 2]
i=15,[7, 1] i=15, [7, 3] i=15, [7, 31]
*/
//clang-format on
std::tuple<uint32_t, uint32_t> get_coord_ref(int i, int wi_number) {
return std::make_tuple(i/2, ((i%2) + (wi_number*2)));
}

//clang-format off
/*
Here's how the distribution of the A matrix looks like for this test case

x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x
x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x
x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x
x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x
x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x
x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x
x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x
x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x
<--------------------------------- SG1 --------------------------------->

x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x
x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x
x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x
x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x
x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x
x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x
x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x
x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x
<0> <1> <2> <3> <4> <5> <6> <7> ..... WORK ITEMS

Each work item has 16 elements <8 rows and 2 cols of the original matrix>
the data_slice in holds the matrix elements in the following order:

0 0 0 0
/
/
1 1 1 1
/
/
2 2 2 2
/
/
3 3 3 3

W0 --> 0 0 1 1 2 2 3 3 .... 7 7
*/
//clang-format on
template <typename T, size_t M, size_t K>
void matrix_sum_rows(queue q, big_matrix<T, M, K> &A, nd_range<2> &r) {
buffer<int8_t, 2> bufA(A.get_data(), range<2>(M, K));
// size of vector is known because SG size of set by the user in this case
int sum_rows[M] = {0};
buffer<int> sum_rows_v(sum_rows, M); // there are total of M rows
q.submit([&](handler &cgh) {
auto accA = bufA.get_access<access::mode::read_write>(cgh);

auto v = sum_rows_v.get_access<access::mode::atomic>(cgh);
auto os = sycl::stream(100000, 6144, cgh);

cgh.parallel_for<class add_matrix>(
r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
const auto global_idx = spmd_item.get_global_id(0);
const auto global_idy = spmd_item.get_global_id(1);
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

ext::oneapi::sub_group sg = spmd_item.get_sub_group();

// TM = 8, TK = 32
joint_matrix<sub_group, int8_t, use::a, TM, TK, layout::row_major>
sub_a;

joint_matrix_load(
sg, sub_a, accA.get_pointer() + (global_idx * TM * K) + TK,
K);

// calculate sum of rows in sum_rows_v[8], there are 8 rows in sub_a
int32_t sum_local_rows[M] = {0}; // 8 local rows, M total
// sub_a has 8x32 elements, 16 elements per WI, 2 per WI per row
auto data = sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a);

size_t global_index; // Index into the result array that holds the sums.

// Keep track of rows handled in this WI
int32_t handled_rows[M] = {-1};

// each WI calculates local sum of rows
for (int i = 0; i < data.length(); ++i) {
// get the index of the element in the submatrix
auto data_item = data[i];
auto [row, col] = sycl::ext::intel::experimental::matrix::get_coord(data_item);
global_index = row + global_idx*TM;

sum_local_rows[global_index] += data[i];

handled_rows[global_index] = 1;
}

for (int j=0; j < M; j++) {
if (handled_rows[j] == 1) {
global_index = j;
sum_local_rows[global_index] = reduce_over_group(
sg, sum_local_rows[global_index],
sycl::plus<>());
// only Groups leader perform the global reduction
if (global_idy % SG_SZ == 0) {
atomic_fetch_add(v[global_index],
sum_local_rows[global_index]);
}
}
}
}); // parallel for
}).wait();
sum_rows_ref<T, M, K>(bufA.get_host_access(), sum_rows_v.get_host_access());
}


static constexpr size_t MATRIX_M = TM * 2;
static constexpr size_t MATRIX_K = TK * 2;
int8_t A[MATRIX_M][MATRIX_K];

int main() {
big_matrix<int8_t, MATRIX_M, MATRIX_K> MA((int8_t *)&A);

size_t NDRangeM = MATRIX_M / TM;
size_t NDRangeK = MATRIX_K / TK;
queue q;
nd_range<2> r({NDRangeM, NDRangeK * SG_SZ}, {1, 1 * SG_SZ});

for (int i = 0; i < MATRIX_M; i++) {
for (int j = 0; j < MATRIX_K; j++) {
A[i][j] = i;
}
}

matrix_sum_rows<int8_t, MATRIX_M, MATRIX_K>(q, MA, r);

std::cout << "Passed\n";

return 0;
}
Loading