Skip to content

Commit a7e35fc

Browse files
ajtullochtqchen
authored andcommitted
Fix vmlal.s16 code generation for int8 x int8 -> int32 (#2748)
1 parent 2239508 commit a7e35fc

File tree

2 files changed

+61
-1
lines changed

2 files changed

+61
-1
lines changed

src/pass/lower_intrin.cc

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,23 @@ class IntrinInjecter : public IRMutator {
5050
// on ARM.
5151
if (const Broadcast* bcast = e.as<Broadcast>()) {
5252
if (const Cast* cast = bcast->value.as<Cast>()) {
53-
if (cast->type.bits() == cast->value.type().bits() * 2) {
53+
auto should_swap = [&]() {
54+
// Maintain behaviour (int8 -> int16, fp16 -> fp32).
55+
if (cast->type.bits() == cast->value.type().bits() * 2) {
56+
return true;
57+
}
58+
// Check both operands are integer-like.
59+
if (!cast->type.is_uint() && !cast->type.is_int()) {
60+
return false;
61+
}
62+
if (!cast->value.type().is_uint() && !cast->value.type().is_int()) {
63+
return false;
64+
}
65+
// If both are integer-like, swap if we have a widening cast.
66+
return cast->type.bits() > cast->value.type().bits();
67+
};
68+
69+
if (should_swap()) {
5470
Expr new_bcast = Broadcast::make(cast->value, bcast->lanes);
5571
return Cast::make(bcast->type, new_bcast);
5672
}

tests/python/unittest/test_codegen_arm.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,49 @@ def check_correct_assembly(type, elements, counts):
2626
check_correct_assembly('uint32', 2, 2)
2727
check_correct_assembly('uint64', 2, 3)
2828

29+
def test_vmlal_s16():
30+
target = 'llvm -target=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon'
31+
32+
def check_correct_assembly(N):
33+
K = tvm.var("K")
34+
A = tvm.placeholder((K, N), dtype="int8", name='A')
35+
B = tvm.placeholder((K, N), dtype="int8", name='A')
36+
k = tvm.reduce_axis((0, K))
37+
C = tvm.compute((N, ), lambda n: tvm.sum(
38+
A[k, n].astype("int32") * B[k, n].astype("int32"), axis=[k]), name='C')
39+
s = tvm.create_schedule(C.op)
40+
s[C].vectorize(s[C].op.axis[0])
41+
f = tvm.build(s, [A, B, C], target)
42+
43+
# Verify we see the correct number of vmlal.s16 instructions
44+
assembly = f.get_source('asm')
45+
matches = re.findall("vmlal.s16", assembly)
46+
assert (len(matches) == N // 4)
47+
check_correct_assembly(4)
48+
check_correct_assembly(8)
49+
check_correct_assembly(16)
50+
51+
def check_broadcast_correct_assembly(N):
52+
K = tvm.var("K")
53+
A = tvm.placeholder((K, N), dtype="int8", name='A')
54+
B = tvm.placeholder((K,), dtype="int8", name='A')
55+
k = tvm.reduce_axis((0, K))
56+
C = tvm.compute((N, ), lambda n: tvm.sum(
57+
A[k, n].astype("int32") * B[k].astype("int32"),
58+
axis=[k]), name='C')
59+
s = tvm.create_schedule(C.op)
60+
s[C].vectorize(s[C].op.axis[0])
61+
f = tvm.build(s, [A, B, C], target)
62+
63+
# Verify we see the correct number of vmlal.s16 instructions
64+
assembly = f.get_source('asm')
65+
matches = re.findall("vmlal.s16", assembly)
66+
assert len(matches) == N // 4
67+
check_broadcast_correct_assembly(8)
68+
check_broadcast_correct_assembly(16)
69+
check_broadcast_correct_assembly(32)
70+
check_broadcast_correct_assembly(64)
71+
2972
if __name__ == "__main__":
3073
test_popcount()
74+
test_vmlal_s16()

0 commit comments

Comments
 (0)