Skip to content

Commit ff9c480

Browse files
authored
making quantization tweaks (#6731)
1 parent 89ce1ed commit ff9c480

File tree

2 files changed

+79
-0
lines changed

2 files changed

+79
-0
lines changed

python/tvm/relay/quantize/_annotate.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,28 @@ def conv2d_rewrite(ref_call, new_args, ctx):
175175
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
176176

177177

178+
@register_annotate_function("nn.conv1d")
179+
def conv1d_rewrite(ref_call, new_args, ctx):
180+
"""Rewrite function for conv1d. Lhs of conv will be quantized to
181+
input field, and rhs of conv will be quantized to weight field.
182+
Output would be in activation field"""
183+
if quantize_context().check_to_skip(ref_call):
184+
return None
185+
186+
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
187+
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
188+
189+
if lhs_kind is None or lhs_kind == QAnnotateKind.ACTIVATION:
190+
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
191+
192+
assert rhs_kind is None
193+
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
194+
195+
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
196+
197+
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
198+
199+
178200
@register_annotate_function("nn.dense")
179201
def dense_rewrite(ref_call, new_args, ctx):
180202
"""Rewrite function for dense. Lhs of dense will be quantized to input field, and rhs of
@@ -289,6 +311,8 @@ def identity_rewrite(ref_call, new_args, ctx):
289311
register_annotate_function("nn.relu", identity_rewrite)
290312
register_annotate_function("strided_slice", identity_rewrite)
291313
register_annotate_function("nn.avg_pool2d", identity_rewrite)
314+
register_annotate_function("nn.batch_flatten", identity_rewrite)
315+
register_annotate_function("transpose", identity_rewrite)
292316
register_annotate_function("annotation.stop_fusion", identity_rewrite)
293317

294318

@@ -311,6 +335,25 @@ def pool2d_rewrite(ref_call, new_args, ctx):
311335
register_annotate_function("nn.max_pool2d", pool2d_rewrite)
312336

313337

338+
def pool1d_rewrite(ref_call, new_args, ctx):
339+
"""Rewrite function for max pool1d"""
340+
if quantize_context().check_to_skip(ref_call):
341+
return None
342+
343+
expr, x_kind = _get_expr_kind(new_args[0])
344+
345+
if x_kind is None:
346+
return None
347+
if x_kind == QAnnotateKind.ACTIVATION:
348+
expr = attach_simulated_quantize(expr, QAnnotateKind.INPUT)
349+
350+
expr = _forward_op(ref_call, [expr])
351+
return QAnnotateExpr(expr, QAnnotateKind.INPUT)
352+
353+
354+
register_annotate_function("nn.max_pool1d", pool1d_rewrite)
355+
356+
314357
@register_annotate_function("annotation.cast_hint")
315358
def cast_hint_rewrite(ref_call, new_args, ctx):
316359
"""Rewrite function to force cast"""

src/relay/quantize/realize.cc

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,37 @@ Expr Conv2dRealize(const Call& ref_call, const Array<Expr>& new_args, const Obje
234234

235235
RELAY_REGISTER_OP("nn.conv2d").set_attr<FForwardRewrite>("FQRealizeRewrite", Conv2dRealize);
236236

237+
Expr Conv1dRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
238+
const QConfig& cfg = QConfig::Current();
239+
CHECK_EQ(new_args.size(), 2);
240+
if (!new_args[0]->IsInstance<TempExprNode>() && !new_args[1]->IsInstance<TempExprNode>()) {
241+
return Expr(nullptr);
242+
}
243+
const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
244+
CHECK(lhs);
245+
const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
246+
CHECK(rhs);
247+
248+
Expr ldata = lhs->data;
249+
if (lhs->dtype != cfg->dtype_input) {
250+
ldata = Cast(ldata, cfg->dtype_input);
251+
}
252+
Expr rdata = Cast(rhs->data, cfg->dtype_weight);
253+
254+
const auto ref_attrs = ref_call->attrs.as<Conv1DAttrs>();
255+
auto attrs = make_object<Conv1DAttrs>();
256+
*attrs = *ref_attrs;
257+
DataType out_dtype = cfg->dtype_activation;
258+
attrs->out_dtype = out_dtype;
259+
260+
Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args);
261+
Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
262+
Expr dom_scale = FoldConstantOpt(mul);
263+
return QRealizeIntExpr(ret, dom_scale, out_dtype);
264+
}
265+
266+
RELAY_REGISTER_OP("nn.conv1d").set_attr<FForwardRewrite>("FQRealizeRewrite", Conv1dRealize);
267+
237268
Expr DenseRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
238269
const QConfig& cfg = QConfig::Current();
239270
ICHECK_EQ(new_args.size(), 2);
@@ -449,6 +480,8 @@ RELAY_REGISTER_OP("strided_slice").set_attr<FForwardRewrite>("FQRealizeRewrite",
449480
RELAY_REGISTER_OP("nn.batch_flatten")
450481
.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
451482

483+
RELAY_REGISTER_OP("transpose").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
484+
452485
RELAY_REGISTER_OP("annotation.stop_fusion")
453486
.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
454487

@@ -469,6 +502,9 @@ Expr CastDtypeInputRealize(const Call& ref_call, const Array<Expr>& new_args,
469502
RELAY_REGISTER_OP("nn.max_pool2d")
470503
.set_attr<FForwardRewrite>("FQRealizeRewrite", CastDtypeInputRealize);
471504

505+
RELAY_REGISTER_OP("nn.max_pool1d")
506+
.set_attr<FForwardRewrite>("FQRealizeRewrite", CastDtypeInputRealize);
507+
472508
Expr AvgPoolRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
473509
const QConfig& cfg = QConfig::Current();
474510
ICHECK_EQ(new_args.size(), 1);

0 commit comments

Comments
 (0)