Skip to content

Commit 90951cc

Browse files
Josh FrommMatthew
authored andcommitted
Fix legalization for non spatial operators. (#6)
* Fix legalization for non spatial operators. * Fix axis checks for end2end functionality.
1 parent 7cf9729 commit 90951cc

File tree

4 files changed

+37
-9
lines changed

4 files changed

+37
-9
lines changed

python/tvm/relay/qnn/op/legalizations.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import tvm
2222
from tvm import relay
23+
from tvm._ffi.base import TVMError
2324
from .. import op as reg
2425

2526
#################################################
@@ -148,8 +149,21 @@ def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op):
148149
)
149150
# Otherwise it needs to be broadcast.
150151
else:
151-
# Determine output axis of kernel.
152-
output_axis = tvm.tir.layout(attrs["kernel_layout"]).index_of("O")
152+
# Determine output axis of kernel for spatial operations.
153+
if hasattr(attrs, "kernel_layout"):
154+
output_axis = tvm.tir.layout(attrs["kernel_layout"]).index_of("O")
155+
# For dense operations, broadcast to [N, K] layout.
156+
elif isinstance(attrs, relay.op.op_attrs.DenseAttrs):
157+
output_axis = 0
158+
# For matrix multiplication instead expand to [K, N] layout.
159+
elif isinstance(attrs, relay.op.op_attrs.MatmulAttrs):
160+
output_axis = 1
161+
else:
162+
raise TVMError(
163+
"Legalization of %s is not yet supported with per channel parameters"
164+
% str(type(attrs))
165+
)
166+
153167
shift_kernel = relay.nn.bias_add(
154168
relay.cast(kernel, dtype="int16"),
155169
relay.cast(kernel_zero_point, dtype="int16"),

src/relay/qnn/op/dequantize.cc

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,16 @@ bool DequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
5454
const auto* dequantize_attrs = attrs.as<DequantizeAttrs>();
5555
int axis = dequantize_attrs->axis;
5656
auto rank = static_cast<int>(data->shape.size());
57-
axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis;
58-
ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << dequantize_attrs->axis << " is out of range";
59-
ICHECK_GE(axis, 0) << "axis " << dequantize_attrs->axis << " is out of range";
57+
58+
// If zero point and scale are scalar then axis doesnt matter.
59+
bool scale_is_scalar = (types[1].as<TensorTypeNode>())->shape.size() == 0;
60+
bool zp_is_scalar = (types[2].as<TensorTypeNode>())->shape.size() == 0;
61+
62+
if (!(scale_is_scalar && zp_is_scalar)) {
63+
axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis;
64+
ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << dequantize_attrs->axis << " is out of range";
65+
ICHECK_GE(axis, 0) << "axis " << dequantize_attrs->axis << " is out of range";
66+
}
6067

6168
PrimExpr axis_shape;
6269
if (rank > 0) {

src/relay/qnn/op/quantize.cc

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,16 @@ bool QuantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
5252
const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
5353
int axis = quantize_attrs->axis;
5454
auto rank = static_cast<int>(data->shape.size());
55-
axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis;
56-
ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << quantize_attrs->axis << " is out of range";
57-
ICHECK_GE(axis, 0) << "axis " << quantize_attrs->axis << " is out of range";
55+
56+
// If zero point and scale are scalar then axis doesnt matter.
57+
bool scale_is_scalar = (types[1].as<TensorTypeNode>())->shape.size() == 0;
58+
bool zp_is_scalar = (types[2].as<TensorTypeNode>())->shape.size() == 0;
59+
60+
if (!(scale_is_scalar && zp_is_scalar)) {
61+
axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis;
62+
ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << quantize_attrs->axis << " is out of range";
63+
ICHECK_GE(axis, 0) << "axis " << quantize_attrs->axis << " is out of range";
64+
}
5865

5966
PrimExpr axis_shape;
6067
if (rank > 0) {

tests/python/relay/test_pass_fake_quantization_to_integer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def test_fake_quantize_dense_per_channel():
130130
x_np = np.random.randint(-128, 127, size=[128, 64], dtype="int8")
131131
w_np = np.random.randint(-128, 127, size=[256, 64], dtype="int8")
132132

133-
compare_fq_to_int(op, [x_np, w_np])
133+
compare_fq_to_int(op, [x_np, w_np], allow_rounding_error=True)
134134

135135

136136
def test_fake_quantize_batch_matmul():

0 commit comments

Comments
 (0)