Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix return type error on gcc 4.8.5 #5660

Merged
merged 10 commits into from
Jul 30, 2021
1 change: 1 addition & 0 deletions oneflow/core/framework/py_remote_blob.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ int64_t EagerBlobTrait::split_axis() const {
return INVALID_SPLIT_AXIS;
} else {
UNIMPLEMENTED();
return 0;
}
}

Expand Down
1 change: 1 addition & 0 deletions oneflow/core/job_rewriter/pass_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.

namespace oneflow {
#define INSERT_CHECK(expr) CHECK(expr.second)
#define INSERT_CHECK_OR_RETURN(expr) CHECK_OR_RETURN(expr.second)

template<typename MapT, typename KeyT>
bool IsKeyFound(const MapT& m, const KeyT& k) {
Expand Down
111 changes: 58 additions & 53 deletions oneflow/core/job_rewriter/quantization_aware_training.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,12 @@ const std::string MUL_BIAS_SUFFIX = "-fake-quant-mul-bias";
const std::string OBSERVER_SUFFIX = "-fake-quant-observer";
const std::string TRAIN_STEP_SUFFIX = "-fake-train-step";

void VerifyQATList(const OpTypeSet& op_list) {
Maybe<void> VerifyQATList(const OpTypeSet& op_list) {
for (const auto& op_type : op_list) {
CHECK(user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_type) != nullptr)
CHECK_OR_RETURN(user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_type) != nullptr)
<< "Cannot find " << op_type << " of QuantAwareTraining list in OpRegistry.";
}
return Maybe<void>::Ok();
}

HashMap<std::string, std::string> scale_map;
Expand Down Expand Up @@ -168,47 +169,47 @@ std::string QuantizationSchemeAttr4QatConfig(const QatConfig& qat_config) {
}

// TODO: refactor the following 4 methods by registration
std::string QuantizationFormulaAttr4QatConfig(const QatConfig& qat_config) {
Maybe<std::string> QuantizationFormulaAttr4QatConfig(const QatConfig& qat_config) {
const auto target_backend = qat_config.target_backend();
if (target_backend == "" || target_backend == "tensorrt") {
return "google";
return std::string("google");
} else if (target_backend == "cambricon") {
return "cambricon";
return std::string("cambricon");
} else {
UNIMPLEMENTED();
UNIMPLEMENTED_THEN_RETURN();
}
}

OpTypeSet Int8List4QatConfig(const QatConfig& qat_config) {
Maybe<OpTypeSet> Int8List4QatConfig(const QatConfig& qat_config) {
const auto target_backend = qat_config.target_backend();
if (target_backend == "") {
return {"add_n", "matmul", "batch_matmul", "conv2d", "avg_pool_2d", "max_pool_2d"};
return OpTypeSet{"add_n", "matmul", "batch_matmul", "conv2d", "avg_pool_2d", "max_pool_2d"};
} else if (target_backend == "cambricon" || target_backend == "tensorrt") {
return {"conv2d", "matmul"};
return OpTypeSet{"conv2d", "matmul"};
} else {
UNIMPLEMENTED();
UNIMPLEMENTED_THEN_RETURN();
}
}

OpTypeSet TransparentList4QatConfig(const QatConfig& qat_config) {
Maybe<OpTypeSet> TransparentList4QatConfig(const QatConfig& qat_config) {
const auto target_backend = qat_config.target_backend();
if (target_backend == "" || target_backend == "tensorrt") {
return {"reshape"};
return OpTypeSet{"reshape"};
} else if (target_backend == "cambricon") {
return {};
return OpTypeSet{};
} else {
UNIMPLEMENTED();
UNIMPLEMENTED_THEN_RETURN();
}
}

bool InsertQuantOpAfterInt8Ops4QatConfig(const QatConfig& qat_config) {
Maybe<bool> InsertQuantOpAfterInt8Ops4QatConfig(const QatConfig& qat_config) {
const auto target_backend = qat_config.target_backend();
if (target_backend == "" || target_backend == "tensorrt") {
return true;
} else if (target_backend == "cambricon") {
return false;
} else {
UNIMPLEMENTED();
UNIMPLEMENTED_THEN_RETURN();
}
}

Expand All @@ -226,16 +227,18 @@ user_op::UserOpConfWrapper MultiplyOp(const std::string& name, const std::string
return op_wrapper;
}

user_op::UserOpConfWrapper MinMaxObserver(const std::string& name, const std::string& input,
const QatConfig& qat_config,
const int64_t scope_symbol_id, OpConfMap* inserted_ops) {
Maybe<user_op::UserOpConfWrapper> MinMaxObserver(const std::string& name, const std::string& input,
const QatConfig& qat_config,
const int64_t scope_symbol_id,
OpConfMap* inserted_ops) {
const auto op_wrapper =
user_op::UserOpConfWrapperBuilder(name)
.Op("min_max_observer")
.Input("in", input)
.Output("scale")
.Output("zero_point")
.Attr<std::string>("quantization_formula", QuantizationFormulaAttr4QatConfig(qat_config))
.Attr<std::string>("quantization_formula",
*JUST(QuantizationFormulaAttr4QatConfig(qat_config)))
.Attr<std::string>("quantization_scheme", QuantizationSchemeAttr4QatConfig(qat_config))
.Attr("per_layer_quantization", PerLayerQuantizationAttr4Config(qat_config))
.ScopeSymbolId(scope_symbol_id)
Expand All @@ -244,11 +247,9 @@ user_op::UserOpConfWrapper MinMaxObserver(const std::string& name, const std::st
return op_wrapper;
}

user_op::UserOpConfWrapper MovingMinMaxObserver(const std::string& name, const std::string& input,
const std::string& train_step_lbn,
const QatConfig& qat_config,
const int64_t scope_symbol_id,
OpConfMap* inserted_ops) {
Maybe<user_op::UserOpConfWrapper> MovingMinMaxObserver(
const std::string& name, const std::string& input, const std::string& train_step_lbn,
const QatConfig& qat_config, const int64_t scope_symbol_id, OpConfMap* inserted_ops) {
const std::string moving_max_name = name + MOVING_MAX_SUFFIX;
const std::string moving_min_name = name + MOVING_MIN_SUFFIX;
const auto moving_max_var =
Expand Down Expand Up @@ -276,7 +277,8 @@ user_op::UserOpConfWrapper MovingMinMaxObserver(const std::string& name, const s
.Output("zero_point")
.Attr("training", GlobalJobDesc().IsTrain())
.Attr("stop_update_after_iters", qat_config.moving_min_max_stop_update_after_iters())
.Attr<std::string>("quantization_formula", QuantizationFormulaAttr4QatConfig(qat_config))
.Attr<std::string>("quantization_formula",
*JUST(QuantizationFormulaAttr4QatConfig(qat_config)))
.Attr<std::string>("quantization_scheme", QuantizationSchemeAttr4QatConfig(qat_config))
.Attr("momentum", qat_config.moving_min_max_momentum())
.ScopeSymbolId(scope_symbol_id)
Expand All @@ -285,17 +287,20 @@ user_op::UserOpConfWrapper MovingMinMaxObserver(const std::string& name, const s
return op_wrapper;
}

user_op::UserOpConfWrapper FakeQuantOp(const std::string& name, const std::string& input,
const std::string& scale, const std::string& zero_point,
const QatConfig& qat_config, const int64_t scope_symbol_id,
OpConfMap* inserted_ops) {
Maybe<user_op::UserOpConfWrapper> FakeQuantOp(const std::string& name, const std::string& input,
const std::string& scale,
const std::string& zero_point,
const QatConfig& qat_config,
const int64_t scope_symbol_id,
OpConfMap* inserted_ops) {
const auto op_wrapper =
user_op::UserOpConfWrapperBuilder(name)
.Op("fake_quantization")
.Input("in", input)
.Input("scale", scale)
.Input("zero_point", zero_point)
.Attr<std::string>("quantization_formula", QuantizationFormulaAttr4QatConfig(qat_config))
.Attr<std::string>("quantization_formula",
*JUST(QuantizationFormulaAttr4QatConfig(qat_config)))
.Attr<std::string>("quantization_scheme", QuantizationSchemeAttr4QatConfig(qat_config))
.Output("out")
.ScopeSymbolId(scope_symbol_id)
Expand Down Expand Up @@ -329,15 +334,15 @@ Maybe<void> GetScaleAndZeroPointLbn4Edge(OpEdge* edge, const std::string train_s
const std::string observer_op_name = ReplaceSlashToDash4Lbn(lbn) + OBSERVER_SUFFIX;
if (IsWeightEdge(edge)) {
const auto observer_op =
MinMaxObserver(observer_op_name, lbn, qat_config, scope_symbol_id, inserted_ops);
*scale = observer_op.output("scale", 0);
*zero_point = observer_op.output("zero_point", 0);
JUST(MinMaxObserver(observer_op_name, lbn, qat_config, scope_symbol_id, inserted_ops));
*scale = observer_op->output("scale", 0);
*zero_point = observer_op->output("zero_point", 0);
} else {
CHECK_OR_RETURN(qat_config.has_moving_min_max_stop_update_after_iters());
const auto observer_op = MovingMinMaxObserver(observer_op_name, lbn, train_step_lbn,
qat_config, scope_symbol_id, inserted_ops);
*scale = observer_op.output("scale", 0);
*zero_point = observer_op.output("zero_point", 0);
const auto observer_op = JUST(MovingMinMaxObserver(
observer_op_name, lbn, train_step_lbn, qat_config, scope_symbol_id, inserted_ops));
*scale = observer_op->output("scale", 0);
*zero_point = observer_op->output("zero_point", 0);
}
}
return Maybe<void>::Ok();
Expand Down Expand Up @@ -374,30 +379,30 @@ class QuantAwareTraining final : public JobPass {
HashSet<OpNode*> downstream_white, Job* job) const;
};

bool IsNodeQuantizationEnabled(const OpNode& node) {
Maybe<bool> IsNodeQuantizationEnabled(const OpNode& node) {
int64_t scope_symbol_id = node.op().op_conf().scope_symbol_id();
CHECK(Global<symbol::Storage<Scope>>::Get()->Has(scope_symbol_id));
CHECK_OR_RETURN(Global<symbol::Storage<Scope>>::Get()->Has(scope_symbol_id));
const Scope& scope = Global<symbol::Storage<Scope>>::Get()->Get(scope_symbol_id);
return scope.Bool("quantization_aware_training");
}

Maybe<void> QuantAwareTraining::Apply(Job* job, JobPassCtx* ctx) const {
if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }
const OpGraph op_graph(*job);
CHECK(GlobalJobDesc().DefaultDataType() == DataType::kFloat);
CHECK_OR_RETURN(GlobalJobDesc().DefaultDataType() == DataType::kFloat);

const auto qat_config = ctx->job_desc().job_conf().qat_config();

OpTypeSet int8_list = Int8List4QatConfig(qat_config);
OpTypeSet transparent_list = TransparentList4QatConfig(qat_config);
OpTypeSet int8_list = *JUST(Int8List4QatConfig(qat_config));
OpTypeSet transparent_list = *JUST(TransparentList4QatConfig(qat_config));
// if `insert_quant_op_after_int8_ops` is false,
// always insert quant op before int8 ops.
// if `insert_quant_op_after_int8_ops` is true,
// always insert quant op after int8 ops
bool insert_quant_op_after_int8_ops = InsertQuantOpAfterInt8Ops4QatConfig(qat_config);
bool insert_quant_op_after_int8_ops = JUST(InsertQuantOpAfterInt8Ops4QatConfig(qat_config));

VerifyQATList(int8_list);
VerifyQATList(transparent_list);
JUST(VerifyQATList(int8_list));
JUST(VerifyQATList(transparent_list));

std::function<std::string(OpNode* const&)> OpName4Node = [](OpNode* const& node) {
return node->op().op_name();
Expand Down Expand Up @@ -456,7 +461,7 @@ Maybe<void> QuantAwareTraining::InsertFakeQuantOp(const QatConfig& qat_config,
const std::string lbn = GenLogicalBlobName(edge->lbis().front());
scale_map[lbn] = ReplaceSlashToDash4Lbn(lbn) + OBSERVER_SUFFIX + "/scale_0";
VLOG(3) << "set " << lbn << " to " << scale_map[lbn];
INSERT_CHECK(white_set_edges.insert(edge));
INSERT_CHECK_OR_RETURN(white_set_edges.insert(edge));
return Maybe<void>::Ok();
};
auto PropagateScale = [](OpNode* node) -> Maybe<void> {
Expand All @@ -478,16 +483,16 @@ Maybe<void> QuantAwareTraining::InsertFakeQuantOp(const QatConfig& qat_config,
if (IsKeyFound(white_set, node)) {
for (OpEdge* edge : node->in_edges()) {
if (IsKeyFound(white_set, edge->src_node())) { continue; }
if (IsNodeQuantizationEnabled(*edge->dst_node())) { JUST(AddWhiteSetEdge(edge)); }
if (JUST(IsNodeQuantizationEnabled(*edge->dst_node()))) { JUST(AddWhiteSetEdge(edge)); }
}
if (IsNodeInList(int8_list, node)) {
if (insert_quant_op_after_int8_ops) {
OpNode* inference_node = JUST(GetInferenceOutputNode(op_graph, node));
if (IsNodeQuantizationEnabled(*inference_node)) {
if (JUST(IsNodeQuantizationEnabled(*inference_node))) {
for (OpEdge* edge : inference_node->out_edges()) { JUST(AddWhiteSetEdge(edge)); }
}
} else {
if (IsNodeQuantizationEnabled(*node)) {
if (JUST(IsNodeQuantizationEnabled(*node))) {
for (OpEdge* edge : node->in_edges()) {
if (white_set_edges.find(edge) == white_set_edges.end()) {
JUST(AddWhiteSetEdge(edge));
Expand Down Expand Up @@ -535,10 +540,10 @@ Maybe<void> QuantAwareTraining::InsertFakeQuantOp(const QatConfig& qat_config,
JUST(GetScaleAndZeroPointLbn4Edge(edge, job->job_conf().train_conf().train_step_lbn(), &scale,
&zero_point, qat_config, scope_symbol_id, &inserted_ops));
const std::string fake_quant_op_name = ReplaceSlashToDash4Lbn(lbn) + FAKE_QUANT_SUFFIX;
const auto fake_quant_op = FakeQuantOp(fake_quant_op_name, lbn, scale, zero_point, qat_config,
scope_symbol_id, &inserted_ops);
const auto fake_quant_op = JUST(FakeQuantOp(fake_quant_op_name, lbn, scale, zero_point,
qat_config, scope_symbol_id, &inserted_ops));

const std::string fake_quant_op_output_name = fake_quant_op.output("out", 0);
const std::string fake_quant_op_output_name = fake_quant_op->output("out", 0);

JUST(ReplaceInputLbn4DstNodeOfEdge(edge, fake_quant_op_output_name, &op_conf_cache));
}
Expand Down
2 changes: 2 additions & 0 deletions oneflow/xrt/xla/ops/scalar_binary_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class ScalarBinaryOp : public XlaOpKernel {
double value = ctx->Attr<double>("float_operand");
return FloatLiteral(builder, data_type, value);
}
UNIMPLEMENTED();
return xla::XlaOp();
}
};

Expand Down