@@ -336,21 +336,36 @@ def test_matmul_dtype():
336
336
337
337
@pytest .mark .parametrize ("dt1" , _numeric_types )
338
338
@pytest .mark .parametrize ("dt2" , _numeric_types )
339
- def test_matmul_type_promotion (dt1 , dt2 ):
339
+ @pytest .mark .parametrize ("order" , ["C" , "K" ])
340
+ def test_matmul_type_promotion (dt1 , dt2 , order ):
340
341
get_queue_or_skip ()
341
342
342
343
q = get_queue_or_skip ()
343
344
skip_if_dtype_not_supported (dt1 , q )
344
345
skip_if_dtype_not_supported (dt2 , q )
345
346
346
- m1 = dpt .ones ((10 , 10 ), dtype = dt1 )
347
- m2 = dpt .ones ((10 , 10 ), dtype = dt2 )
347
+ b , n , k , m = 8 , 10 , 17 , 10
348
+ m1 = dpt .ones ((1 , n , k ), dtype = dt1 )
349
+ m2 = dpt .ones ((b , k , m ), dtype = dt2 )
350
+ expected_dt = dpt .result_type (m1 , m2 )
348
351
349
- r = dpt .matmul (m1 , m2 )
350
- assert r .shape == (
351
- 10 ,
352
- 10 ,
353
- )
352
+ r = dpt .matmul (m1 , m2 , order = order )
353
+ assert r .shape == (b , n , m )
354
+ assert r .dtype == expected_dt
355
+
356
+ m1 = dpt .ones ((b , n , k ), dtype = dt1 )
357
+ m2 = dpt .ones ((1 , k , m ), dtype = dt2 )
358
+
359
+ r = dpt .matmul (m1 , m2 , order = order )
360
+ assert r .shape == (b , n , m )
361
+ assert r .dtype == expected_dt
362
+
363
+ m1 = dpt .ones ((n , k ), dtype = dt1 )
364
+ m2 = dpt .ones ((k , m ), dtype = dt2 )
365
+
366
+ r = dpt .matmul (m1 , m2 , order = order )
367
+ assert r .shape == (n , m )
368
+ assert r .dtype == expected_dt
354
369
355
370
356
371
def test_matmul_invalid_dtype ():
0 commit comments