@@ -579,7 +579,7 @@ def random_matrix():
579
579
580
580
581
581
@pytest .mark .parametrize ("dtype" , _numeric_types )
582
- def test_matmul_largish (dtype , random_matrix ):
582
+ def test_matmul_largish_square (dtype , random_matrix ):
583
583
q = get_queue_or_skip ()
584
584
skip_if_dtype_not_supported (dtype , q )
585
585
@@ -598,6 +598,7 @@ def test_matmul_largish(dtype, random_matrix):
598
598
assert dpt .allclose (x1 , x2 , atol = tol , rtol = tol )
599
599
assert dpt .allclose (x1 , dpt .asarray (x_np ), atol = tol , rtol = tol )
600
600
601
+ # check stided input
601
602
m_np = m_np [:- 1 , :- 1 ]
602
603
x_np = np .matmul (m_np .T , m_np )
603
604
@@ -610,6 +611,40 @@ def test_matmul_largish(dtype, random_matrix):
610
611
assert dpt .allclose (x1 , dpt .asarray (x_np ), atol = tol , rtol = tol )
611
612
612
613
614
+ @pytest .mark .parametrize ("dtype" , _numeric_types )
615
+ def test_matmul_largish_rect (dtype , random_matrix ):
616
+ q = get_queue_or_skip ()
617
+ skip_if_dtype_not_supported (dtype , q )
618
+
619
+ m_np = random_matrix .astype (dtype )[:, :- 1 ]
620
+ x_np = np .matmul (m_np .T [:- 2 , :], m_np )
621
+
622
+ m = dpt .asarray (m_np )
623
+ mmT = m .mT [:- 2 , :]
624
+ mT = dpt .asarray (mmT , copy = True , order = "C" )
625
+ x1 = dpt .matmul (mmT , m )
626
+ x2 = dpt .matmul (mT , m )
627
+
628
+ tol = 0
629
+ if dpt .isdtype (x2 .dtype , ("real floating" , "complex floating" )):
630
+ tol = 32 * dpt .finfo (x2 .dtype ).eps
631
+
632
+ assert dpt .allclose (x1 , x2 , atol = tol , rtol = tol )
633
+ assert dpt .allclose (x1 , dpt .asarray (x_np ), atol = tol , rtol = tol )
634
+
635
+ m_np = m_np [:- 1 , :- 1 ]
636
+ x_np = np .matmul (m_np .T [:- 2 , :], m_np )
637
+
638
+ m = m [:- 1 , :- 1 ]
639
+ mmT = m .mT [:- 2 , :]
640
+ mT = dpt .asarray (mmT , copy = True , order = "C" )
641
+ x1 = dpt .matmul (mmT , m )
642
+ x2 = dpt .matmul (mT , m )
643
+
644
+ assert dpt .allclose (x1 , x2 , atol = tol , rtol = tol )
645
+ assert dpt .allclose (x1 , dpt .asarray (x_np ), atol = tol , rtol = tol )
646
+
647
+
613
648
@pytest .mark .parametrize ("dtype" , _numeric_types )
614
649
def test_tensordot_outer (dtype ):
615
650
q = get_queue_or_skip ()
0 commit comments