Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add fake_quantization op conversion in onnx #4512

Merged
merged 41 commits into from
Mar 28, 2021
Merged
Changes from 1 commit
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
76765bd
add min_max_observer and moving_average_min_max_observer conversion i…
mosout Mar 20, 2021
bb69ec9
fix op's name
mosout Mar 21, 2021
a1cbfaf
fix moving_average_min_max_observer
mosout Mar 21, 2021
67415dd
update quantization ops conversion
mosout Mar 22, 2021
a643999
update quantization ops conversion and tests
mosout Mar 23, 2021
bfa1abe
update ops version
mosout Mar 23, 2021
39389d4
delete auto-imported package
mosout Mar 23, 2021
2efd5d1
update min_max_observer op
mosout Mar 23, 2021
1dfcc4f
Merge branch 'master' into quantize
oneflow-ci-bot Mar 23, 2021
9f2f8f8
format
mosout Mar 23, 2021
17e6848
Merge remote-tracking branch 'origin/quantize' into quantize
mosout Mar 23, 2021
516d82d
Merge branch 'master' into quantize
oneflow-ci-bot Mar 23, 2021
7ca1c77
Merge branch 'master' into quantize
oneflow-ci-bot Mar 23, 2021
dec9f35
Merge branch 'master' into quantize
oneflow-ci-bot Mar 23, 2021
3f02448
Merge branch 'master' into quantize
oneflow-ci-bot Mar 24, 2021
338e401
Merge branch 'master' into quantize
oneflow-ci-bot Mar 24, 2021
6f5e08e
fix test_quantization_aware_training
mosout Mar 24, 2021
32e3112
Merge remote-tracking branch 'origin/quantize' into quantize
mosout Mar 24, 2021
5347a90
Merge branch 'master' into quantize
oneflow-ci-bot Mar 24, 2021
e67e140
Merge branch 'master' into quantize
oneflow-ci-bot Mar 24, 2021
1a7440a
Merge branch 'master' into quantize
oneflow-ci-bot Mar 24, 2021
4ebc15b
Merge branch 'master' into quantize
oneflow-ci-bot Mar 24, 2021
c19076f
update test_quantize_op
mosout Mar 24, 2021
91f9b56
Merge remote-tracking branch 'origin/quantize' into quantize
mosout Mar 24, 2021
9ecb743
Merge branch 'master' into quantize
oneflow-ci-bot Mar 24, 2021
c269cad
Merge branch 'master' into quantize
oneflow-ci-bot Mar 24, 2021
9405ff9
Merge branch 'master' into quantize
oneflow-ci-bot Mar 24, 2021
fd0cfd0
Merge branch 'master' into quantize
oneflow-ci-bot Mar 24, 2021
817cd92
Merge branch 'master' into quantize
oneflow-ci-bot Mar 24, 2021
6a18c05
add fake_quantization conversion in onnx
mosout Mar 25, 2021
1d4c4db
Merge remote-tracking branch 'origin/quantize' into quantize
mosout Mar 25, 2021
6a6ba42
Merge remote-tracking branch 'upstream/master' into quantize
mosout Mar 25, 2021
7338873
format quantize.py
mosout Mar 25, 2021
50ddd0f
fix fake_quantization
mosout Mar 27, 2021
9560830
update fake_quantization conversion
mosout Mar 27, 2021
42783d0
update fake_quantization conversion and its test
mosout Mar 27, 2021
cfd86e6
update quantization_aware_training
mosout Mar 27, 2021
80222b6
Merge branch 'master' into quantize
oneflow-ci-bot Mar 27, 2021
74ee5ae
update quantization_aware_training
mosout Mar 27, 2021
da8f65e
Merge remote-tracking branch 'upstream/master' into quantize
mosout Mar 27, 2021
018dd12
Merge remote-tracking branch 'origin/quantize' into quantize
mosout Mar 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 11 additions & 14 deletions oneflow/core/job_rewriter/quantization_aware_training.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ std::string OpTypeName4OpNode(const OpNode* node) {

using OpConfMap = HashMap<std::string, OperatorConf>;

template<DataType VALUE_TYPE = DataType::kFloat>
template<DataType data_type = DataType::kFloat>
OperatorConf Get1DZeroVariableOpConf(std::string name, const int64_t scope_symbol_id,
const int64_t length, OpConfMap* inserted_ops) {
OperatorConf variable_op_conf{};
Expand All @@ -133,7 +133,7 @@ OperatorConf Get1DZeroVariableOpConf(std::string name, const int64_t scope_symbo
VariableOpConf* variable_conf = variable_op_conf.mutable_variable_conf();
variable_conf->set_out("out");
*variable_conf->mutable_shape()->mutable_dim()->Add() = length;
variable_conf->set_data_type(VALUE_TYPE);
variable_conf->set_data_type(data_type);
variable_conf->mutable_initializer()->mutable_constant_conf()->set_value(0);
(*inserted_ops)[name] = variable_op_conf;
return variable_op_conf;
Expand Down Expand Up @@ -255,22 +255,19 @@ user_op::UserOpConfWrapper MovingMinMaxObserver(const std::string& name, const s
Get1DZeroVariableOpConf(moving_max_name, scope_symbol_id, 1, inserted_ops);
const auto moving_min_var =
Get1DZeroVariableOpConf(moving_min_name, scope_symbol_id, 1, inserted_ops);
std::string train_step_value = train_step_lbn;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里换一个清晰一些的名字,例如 observer_current_train_step

if (!GlobalJobDesc().IsTrain()) {
const std::string train_step_name = name + TRAIN_STEP_SUFFIX;
const auto train_step_var = Get1DZeroVariableOpConf<DataType::kInt64>(
train_step_name, scope_symbol_id, 1, inserted_ops);
train_step_value =
GenLogicalBlobName(train_step_var.name(), train_step_var.variable_conf().out());
}
const auto op_wrapper =
user_op::UserOpConfWrapperBuilder(name)
.Op("moving_average_min_max_observer")
.Input("in", input)
.Input("current_train_step",
[&] {
const std::string train_step_name = name + TRAIN_STEP_SUFFIX;
const auto train_step_var = Get1DZeroVariableOpConf<DataType::kInt64>(
train_step_name, scope_symbol_id, 1, inserted_ops);
if (GlobalJobDesc().IsTrain()) {
return train_step_lbn;
} else {
return GenLogicalBlobName(train_step_var.name(),
train_step_var.variable_conf().out());
}
}())
.Input("current_train_step", train_step_value)
.Input("moving_max",
GenLogicalBlobName(moving_max_var.name(), moving_max_var.variable_conf().out()))
.Input("moving_min",
Expand Down