Skip to content

Commit 39cf672

Browse files
committed
Tweaks to matmul and gemm kernels
Fixes a missing indexer in gemm functor with threading along `nm` dimensions Fixes `matmul` broadcasting, which was broadcasting in some unnecessary cases
1 parent af41424 commit 39cf672

File tree

3 files changed

+15
-10
lines changed

3 files changed

+15
-10
lines changed

dpctl/tensor/_linear_algebra_functions.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -823,9 +823,9 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
823823
sycl_queue=exec_q,
824824
order=order,
825825
)
826-
if x1.shape != res_shape:
826+
if x1.shape != x1_broadcast_shape:
827827
x1 = dpt.broadcast_to(x1, x1_broadcast_shape)
828-
if x2.shape != res_shape:
828+
if x2.shape != x2_broadcast_shape:
829829
x2 = dpt.broadcast_to(x2, x2_broadcast_shape)
830830
ht_dot_ev, binary_ev = tli._dot(
831831
x1=x1,
@@ -875,9 +875,10 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
875875
order=order,
876876
)
877877

878-
if x1.shape != res_shape:
878+
if x1.shape != x1_broadcast_shape:
879879
x1 = dpt.broadcast_to(x1, x1_broadcast_shape)
880-
buf2 = dpt.broadcast_to(buf2, x2_broadcast_shape)
880+
if buf2.shape != x2_broadcast_shape:
881+
buf2 = dpt.broadcast_to(buf2, x2_broadcast_shape)
881882
ht_dot_ev, binary_ev = tli._dot(
882883
x1=x1,
883884
x2=buf2,
@@ -929,8 +930,9 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
929930
order=order,
930931
)
931932

932-
buf1 = dpt.broadcast_to(buf1, x1_broadcast_shape)
933-
if x2.shape != res_shape:
933+
if buf1.shape != x1_broadcast_shape:
934+
buf1 = dpt.broadcast_to(buf1, x1_broadcast_shape)
935+
if x2.shape != x2_broadcast_shape:
934936
x2 = dpt.broadcast_to(x2, x2_broadcast_shape)
935937
ht_dot_ev, binary_ev = tli._dot(
936938
x1=buf1,
@@ -994,8 +996,10 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
994996
order=order,
995997
)
996998

997-
buf1 = dpt.broadcast_to(buf1, x1_broadcast_shape)
998-
buf2 = dpt.broadcast_to(buf2, x2_broadcast_shape)
999+
if buf1.shape != x1_broadcast_shape:
1000+
buf1 = dpt.broadcast_to(buf1, x1_broadcast_shape)
1001+
if buf2.shape != x2_broadcast_shape:
1002+
buf2 = dpt.broadcast_to(buf2, x2_broadcast_shape)
9991003
ht_, _ = tli._dot(
10001004
x1=buf1,
10011005
x2=buf2,

dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,8 @@ class GemmFunctorThreadNM
533533
size_t g_j = g_j0 + lane_id;
534534
vec[lane_id] =
535535
(g_j < m && g_s < k)
536-
? static_cast<resT>(rhs[g_s * b_st0 + g_j * b_st1])
536+
? static_cast<resT>(
537+
rhs[rhs_indexer(g_s * b_st0 + g_j * b_st1)])
537538
: resT(0);
538539
}
539540

dpctl/tests/test_usm_ndarray_linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def test_matmul_simple(dtype):
7272
q = get_queue_or_skip()
7373
skip_if_dtype_not_supported(dtype, q)
7474

75-
n, m = 100, 17
75+
n, m = 235, 17
7676
m1 = dpt.ones((m, n), dtype=dtype)
7777
m2 = dpt.ones((n, m), dtype=dtype)
7878

0 commit comments

Comments
 (0)