Skip to content

Commit 61ec3d3

Browse files
Added test for matmul based on discovered failure due to typo
1 parent a885034 commit 61ec3d3

File tree

1 file changed

+36
-1
lines changed

1 file changed

+36
-1
lines changed

dpctl/tests/test_usm_ndarray_linalg.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@ def random_matrix():
579579

580580

581581
@pytest.mark.parametrize("dtype", _numeric_types)
582-
def test_matmul_largish(dtype, random_matrix):
582+
def test_matmul_largish_square(dtype, random_matrix):
583583
q = get_queue_or_skip()
584584
skip_if_dtype_not_supported(dtype, q)
585585

@@ -598,6 +598,7 @@ def test_matmul_largish(dtype, random_matrix):
598598
assert dpt.allclose(x1, x2, atol=tol, rtol=tol)
599599
assert dpt.allclose(x1, dpt.asarray(x_np), atol=tol, rtol=tol)
600600

601+
# check stided input
601602
m_np = m_np[:-1, :-1]
602603
x_np = np.matmul(m_np.T, m_np)
603604

@@ -610,6 +611,40 @@ def test_matmul_largish(dtype, random_matrix):
610611
assert dpt.allclose(x1, dpt.asarray(x_np), atol=tol, rtol=tol)
611612

612613

614+
@pytest.mark.parametrize("dtype", _numeric_types)
615+
def test_matmul_largish_rect(dtype, random_matrix):
616+
q = get_queue_or_skip()
617+
skip_if_dtype_not_supported(dtype, q)
618+
619+
m_np = random_matrix.astype(dtype)[:, :-1]
620+
x_np = np.matmul(m_np.T[:-2, :], m_np)
621+
622+
m = dpt.asarray(m_np)
623+
mmT = m.mT[:-2, :]
624+
mT = dpt.asarray(mmT, copy=True, order="C")
625+
x1 = dpt.matmul(mmT, m)
626+
x2 = dpt.matmul(mT, m)
627+
628+
tol = 0
629+
if dpt.isdtype(x2.dtype, ("real floating", "complex floating")):
630+
tol = 32 * dpt.finfo(x2.dtype).eps
631+
632+
assert dpt.allclose(x1, x2, atol=tol, rtol=tol)
633+
assert dpt.allclose(x1, dpt.asarray(x_np), atol=tol, rtol=tol)
634+
635+
m_np = m_np[:-1, :-1]
636+
x_np = np.matmul(m_np.T[:-2, :], m_np)
637+
638+
m = m[:-1, :-1]
639+
mmT = m.mT[:-2, :]
640+
mT = dpt.asarray(mmT, copy=True, order="C")
641+
x1 = dpt.matmul(mmT, m)
642+
x2 = dpt.matmul(mT, m)
643+
644+
assert dpt.allclose(x1, x2, atol=tol, rtol=tol)
645+
assert dpt.allclose(x1, dpt.asarray(x_np), atol=tol, rtol=tol)
646+
647+
613648
@pytest.mark.parametrize("dtype", _numeric_types)
614649
def test_tensordot_outer(dtype):
615650
q = get_queue_or_skip()

0 commit comments

Comments
 (0)