Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions src/operator/quantization/quantize_graph_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ Graph QuantizeGraph(Graph &&src) {
static const auto& need_requantize_map = Op::GetAttr<mxnet::FNeedRequantize>("FNeedRequantize");
static const auto& avoid_quantize_input_map =
Op::GetAttr<mxnet::FAvoidQuantizeInput>("FAvoidQuantizeInput");
static const auto& flist_inputs = nnvm::Op::GetAttr<nnvm::FListOutputNames>("FListInputNames");
const auto offline_params = src.GetAttr<std::unordered_set<std::string>>("offline_params");
const auto quantized_dtype = src.GetAttr<std::string>("quantized_dtype");
const auto quantize_granularity = src.GetAttr<std::string>("quantize_granularity");
Expand Down Expand Up @@ -346,7 +347,13 @@ Graph QuantizeGraph(Graph &&src) {
std::string name = GetOutputName(e.node.get(), e.index);
suffix = "_" + name;
} else if (!offline_params.count(new_name)) {
new_name = node->attrs.name + "_" + e.node->attrs.name;
std::string input_name;
if (flist_inputs.count(node->op())) {
input_name = flist_inputs[node->op()](node->attrs)[i];
new_name = node->attrs.name + "_" + input_name;
} else {
new_name = node->attrs.name + "_" + e.node->attrs.name;
}
}

ObjectPtr quantize_node = InsertNode("_contrib_quantize_v2",
Expand Down Expand Up @@ -504,20 +511,33 @@ Graph QuantizeGraph(Graph &&src) {
static const auto& need_calib_output_map =
Op::GetAttr<mxnet::FNeedCalibrateOutput>("FNeedCalibrateOutput");

std::stack<std::string> calib_variables;
std::unordered_set<nnvm::ObjectPtr> calib_variables;
std::vector<std::string> calib_nodes;
DFSVisit(ret.outputs, [&](const ObjectPtr& node) {
if (node->op() && !calib_variables.empty()) {
if (reverse_mirror_map.count(node)) {
const std::string& var_name = calib_variables.top();
const auto& fp32_in_node = reverse_mirror_map[node];
for (const auto &input_node : fp32_in_node->inputs) {
if (var_name == input_node.node->attrs.name) {
calib_nodes.push_back(fp32_in_node->attrs.name + "_" + var_name);
calib_variables.pop();
break;
// find nodes where input is variable node
// and add proper input_name to calib_nodes
for (int i = 0; i < node->inputs.size(); i++) {
const auto &input_node = node->inputs[i];
if (calib_variables.find(input_node.node) != std::end(calib_variables)) {
auto fp32_node = std::find_if(std::begin(quantized_node_map),
std::end(quantized_node_map),
[&](const std::pair<ObjectPtr, ObjectPtr> &pair) {
return pair.second == node;
});
if (fp32_node != std::end(quantized_node_map)) {
const auto& fp32_in_node = fp32_node->first;
std::string node_input_name;
if (flist_inputs.count(fp32_in_node->op())) {
std::string op_input_name = flist_inputs[fp32_in_node->op()](fp32_in_node->attrs)[i];
node_input_name = fp32_in_node->attrs.name + "_" + op_input_name;
} else {
node_input_name = fp32_in_node->attrs.name + "_" + input_node.node->attrs.name;
}
calib_nodes.push_back(node_input_name);
calib_variables.erase(input_node.node);
}
}
}
}
if (need_calib_input_map.count(node->op())) {
Expand All @@ -530,10 +550,13 @@ Graph QuantizeGraph(Graph &&src) {
} else {
const auto& e = node->inputs[idx];
if (e.node->is_variable()) {
// monitor callback join operator name and variable name as observable node,
// utilize fact that we're using DFS and put variable name on stack to
// find operator node name for this variable node
calib_variables.emplace(e.node->attrs.name);
// monitor callback join operator name and variable name as observable node name,
// instead of using variable output we can use op node input
//
// data_output/fc_input
// e.g. data (var.) ----------------------> FC (op)
// remember current node and compare with inputs of next nodes
calib_variables.insert(node);
} else {
if (reverse_mirror_map.count(e.node)) {
const auto& fp32_in_node = reverse_mirror_map.at(e.node);
Expand Down
9 changes: 0 additions & 9 deletions src/operator/quantization/quantized_elemwise_mul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,6 @@ namespace op {

DMLC_REGISTER_PARAMETER(QuantizeElemwiseMulParam);

static std::vector<std::string> QuantizedElemwiseMulOutputNames(const NodeAttrs &attrs) {
const QuantizeElemwiseMulParam& params = nnvm::get<QuantizeElemwiseMulParam>(attrs.parsed);
if (params.enable_float_output)
return std::vector<std::string>{"output"};
else
return std::vector<std::string>{"output", "min_output", "max_output"};
}

inline bool QuantizedElemwiseMulOpShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
Expand Down Expand Up @@ -228,7 +220,6 @@ NNVM_REGISTER_OP(_contrib_quantized_elemwise_mul)
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"lhs", "rhs", "lhs_min", "lhs_max", "rhs_min", "rhs_max"};
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames", QuantizedElemwiseMulOutputNames)
.set_attr<mxnet::FInferShape>("FInferShape", QuantizedElemwiseMulOpShape)
.set_attr<nnvm::FInferType>("FInferType", QuantizedElemwiseMulOpType)
.set_attr<FInferStorageType>("FInferStorageType", QuantizedElemwiseMulOpStorageType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_ELEMWISEMUL_POST_QUANTIZE_PROPERTY_H_
#if MXNET_USE_ONEDNN == 1

#include <memory>
#include <string>
#include <vector>
#include "../../tensor/elemwise_binary_op-inl.h"
Expand All @@ -40,7 +41,7 @@ namespace op {

#define QUANTIZED_ElemwiseMul_NAME "_contrib_quantized_elemwise_mul"

class ElemwiseMulPostQuantizeSelector : public SubgraphSelector {
class ElemwiseMulPostQuantizeSelector : public SubgraphSelectorV2 {
public:
/*! \brief pattern match status */
enum SelectStatus {
Expand All @@ -54,16 +55,17 @@ class ElemwiseMulPostQuantizeSelector : public SubgraphSelector {
bool disable_all;
bool disable_float_output;
SelectStatus status;
std::vector<const nnvm::Node *> matched_list;
std::vector<const BiDirectedNode *> matched_list;

public:
explicit ElemwiseMulPostQuantizeSelector(const bool dis_all,
const bool dis_float_output)
: disable_all(dis_all),
disable_float_output(dis_float_output) {}

bool Select(const nnvm::Node &n) override {
if ((!disable_all) && n.op() == Op::Get(QUANTIZED_ElemwiseMul_NAME)) {
bool Select(const BiDirectedNode &n) override {
const auto rawnode = n.node;
if ((!disable_all) && rawnode->op() == Op::Get(QUANTIZED_ElemwiseMul_NAME)) {
status = disable_all ? kSuccess : kStart;
matched_list.clear();
matched_list.push_back(&n);
Expand All @@ -72,12 +74,14 @@ class ElemwiseMulPostQuantizeSelector : public SubgraphSelector {
return false;
}

bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override {
bool SelectInput(const BiDirectedNode &n, const BiDirectedNode &new_node) override {
return false;
}

bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override {
if (status == kFail || status == kSuccess || new_node.is_variable())
bool SelectOutput(const BiDirectedNode &n, const BiDirectedNode &new_node) override {
const auto raw_node = n.node;
const auto raw_new_node = new_node.node;
if (status == kFail || status == kSuccess || raw_new_node->is_variable())
return false;
// If n isn't the last matched node, then we encoutered a internal
// branch, we should pop out the node behind n and stop fusion.
Expand All @@ -95,8 +99,8 @@ class ElemwiseMulPostQuantizeSelector : public SubgraphSelector {

switch (status) {
case kStart:
if (new_node.op() == Op::Get("_contrib_requantize")) {
auto const &param = nnvm::get<RequantizeParam>(new_node.attrs.parsed);
if (raw_new_node->op() == Op::Get("_contrib_requantize")) {
auto const &param = nnvm::get<RequantizeParam>(raw_new_node->attrs.parsed);
if (param.min_calib_range.has_value() &&
param.max_calib_range.has_value()) {
matched_list.push_back(&new_node);
Expand All @@ -105,7 +109,20 @@ class ElemwiseMulPostQuantizeSelector : public SubgraphSelector {
}
}
case kRequantize:
if ((!disable_float_output) && (new_node.op() == Op::Get("_contrib_dequantize"))) {
if ((!disable_float_output) && (raw_new_node->op() == Op::Get("_contrib_dequantize"))) {
CHECK(raw_node->op() == Op::Get("_contrib_requantize"));
if (n.outputs.size() > 1) {
// check if requantize have other outputs than dequantize
// if it has we can't fuse dequantize into elemwise_mul
for (auto kv : n.outputs) {
const auto& node = kv.first;
if (node->op() != Op::Get("_contrib_dequantize")) {
status = kSuccess;
return false;
}
}
}

matched_list.push_back(&new_node);
status = kSuccess;
return true;
Expand All @@ -116,14 +133,14 @@ class ElemwiseMulPostQuantizeSelector : public SubgraphSelector {
}
}

std::vector<nnvm::Node *> Filter(
const std::vector<nnvm::Node *> &candidates) override {
std::vector<BiDirectedNode *> Filter(
const std::vector<BiDirectedNode *>& candidates) override {
if ((status != kSuccess) || (matched_list.size() <= 1)) {
return std::vector<nnvm::Node *>(0);
return std::vector<BiDirectedNode *>(0);
} else {
std::vector<nnvm::Node *> ret;
std::vector<BiDirectedNode *> ret;
for (auto i : matched_list) {
auto non_const_i = const_cast<nnvm::Node *>(i);
auto non_const_i = const_cast<BiDirectedNode *>(i);
if (std::find(candidates.begin(), candidates.end(), non_const_i) !=
candidates.end()) {
ret.push_back(non_const_i);
Expand Down Expand Up @@ -194,7 +211,7 @@ class ElemwiseMulPostQuantizeProperty : public SubgraphProperty {
return em_node;
}

SubgraphSelectorPtr CreateSubgraphSelector() const override {
SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override {
auto selector =
std::make_shared<ElemwiseMulPostQuantizeSelector>(disable_fuse_all,
disable_float_output);
Expand Down
48 changes: 32 additions & 16 deletions src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_FC_POST_QUANTIZE_PROPERTY_H_
#if MXNET_USE_ONEDNN == 1

#include <memory>
#include <string>
#include <vector>
#include "../../nn/fully_connected-inl.h"
Expand All @@ -40,7 +41,7 @@ namespace op {

#define QUANTIZED_FC_NAME "_sg_mkldnn_fully_connected"

class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelector {
class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelectorV2 {
public:
/*! \brief pattern match status */
enum SelectStatus {
Expand All @@ -54,16 +55,17 @@ class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelector {
bool disable_all;
bool disable_float_output;
SelectStatus status;
std::vector<const nnvm::Node *> matched_list;
std::vector<const BiDirectedNode *> matched_list;

public:
explicit SgMKLDNNFCPostQuantizeSelector(const bool dis_all,
const bool dis_float_output)
: disable_all(dis_all),
disable_float_output(dis_float_output) {}

bool Select(const nnvm::Node &n) override {
if ((!disable_all) && n.op() == Op::Get(QUANTIZED_FC_NAME)) {
bool Select(const BiDirectedNode &n) override {
const auto rawnode = n.node;
if ((!disable_all) && rawnode->op() == Op::Get(QUANTIZED_FC_NAME)) {
status = disable_all ? kSuccess : kStart;
matched_list.clear();
matched_list.push_back(&n);
Expand All @@ -72,12 +74,14 @@ class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelector {
return false;
}

bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override {
bool SelectInput(const BiDirectedNode &n, const BiDirectedNode &new_node) override {
return false;
}

bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override {
if (status == kFail || status == kSuccess || new_node.is_variable())
bool SelectOutput(const BiDirectedNode &n, const BiDirectedNode &new_node) override {
const auto raw_node = n.node;
const auto raw_new_node = new_node.node;
if (status == kFail || status == kSuccess || raw_new_node->is_variable())
return false;
// If n isn't the last matched node, then we encoutered a internal
// branch, we should pop out the node behind n and stop fusion.
Expand All @@ -95,8 +99,8 @@ class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelector {

switch (status) {
case kStart:
if (new_node.op() == Op::Get("_contrib_requantize")) {
auto const &param = nnvm::get<RequantizeParam>(new_node.attrs.parsed);
if (raw_new_node->op() == Op::Get("_contrib_requantize")) {
auto const &param = nnvm::get<RequantizeParam>(raw_new_node->attrs.parsed);
if (param.min_calib_range.has_value() &&
param.max_calib_range.has_value()) {
matched_list.push_back(&new_node);
Expand All @@ -105,7 +109,19 @@ class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelector {
}
}
case kRequantize:
if ((!disable_float_output) && (new_node.op() == Op::Get("_contrib_dequantize"))) {
if ((!disable_float_output) && (raw_new_node->op() == Op::Get("_contrib_dequantize"))) {
CHECK(raw_node->op() == Op::Get("_contrib_requantize"));
if (n.outputs.size() > 1) {
// check if requantize have other outputs than dequantize
// if it has we can't fuse dequantize into FC
for (auto kv : n.outputs) {
const auto& node = kv.first;
if (node->op() != Op::Get("_contrib_dequantize")) {
status = kSuccess;
return false;
}
}
}
matched_list.push_back(&new_node);
status = kSuccess;
return true;
Expand All @@ -116,14 +132,14 @@ class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelector {
}
}

std::vector<nnvm::Node *> Filter(
const std::vector<nnvm::Node *> &candidates) override {
std::vector<BiDirectedNode *> Filter(
const std::vector<BiDirectedNode *>& candidates) override {
if ((status != kSuccess) || (matched_list.size() <= 1)) {
return std::vector<nnvm::Node *>(0);
return std::vector<BiDirectedNode *>(0);
} else {
std::vector<nnvm::Node *> ret;
std::vector<BiDirectedNode *> ret;
for (auto i : matched_list) {
auto non_const_i = const_cast<nnvm::Node *>(i);
auto non_const_i = const_cast<BiDirectedNode *>(i);
if (std::find(candidates.begin(), candidates.end(), non_const_i) !=
candidates.end()) {
ret.push_back(non_const_i);
Expand Down Expand Up @@ -194,7 +210,7 @@ class SgMKLDNNFCPostQuantizeProperty : public SubgraphProperty {
return fc_node;
}

SubgraphSelectorPtr CreateSubgraphSelector() const override {
SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override {
auto selector =
std::make_shared<SgMKLDNNFCPostQuantizeSelector>(disable_fuse_all,
disable_float_output);
Expand Down
2 changes: 1 addition & 1 deletion src/operator/subgraph/mkldnn/mkldnn_fc_property.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ namespace op {

class SgMKLDNNFCSelector : public SubgraphSelector {
public:
/*! \brief pattern match status */
/* pattern match status */
enum SelectStatus {
kFail = 0,
kStart,
Expand Down
24 changes: 24 additions & 0 deletions tests/python/mkl/subgraphs/test_fc_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,27 @@ def hybrid_forward(self, F, x, weight, bias):
out_quantized = qnet(data_nd)
assert_almost_equal_with_err(out.asnumpy(), out_quantized.asnumpy(),
rtol=1e-2, atol=1e-2, etol=0.01)


@pytest.mark.parametrize('data_shape', DATA_SHAPE)
def test_fc_int8_and_fp32_outputs(data_shape):

# /---> Quantizable op
# Input ---> FC -|
# \---> Non quantizable op

class MultiOutputFC(nn.HybridBlock):
def __init__(self, **kwargs):
super(MultiOutputFC, self).__init__(**kwargs)
self.dense0 = nn.Dense(64)
self.dense1 = nn.Dense(64)

def hybrid_forward(self, F, x):
x = self.dense0(x)
y = self.dense1(x) # quantizable
z = F.softmax(x) # non quantizable
return y + z

attrs = {'fc': {}}
net = MultiOutputFC()
check_fusion(net, data_shape, attrs, check_quantization=True)