32
32
}
33
33
34
34
35
- def verify_batch_matmul (batch , M , N , K ):
36
- x = te .placeholder ((batch , M , K ), name = "x" )
37
- y = te .placeholder ((batch , N , K ), name = "y" )
35
+ def verify_batch_matmul (x_batch , y_batch , M , N , K ):
36
+ x = te .placeholder ((x_batch , M , K ), name = "x" )
37
+ y = te .placeholder ((y_batch , N , K ), name = "y" )
38
38
dtype = x .dtype
39
39
40
40
# use memoize to pickle the test data for next time use
41
41
@memoize ("topi.tests.test_topi_batch_matmul" )
42
42
def get_ref_data ():
43
- a_np = np .random .uniform (size = (batch , M , K )).astype (dtype )
44
- b_np = np .random .uniform (size = (batch , N , K )).astype (dtype )
43
+ a_np = np .random .uniform (size = (x_batch , M , K )).astype (dtype )
44
+ b_np = np .random .uniform (size = (y_batch , N , K )).astype (dtype )
45
45
c_np = tvm .topi .testing .batch_matmul (a_np , b_np )
46
46
return (a_np , b_np , c_np )
47
47
@@ -67,10 +67,13 @@ def check_device(device, ctx):
67
67
68
68
@tvm .testing .uses_gpu
69
69
def test_batch_matmul ():
70
- verify_batch_matmul (1 , 16 , 16 , 32 )
71
- verify_batch_matmul (5 , 16 , 16 , 32 )
72
- verify_batch_matmul (5 , 16 , 20 , 32 )
73
- verify_batch_matmul (30 , 16 , 20 , 32 )
70
+ verify_batch_matmul (1 , 1 , 16 , 16 , 32 )
71
+ verify_batch_matmul (5 , 5 , 16 , 16 , 32 )
72
+ verify_batch_matmul (5 , 5 , 16 , 20 , 32 )
73
+ verify_batch_matmul (30 , 30 , 16 , 20 , 32 )
74
+ # Test batch broadcasting.
75
+ verify_batch_matmul (1 , 5 , 16 , 16 , 32 )
76
+ verify_batch_matmul (5 , 1 , 16 , 16 , 32 )
74
77
75
78
76
79
if __name__ == "__main__" :
0 commit comments