Skip to content

Commit a7a0265

Browse files
committed
Add batch_matmul convertion to FQ2I pass
1 parent dc5da05 commit a7a0265

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

python/tvm/relay/transform/fake_quantization_to_integer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,18 @@ def dense(expr, type_map):
139139
return [out, TensorAffineType(dense_scale, dense_zp, out.attrs.out_dtype)]
140140

141141

142+
@register_fake_quantization_to_integer("nn.batch_matmul")
143+
def batch_matmul(expr, type_map):
144+
"""Rewrite a batch_matmul op"""
145+
x, y = expr.args
146+
x_t = type_map[x]
147+
y_t = type_map[y]
148+
matmul_scale = fold_constant(x_t.scale * y_t.scale)
149+
matmul_zp = relay.const(0)
150+
out = relay.qnn.op.batch_matmul(x, y, x_t.zero_point, y_t.zero_point, x_t.scale, y_t.scale)
151+
return [out, TensorAffineType(matmul_scale, matmul_zp, out.attrs.out_dtype)]
152+
153+
142154
@register_fake_quantization_to_integer("concatenate")
143155
def concat(expr, type_map):
144156
"""Rewrite a concat op"""

tests/python/relay/test_pass_fake_quantization_to_integer.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,25 @@ def test_fake_quantize_dense():
7979
compare_fq_to_int(op, [x_np, w_np])
8080

8181

82+
def test_fake_quantize_batch_matmul():
83+
for out_dtype in ["int8", "uint8"]:
84+
x = relay.var("x", shape=[1, 128, 64], dtype="int8")
85+
w = relay.var("w", shape=[1, 256, 64], dtype="int8")
86+
one = relay.const(1.0)
87+
zero = relay.const(0)
88+
89+
op = relay.op.nn.batch_matmul(
90+
relay.qnn.op.dequantize(x, relay.const(2.0), zero),
91+
relay.qnn.op.dequantize(w, relay.const(0.5), zero),
92+
)
93+
op = relay.qnn.op.quantize(op, one, zero, out_dtype=out_dtype)
94+
95+
x_np = np.random.randint(-128, 127, size=[1, 128, 64], dtype="int8")
96+
w_np = np.random.randint(-128, 127, size=[1, 256, 64], dtype="int8")
97+
98+
compare_fq_to_int(op, [x_np, w_np])
99+
100+
82101
def test_fake_transpose_quantize_conv():
83102
x = relay.var("x", shape=[1, 224, 224, 3], dtype="int8")
84103
w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8")

0 commit comments

Comments
 (0)