Skip to content

Commit 03c36eb

Browse files
Use order keyword in test of type promotion for matmul
1 parent 3ce9b59 commit 03c36eb

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

dpctl/tests/test_usm_ndarray_linalg.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -336,21 +336,36 @@ def test_matmul_dtype():
336336

337337
@pytest.mark.parametrize("dt1", _numeric_types)
338338
@pytest.mark.parametrize("dt2", _numeric_types)
339-
def test_matmul_type_promotion(dt1, dt2):
339+
@pytest.mark.parametrize("order", ["C", "K"])
340+
def test_matmul_type_promotion(dt1, dt2, order):
340341
get_queue_or_skip()
341342

342343
q = get_queue_or_skip()
343344
skip_if_dtype_not_supported(dt1, q)
344345
skip_if_dtype_not_supported(dt2, q)
345346

346-
m1 = dpt.ones((10, 10), dtype=dt1)
347-
m2 = dpt.ones((10, 10), dtype=dt2)
347+
b, n, k, m = 8, 10, 17, 10
348+
m1 = dpt.ones((1, n, k), dtype=dt1)
349+
m2 = dpt.ones((b, k, m), dtype=dt2)
350+
expected_dt = dpt.result_type(m1, m2)
348351

349-
r = dpt.matmul(m1, m2)
350-
assert r.shape == (
351-
10,
352-
10,
353-
)
352+
r = dpt.matmul(m1, m2, order=order)
353+
assert r.shape == (b, n, m)
354+
assert r.dtype == expected_dt
355+
356+
m1 = dpt.ones((b, n, k), dtype=dt1)
357+
m2 = dpt.ones((1, k, m), dtype=dt2)
358+
359+
r = dpt.matmul(m1, m2, order=order)
360+
assert r.shape == (b, n, m)
361+
assert r.dtype == expected_dt
362+
363+
m1 = dpt.ones((n, k), dtype=dt1)
364+
m2 = dpt.ones((k, m), dtype=dt2)
365+
366+
r = dpt.matmul(m1, m2, order=order)
367+
assert r.shape == (n, m)
368+
assert r.dtype == expected_dt
354369

355370

356371
def test_matmul_invalid_dtype():

0 commit comments

Comments
 (0)