@@ -26,5 +26,49 @@ def check_correct_assembly(type, elements, counts):
26
26
check_correct_assembly ('uint32' , 2 , 2 )
27
27
check_correct_assembly ('uint64' , 2 , 3 )
28
28
29
+ def test_vmlal_s16 ():
30
+ target = 'llvm -target=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon'
31
+
32
+ def check_correct_assembly (N ):
33
+ K = tvm .var ("K" )
34
+ A = tvm .placeholder ((K , N ), dtype = "int8" , name = 'A' )
35
+ B = tvm .placeholder ((K , N ), dtype = "int8" , name = 'A' )
36
+ k = tvm .reduce_axis ((0 , K ))
37
+ C = tvm .compute ((N , ), lambda n : tvm .sum (
38
+ A [k , n ].astype ("int32" ) * B [k , n ].astype ("int32" ), axis = [k ]), name = 'C' )
39
+ s = tvm .create_schedule (C .op )
40
+ s [C ].vectorize (s [C ].op .axis [0 ])
41
+ f = tvm .build (s , [A , B , C ], target )
42
+
43
+ # Verify we see the correct number of vmlal.s16 instructions
44
+ assembly = f .get_source ('asm' )
45
+ matches = re .findall ("vmlal.s16" , assembly )
46
+ assert (len (matches ) == N // 4 )
47
+ check_correct_assembly (4 )
48
+ check_correct_assembly (8 )
49
+ check_correct_assembly (16 )
50
+
51
+ def check_broadcast_correct_assembly (N ):
52
+ K = tvm .var ("K" )
53
+ A = tvm .placeholder ((K , N ), dtype = "int8" , name = 'A' )
54
+ B = tvm .placeholder ((K ,), dtype = "int8" , name = 'A' )
55
+ k = tvm .reduce_axis ((0 , K ))
56
+ C = tvm .compute ((N , ), lambda n : tvm .sum (
57
+ A [k , n ].astype ("int32" ) * B [k ].astype ("int32" ),
58
+ axis = [k ]), name = 'C' )
59
+ s = tvm .create_schedule (C .op )
60
+ s [C ].vectorize (s [C ].op .axis [0 ])
61
+ f = tvm .build (s , [A , B , C ], target )
62
+
63
+ # Verify we see the correct number of vmlal.s16 instructions
64
+ assembly = f .get_source ('asm' )
65
+ matches = re .findall ("vmlal.s16" , assembly )
66
+ assert len (matches ) == N // 4
67
+ check_broadcast_correct_assembly (8 )
68
+ check_broadcast_correct_assembly (16 )
69
+ check_broadcast_correct_assembly (32 )
70
+ check_broadcast_correct_assembly (64 )
71
+
29
72
if __name__ == "__main__" :
30
73
test_popcount ()
74
+ test_vmlal_s16 ()
0 commit comments