-
Notifications
You must be signed in to change notification settings - Fork 790
[SYCL][Matrix] Enable wi_slice for joint_matrix #4979
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
16c0bbe
[Matrix] Enable wi_slice for joint_matrix
yubingex007-a11y 4803165
Change implementation by using wi_elem
yubingex007-a11y 9a867ad
Fix preCI's fail and address dounia's comments
yubingex007-a11y d236a56
Remove useless wi_slice_t
yubingex007-a11y a894d04
Remove useless comments
yubingex007-a11y 93c3e22
Address Dounia&Alexey's comments
yubingex007-a11y 3ffe959
Fix clang-format issue
yubingex007-a11y 1d2dab3
choose a different number for elemwice multiplication
yubingex007-a11y File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,6 +44,11 @@ template <int D> struct spv_scope_traits<sycl::group<D>> { | |
constexpr static auto value = __spv::Scope::Workgroup; | ||
}; | ||
|
||
template <typename T, size_t NumRows, size_t NumCols, | ||
matrix_layout Layout = matrix_layout::row_major, | ||
typename Group = sycl::sub_group> | ||
class wi_slice; | ||
|
||
template <typename T, size_t NumRows, size_t NumCols, | ||
matrix_layout Layout = matrix_layout::row_major, | ||
typename Group = sycl::sub_group> | ||
|
@@ -58,6 +63,11 @@ struct joint_matrix { | |
PI_INVALID_DEVICE); | ||
#endif // __SYCL_DEVICE_ONLY__ | ||
} | ||
|
||
inline __SYCL_ALWAYS_INLINE wi_slice<T, NumRows, NumCols, Layout, Group> | ||
get_wi_data() { | ||
return wi_slice<T, NumRows, NumCols, Layout, Group>(*this); | ||
} | ||
}; | ||
|
||
template <typename Group, typename T, size_t NumRows, size_t NumCols, | ||
|
@@ -191,6 +201,70 @@ joint_matrix_mad(Group sg, joint_matrix<T1, M, K, LayoutA, Group> &mA, | |
PI_INVALID_DEVICE); | ||
#endif // __SYCL_DEVICE_ONLY__ | ||
} | ||
|
||
template <typename T, size_t NumRows, size_t NumCols, | ||
matrix_layout Layout = matrix_layout::row_major, | ||
typename Group = sycl::sub_group> | ||
class wi_element { | ||
joint_matrix<T, NumRows, NumCols, Layout, Group> &M; | ||
std::size_t idx; | ||
|
||
public: | ||
wi_element(joint_matrix<T, NumRows, NumCols, Layout, Group> &Mat, | ||
std::size_t i) | ||
: M(Mat), idx(i) {} | ||
operator T() { | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
return __spirv_VectorExtractDynamic(M.spvm, idx); | ||
#else | ||
throw runtime_error("joint matrix is not supported on host device.", | ||
PI_INVALID_DEVICE); | ||
#endif // __SYCL_DEVICE_ONLY__ | ||
} | ||
wi_element &operator=(const T &rhs) { | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx); | ||
return *this; | ||
#else | ||
(void)rhs; | ||
throw runtime_error("joint matrix is not supported on host device.", | ||
PI_INVALID_DEVICE); | ||
#endif // __SYCL_DEVICE_ONLY__ | ||
} | ||
wi_element &operator*=(const T &rhs) { | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
M.spvm = __spirv_VectorInsertDynamic( | ||
M.spvm, __spirv_VectorExtractDynamic(M.spvm, idx) * rhs, idx); | ||
return *this; | ||
#else | ||
(void)rhs; | ||
throw runtime_error("joint matrix is not supported on host device.", | ||
PI_INVALID_DEVICE); | ||
#endif // __SYCL_DEVICE_ONLY__ | ||
} | ||
// TODO: add other arithmetic operators | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @yubingex007-a11y please do not forget to add overloading for other operators |
||
}; | ||
|
||
template <typename T, size_t NumRows, size_t NumCols, matrix_layout Layout, | ||
typename Group> | ||
class wi_slice { | ||
joint_matrix<T, NumRows, NumCols, Layout, Group> &M; | ||
|
||
public: | ||
wi_slice(joint_matrix<T, NumRows, NumCols, Layout, Group> &Mat) : M(Mat) {} | ||
size_t length() { | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
return __spirv_JointMatrixWorkItemLengthINTEL(M.spvm); | ||
#else | ||
throw runtime_error("joint matrix is not supported on host device.", | ||
PI_INVALID_DEVICE); | ||
#endif // __SYCL_DEVICE_ONLY__ | ||
} | ||
wi_element<T, NumRows, NumCols, Layout, Group> operator[](size_t i) { | ||
return wi_element<T, NumRows, NumCols, Layout, Group>(M, i); | ||
} | ||
}; | ||
|
||
} // namespace experimental::matrix | ||
} // namespace oneapi | ||
} // namespace ext | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
// RUN: %clangxx -fsycl -O2 %s -o %t.out | ||
// XFAIL: * | ||
#include <CL/sycl.hpp> | ||
#if (SYCL_EXT_ONEAPI_MATRIX == 2) | ||
#include <iostream> | ||
|
||
using namespace sycl; | ||
using namespace sycl::ext::oneapi::experimental::matrix; | ||
|
||
#define TILE_SZ 16 | ||
#define TM (TILE_SZ - 4) | ||
#define TN (TILE_SZ - 4) | ||
#define TK (4 * TILE_SZ - 16) | ||
|
||
#define SG_SZ 16 | ||
|
||
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix { | ||
public: | ||
T *mat; | ||
|
||
public: | ||
T *get_data() { return mat; } | ||
void set_data(T *data) { mat = data; } | ||
big_matrix(T *data) : mat(data) {} | ||
}; | ||
|
||
template <typename T1, typename T2, size_t NUM_ROWS_A, size_t NUM_COLS_A, | ||
size_t NUM_ROWS_B, size_t NUM_COLS_B, size_t NUM_ROWS_C, | ||
size_t NUM_COLS_C> | ||
void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C, | ||
big_matrix<T2, NUM_ROWS_A, NUM_COLS_A> &A, | ||
big_matrix<T2, NUM_ROWS_B, NUM_COLS_B> &B) { | ||
size_t M = NUM_ROWS_C; | ||
size_t N = NUM_COLS_C; | ||
size_t K = NUM_COLS_A; | ||
// B => K/4 x N*4, A => M x K, C => M, N | ||
// stride should be X's cols, e.g., B's stirde = N*4 | ||
assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 4); | ||
size_t NDRangeM = M / TM; | ||
size_t NDRangeN = N / TN; | ||
buffer<int8_t, 2> bufA(A.get_data(), range<2>(M, K)); | ||
buffer<int8_t, 2> bufB(B.get_data(), range<2>(K, N)); | ||
buffer<int32_t, 2> bufC(C.get_data(), range<2>(M, N)); | ||
|
||
queue q; | ||
q.submit([&](handler &cgh) { | ||
auto accC = bufC.get_access<access::mode::read_write>(cgh); | ||
auto accA = bufA.get_access<access::mode::read_write>(cgh); | ||
auto accB = bufB.get_access<access::mode::read_write>(cgh); | ||
|
||
cgh.parallel_for<class imatrix>( | ||
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), | ||
[accA, accB, accC, M, N, K](nd_item<2> spmd_item) | ||
|
||
{ | ||
// The submatrix API has to be accessed by all the workitems in a | ||
// subgroup these functions will be called once by the subgroup no | ||
// code divergence between the workitems | ||
const auto global_idx = spmd_item.get_global_id(0); | ||
const auto global_idy = spmd_item.get_global_id(1); | ||
const auto sg_startx = global_idx - spmd_item.get_local_id(0); | ||
const auto sg_starty = global_idy - spmd_item.get_local_id(1); | ||
|
||
ext::oneapi::sub_group sg = spmd_item.get_sub_group(); | ||
joint_matrix<int8_t, TM, TK> sub_a(sg); | ||
// For B, since current implementation does not support non-packed | ||
// layout, users need to specify the updated VNNI sizes along with | ||
// the packed_b layout. By default, the layout is row_major and size | ||
// is (TK, TN). | ||
joint_matrix<int8_t, TK, TN, matrix_layout::packed_b> sub_b(sg); | ||
joint_matrix<int32_t, TM, TN> sub_c(sg); | ||
|
||
// AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 | ||
// strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 | ||
joint_matrix_load(sg, sub_c, | ||
accC.get_pointer() + (sg_startx * TM) * N + | ||
sg_starty / SG_SZ * TN, | ||
N, matrix_layout::row_major); | ||
for (int k = 0; k < K / TK; k += 1) { | ||
joint_matrix_load( | ||
sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, | ||
K, matrix_layout::row_major); | ||
// Assuming B data is already in VNNI format. | ||
joint_matrix_load(sg, sub_b, | ||
accB.get_pointer() + (k * TK / 4) * (N * 4) + | ||
sg_starty / SG_SZ * TN * 4, | ||
N * 4, matrix_layout::packed_b); | ||
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); | ||
} | ||
auto wi_slice_c = sub_c.get_wi_data(); | ||
for (int i = 0; i < wi_slice_c.length(); i++) { | ||
wi_slice_c[i] *= 2; | ||
} | ||
joint_matrix_store(sg, sub_c, | ||
accC.get_pointer() + (sg_startx * TM) * N + | ||
sg_starty / SG_SZ * TN, | ||
N, matrix_layout::row_major); | ||
}); // parallel for | ||
}).wait(); | ||
} | ||
|
||
static constexpr size_t MATRIX_M = TM * 2; | ||
static constexpr size_t MATRIX_N = TN * 2; | ||
static constexpr size_t MATRIX_K = TK * 2; | ||
int8_t A[MATRIX_M][MATRIX_K]; | ||
int8_t B[MATRIX_K / 4][MATRIX_N * 4]; | ||
int32_t C[MATRIX_M][MATRIX_N]; | ||
int32_t D[MATRIX_M][MATRIX_N]; | ||
|
||
void matrix_multiply_ref(int32_t *A_mem, int32_t *B_mem, int32_t *C_mem, int M, | ||
int N, int K) { | ||
// tiling | ||
for (int m = 0; m < M; m++) | ||
for (int n = 0; n < N; n++) { | ||
for (int k = 0; k < K; k++) { | ||
char *va = (char *)(A_mem + m * K + k); | ||
char *vb = (char *)(B_mem + k * N + n); | ||
int acc = *(C_mem + m * N + n); | ||
for (int i = 0; i < 4; i++) { | ||
acc += (va[i] * vb[i]); | ||
} | ||
*(C_mem + m * N + n) = acc; | ||
} | ||
*(C_mem + m * N + n) *= 2; | ||
} | ||
} | ||
|
||
int main() { | ||
for (int i = 0; i < MATRIX_M; i++) { | ||
for (int j = 0; j < MATRIX_K; j++) { | ||
A[i][j] = i + 2 * j; | ||
} | ||
} | ||
for (int i = 0; i < MATRIX_K / 4; i++) { | ||
for (int j = 0; j < MATRIX_N * 4; j++) { | ||
B[i][j] = i + j; | ||
} | ||
} | ||
for (int i = 0; i < MATRIX_M; i++) { | ||
for (int j = 0; j < MATRIX_N; j++) { | ||
C[i][j] = 1; | ||
D[i][j] = 1; | ||
} | ||
} | ||
|
||
big_matrix<int32_t, MATRIX_M, MATRIX_N> MC((int32_t *)&C); | ||
big_matrix<int32_t, MATRIX_M, MATRIX_N> MD((int32_t *)&D); | ||
big_matrix<int8_t, MATRIX_M, MATRIX_K> MA((int8_t *)&A); | ||
big_matrix<int8_t, MATRIX_K / 4, MATRIX_N * 4> MB((int8_t *)&B); | ||
matrix_multiply(MC, MA, MB); | ||
matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M, | ||
MATRIX_N, MATRIX_K / 4); | ||
|
||
bool res = true; | ||
for (int i = 0; i < MATRIX_M; i++) { | ||
for (int j = 0; j < MATRIX_N; j++) { | ||
if (C[i][j] != D[i][j]) | ||
res = false; | ||
} | ||
} | ||
if (res) | ||
std::cout << "passed\n"; | ||
else | ||
std::cout << "failed\n"; | ||
for (int i = 0; i < MATRIX_M; i++) { | ||
for (int j = 0; j < MATRIX_N; j++) | ||
std::cout << C[i][j] << ", "; | ||
std::cout << "\n"; | ||
} | ||
std::cout << std::endl; | ||
for (int i = 0; i < MATRIX_M; i++) { | ||
for (int j = 0; j < MATRIX_N; j++) | ||
std::cout << D[i][j] << ", "; | ||
std::cout << "\n"; | ||
} | ||
} | ||
#endif // (SYCL_EXT_ONEAPI_MATRIX == 2) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The width of
size_t
type is platform dependent I think. Typically it becomesi32
ori64
in LLVM IR. I would prefer a more specific type, likeuint32_t
.Is 32 bits enough for indexing in slices, what do you think @dkhaldi?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
32 bits is enough. But size_t is widely used in other APIs.
Specifically, what if the user is calculating this iterator using some WI id or other id which is also size_t, will that work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we explicitly say that this parameter must be
uint32_t
there may be a warning about type narrowing during compilation. But now I think maybe I was wrong and we really should stick tosize_t
. On LLVM IR level we can handle it in the same way asmemcpy
(https://llvm.org/docs/LangRef.html#llvm-memcpy-intrinsic), i.e. it is overloaded type.