Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/tvm/relay/quantize/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class QConfig(NodeBase):
"do_simulation": False,
"round_for_shift": True,
"debug_enabled_ops": None,
"rounding": "UPWARD"
}

# pylint: disable=no-member
Expand Down Expand Up @@ -160,6 +161,9 @@ def qconfig(**kwargs):
is None, which means will try to call all operartors' annotate rewrite
function.

rounding: "UPWARD" or "TONEAREST"
Rounding direction for fixed point multiplications.

Returns
-------
config: QConfig
Expand Down
3 changes: 2 additions & 1 deletion src/relay/pass/quantize/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", ";
p->stream << "do_simulation==" << op->do_simulation << ", ";
p->stream << "round_for_shift==" << op->round_for_shift << ", ";
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops;
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops <<", ";
p->stream << "rounding==" << op->rounding;
p->stream << ")";
});

Expand Down
2 changes: 2 additions & 0 deletions src/relay/pass/quantize/quantize.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class QConfigNode : public Node {
bool do_simulation = false;
bool round_for_shift = true;
Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr));
std::string rounding = "UPWARD";

void VisitAttrs(AttrVisitor* v) final {
v->Visit("nbit_input", &nbit_input);
Expand All @@ -90,6 +91,7 @@ class QConfigNode : public Node {
v->Visit("do_simulation", &do_simulation);
v->Visit("round_for_shift", &round_for_shift);
v->Visit("debug_enabled_ops", &debug_enabled_ops);
v->Visit("rounding", &rounding);
}

static constexpr const char* _type_key = "relay.quantize.QConfig";
Expand Down
23 changes: 13 additions & 10 deletions src/relay/pass/quantize/realize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <tvm/relay/attrs/annotation.h>
#include "./quantize.h"
#include "../pattern_util.h"
#include "../../qnn/util.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -97,7 +98,9 @@ inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) {


/* calculate `data * s1 / s2`, use shift if possible */
inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) {
inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a requantize operation? Is it?

Maybe, we should directly call QNN Requantize operator then. And hide the power of 2 handling inside the requantize/Fixed point multiplication. Special power of 2 handling will help QNN as well. I am fine this PR going in first, and later refactoring. Just putting the point across. Requantize will be a better name for the method.

const Array<IndexExpr> &data_shape) {
const QConfig& cfg = QConfig::Current();
// here we assume the dtype of data is dtype activation
if (s1 == s2) return data;

Expand All @@ -110,9 +113,8 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) {
} else if (static_cast<int>(factor) == factor) {
return Multiply(data, MakeConstantScalar(dtype, factor));
} else {
data = Cast(data, Float(32));
data = Multiply(data, MakeConstantScalar(Float(32), factor));
return Cast(Round(data), dtype);
data = qnn::FixedPointMultiply(Cast(data, Int(64)), factor, data_shape, cfg->rounding);
return Cast(data, dtype);
}
}

Expand Down Expand Up @@ -164,11 +166,12 @@ Expr QuantizeRealize(const Call& ref_call,
data = Clip(data, clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
} else {
// float computation
data = Cast(data, Float(32));
Expr scaled_data = Multiply(data, Divide(n->dom_scale, dom_scale));
Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(round_data, dom_scale, Float(32));
data = Cast(data, Int(64));
data = qnn::FixedPointMultiply(data, idom_scale_imm / odom_scale_imm,
ref_call->type_as<TensorTypeNode>()->shape,
cfg->rounding);
data = Cast(Clip(data, clip_min_imm, clip_max_imm), n->dtype);
return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
}
}

Expand Down Expand Up @@ -355,7 +358,7 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args
Expr dom_scale = MakeConstantScalar(Float(32), s);
for (size_t i = 0; i < ret.size(); ++i) {
float cur_s = GetScalarFromConstant<float>(nptrs[i]->dom_scale);
ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype));
ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype, ref_args[i]->type_as<TensorTypeNode>()->shape));
}

*dtype_ptr = dtype;
Expand Down
6 changes: 2 additions & 4 deletions src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ TVM_REGISTER_NODE_TYPE(RequantizeAttrs);

// Lowering of qnn.requantize op



/*
* \brief Lower requantize to a sequence of ops.
* \param input_tensor The input tensor to requantize op.
Expand Down Expand Up @@ -73,8 +71,8 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
// 2) If the input and output scales are same, we can skip the fixed point multiplication.
auto scaled_int64_t = tensor;
if (param->input_scale != param->output_scale) {
scaled_int64_t = FixedPointMuliply(scaled_int64_t, double_multiplier, input_shape,
param->rounding);
scaled_int64_t =
FixedPointMultiply(scaled_int64_t, double_multiplier, input_shape, param->rounding);
}

// 3) Add the output zero point.
Expand Down
4 changes: 3 additions & 1 deletion src/relay/qnn/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(
return std::make_pair(significand, exponent);
}

Expr FixedPointMuliply(Expr tensor, double multiplier,
Expr FixedPointMultiply(Expr tensor, double multiplier,
const Array<IndexExpr>& input_shape, const std::string& rounding) {
// Choose high precision datatype to be int64. This is for avoiding overflow
// in multiplication of two int32 values.
Expand Down Expand Up @@ -121,6 +121,8 @@ Expr FixedPointMuliply(Expr tensor, double multiplier,
auto zero_t = Zeros(input_shape, hp_dtype);
round_scalar =
Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
} else {
LOG(FATAL) << "Rounding mode " << rounding << " not supported.";
}
// Add the rounding scalar.
tensor = Add(tensor, round_scalar);
Expand Down
6 changes: 3 additions & 3 deletions src/relay/qnn/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ static inline int64_t get_const_int(const tvm::Expr& x) {
* 2) Round the result.
* 3) Right shift the result
*/
Expr FixedPointMuliply(Expr tensor, double multiplier,
const Array<IndexExpr>& input_shape,
const std::string& rounding);
Expr FixedPointMultiply(Expr tensor, double multiplier,
const Array<IndexExpr>& input_shape,
const std::string& rounding);

} // namespace qnn
} // namespace relay
Expand Down