Skip to content

Commit 978ec48

Browse files
Address Douniai&Dmitry's comments
1 parent ef015f4 commit 978ec48

File tree

4 files changed

+6
-181
lines changed

4 files changed

+6
-181
lines changed

sycl/include/CL/__spirv/spirv_ops.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ template <typename T, std::size_t R, std::size_t C,
9090
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
9191
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
9292
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S> *
93-
__spirv_JointMatrixFillINTEL(const T &v, __spv::Scope::Flag Sc = S);
93+
__spirv_CompositeConstruct(const T v);
9494

9595
#ifndef __SPIRV_BUILTIN_DECLARATIONS__
9696
#error \

sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,12 +197,9 @@ template <typename Group, typename T, size_t NumRows, size_t NumCols,
197197
inline __SYCL_ALWAYS_INLINE void
198198
joint_matrix_fill(Group sg,
199199
joint_matrix<T, NumRows, NumCols, Layout, Group> &res,
200-
const T &v) {
200+
const T v) {
201201
#ifdef __SYCL_DEVICE_ONLY__
202-
res.spvm =
203-
__spirv_JointMatrixFillINTEL<T, NumRows, NumCols,
204-
spv_matrix_layout_traits<Layout>::value>(
205-
v, spv_scope_traits<Group>::value);
202+
res.spvm = __spirv_CompositeConstruct<T, NumRows, NumCols>(v);
206203
#else
207204
(void)res;
208205
(void)v;

sycl/test/matrix/matrix-int8-test-fill.cpp

Lines changed: 0 additions & 169 deletions
This file was deleted.

sycl/test/matrix/matrix-int8-test.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,7 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C, big_matrix<T2, N
6868

6969
// AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64
7070
// strideX = X's cols, so strideC = N, strideA = K, strideB = N*4
71-
joint_matrix_load(sg, sub_c,
72-
accC.get_pointer() + (sg_startx * TM) * N +
73-
sg_starty / SG_SZ * TN,
74-
N, matrix_layout::row_major);
71+
joint_matrix_fill(sg, sub_c, 0);
7572
for (int k = 0; k < K / TK; k += 1) {
7673
joint_matrix_load(
7774
sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK,
@@ -129,8 +126,8 @@ int main() {
129126
}
130127
for (int i = 0; i < MATRIX_M; i++) {
131128
for (int j = 0; j < MATRIX_N; j++) {
132-
C[i][j] = 1;
133-
D[i][j] = 1;
129+
C[i][j] = 0;
130+
D[i][j] = 0;
134131
}
135132
}
136133

0 commit comments

Comments
 (0)