Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[FEATURE] Restore Quantization API to MXNet #19587

Merged
merged 40 commits into from
Dec 12, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
840bc2d
Restore quantization files
bgawrych Sep 16, 2020
2f2adcb
Adapt quantization.py - Add/Remove modules
bgawrych Sep 16, 2020
a2fd342
Adapt part of quantization tests to new API
bgawrych Sep 16, 2020
6d6ff15
fuse fc+tanh
sfraczek Sep 18, 2020
889a0dd
Replace Module API with SymbolBlock in quantize_model
bgawrych Sep 21, 2020
2937d5b
enabled test_quantization_mkldnn.py
sfraczek Sep 29, 2020
2ed842c
Revert "fuse fc+tanh"
bgawrych Sep 30, 2020
7694290
Enable tests from test_subgraph.py
bgawrych Oct 1, 2020
9364b30
Enable test_mobilenetv2_struct
bgawrych Oct 2, 2020
2da0f96
Refactor test_subgraph.py
bgawrych Oct 5, 2020
b6c3740
Reorder of conditions
sfraczek Oct 8, 2020
d8af341
Utilize optimize_for in quantization flow
bgawrych Oct 7, 2020
2e8aab7
remove duplicate imports
bgawrych Oct 8, 2020
8bdfb0b
Add variable monitor callback
bgawrych Oct 19, 2020
03a568f
fix sanity
bgawrych Oct 19, 2020
9c7483a
wip
sfraczek Oct 29, 2020
cf5376c
Rebase to master - remove with_seed
bgawrych Nov 3, 2020
07b3942
Add numpy support for quantization
bgawrych Nov 3, 2020
a8f80e9
enabled examples/quantization/imagenet_gen_qsym_mkldnn.py
sfraczek Nov 5, 2020
e7428b2
Add test to check different way of data generation for hybridize
bgawrych Nov 5, 2020
fb7bf99
Copy original network
bgawrych Nov 6, 2020
f095f25
Change num_calib_examples to num_calib_batches
bgawrych Nov 6, 2020
4e3591e
enabling imagenet_inference.py
sfraczek Nov 9, 2020
c898b14
Add base class for collectors and feed custom with calib_layers
bgawrych Nov 9, 2020
bda6463
Some doc fixes after discussion
grygielski Nov 12, 2020
0375801
anko review - change all quantize_net_v2 to quantize_net
bgawrych Nov 13, 2020
3e1ee58
Make -s argument required
bgawrych Nov 13, 2020
2db7c6d
review fixes by mozga and anko
sfraczek Nov 18, 2020
b080a24
Fix bugs
bgawrych Nov 18, 2020
0773950
Fix channel-wise quantization
bgawrych Nov 23, 2020
297eaac
Fix documentation formatting
bgawrych Nov 24, 2020
3b6cd2b
mozga: fix review
bgawrych Nov 23, 2020
a62047b
Fix lint
bgawrych Nov 26, 2020
3f661bd
Refactor calibration for variables
bgawrych Nov 26, 2020
674f48f
fix sanity
bgawrych Nov 27, 2020
cbb8b2e
fix clang tidy
bgawrych Nov 27, 2020
20fce07
Merge branch 'master' into quantization_20
bgawrych Nov 30, 2020
a9daf5d
ciyong review fixes
bgawrych Dec 4, 2020
1e34d48
Add verified models
bgawrych Dec 9, 2020
f24e2bc
Fix review: Tao and Xinyu
bgawrych Dec 10, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor calibration for variables
  • Loading branch information
bgawrych committed Nov 27, 2020
commit 3f661bd492137923a7eb72d9c7765329ee7e274d
7 changes: 0 additions & 7 deletions src/common/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,6 @@ void ExecuteMonInputCallback(
if (state_arrays[idx.entry_id(input)]->is_none()) {
continue;
}
if (input.node->is_variable()) {
// Monitor variable
NDArray *var_cpy = new NDArray(*state_arrays[idx.entry_id(input)]);
std::string var_name = input.node->attrs.name;
monitor_callback(var_name.c_str(), inode.source->op()->name.c_str(),
reinterpret_cast<void*>(var_cpy));
}
NDArray *cpy = new NDArray(*state_arrays[idx.entry_id(input)]);
std::string name = inode.source->attrs.name + "_" + input_names[i];
monitor_callback(name.c_str(), inode.source->op()->name.c_str(),
Expand Down
31 changes: 29 additions & 2 deletions src/operator/quantization/quantize_graph_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <nnvm/graph.h>
#include <nnvm/pass.h>
#include <queue>
#include <stack>
#include <unordered_map>
#include <unordered_set>
#include <vector>
Expand Down Expand Up @@ -339,13 +340,17 @@ Graph QuantizeGraph(Graph &&src) {
// Or the output name is not ending with 'output', just put the output name here
// to better align with calibration phase. No need to change name to weights/bias.
std::string suffix = "";
std::string new_name = e.node->attrs.name;

if (mirror_node->op() != nullptr) {
std::string name = GetOutputName(e.node.get(), e.index);
suffix = "_" + name;
} else if(!offline_params.count(new_name)){
new_name = node->attrs.name + "_" + new_name;
}

ObjectPtr quantize_node = InsertNode("_contrib_quantize_v2",
e.node->attrs.name + suffix + "_quantize", new_node, mirror_entry);
new_name + suffix + "_quantize", new_node, mirror_entry);
quantize_node->attrs.dict["out_type"] = quantized_dtype;
quantize_node->op()->attr_parser(&(quantize_node->attrs));
mirror_entry_map[e] = NodeEntry{quantize_node, 0, e.version};
Expand Down Expand Up @@ -498,8 +503,23 @@ Graph QuantizeGraph(Graph &&src) {
Op::GetAttr<mxnet::FNeedCalibrateInput>("FNeedCalibrateInput");
static const auto& need_calib_output_map =
Op::GetAttr<mxnet::FNeedCalibrateOutput>("FNeedCalibrateOutput");

std::stack<std::string> calib_variables;
std::vector<std::string> calib_nodes;
DFSVisit(ret.outputs, [&](const ObjectPtr& node) {
if (node->op() && !calib_variables.empty()) {
if (reverse_mirror_map.count(node)) {
const std::string& var_name = calib_variables.top();
const auto& fp32_in_node = reverse_mirror_map[node];
for (const auto &input_node : fp32_in_node->inputs) {
if (var_name == input_node.node->attrs.name) {
calib_nodes.push_back(fp32_in_node->attrs.name + "_" + var_name);
calib_variables.pop();
break;
}
}
}
}
if (need_calib_input_map.count(node->op())) {
const auto calib_idx = need_calib_input_map[node->op()](node->attrs);
for (const auto &idx : calib_idx) {
Expand All @@ -510,7 +530,10 @@ Graph QuantizeGraph(Graph &&src) {
} else {
const auto& e = node->inputs[idx];
if (e.node->is_variable()) {
calib_nodes.push_back(e.node->attrs.name);
// monitor callback join operator name and variable name as observable node,
// utilize fact that we're using DFS and put variable name on stack to
// find operator node name for this variable node
calib_variables.emplace(e.node->attrs.name);
} else {
if (reverse_mirror_map.count(e.node)) {
const auto& fp32_in_node = reverse_mirror_map.at(e.node);
Expand Down Expand Up @@ -548,6 +571,10 @@ static inline void SetCalibTableForEntry(

if (!e.node->is_variable()) {
full_node_name += "_" + out_name;
} else {
const std::string suffix = "_quantize";
full_node_name = node->attrs.name;
full_node_name = std::string(full_node_name.begin(), full_node_name.end() - suffix.size());
}

const std::string prefix = "quantized_";
Expand Down