diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 38af8911bc535..0099ccf8bedeb 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -109,6 +109,20 @@ def identity(expr, type_map): register_unary_identity("image.resize2d") +@register_fake_quantization_to_integer("abs") +def abs_(expr, type_map): + """Rewrite an abs op""" + assert len(expr.args) == 1 + arg = expr.args[0] + t = type_map[arg] + + min_value = relay.const(np.iinfo(t.dtype).min, t.dtype) + one = relay.const(1, t.dtype) + out = relay.op.where(relay.op.equal(min_value, arg), arg + one, arg) + out = relay.op.abs(out) + return [out, t] + + @register_fake_quantization_to_integer("nn.adaptive_avg_pool1d") def adaptive_avgpool1d(expr, type_map): """Rewrite an adaptive avgpool op""" diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index 602671af41ac9..a004de634d2dc 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -374,6 +374,19 @@ def test_fake_quantize_image_resize_bilinear(): compare_fq_to_int(op, [x_np], allow_rounding_error=True) +def test_fake_quantize_abs(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.abs(x) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + + compare_fq_to_int(op, [x_np]) + + def test_fake_quantize_expand_dims(): x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8")