Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gen_bw_fn return maybe #5454

Merged
merged 36 commits into from
Jul 17, 2021
Merged
Changes from 7 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
d897143
modified SetInputArgModifyFn
luqiang-guo Jul 1, 2021
e33df72
Delete the CHECK changes in the assign_op.cpp file
luqiang-guo Jul 9, 2021
b6732e1
Format
luqiang-guo Jul 9, 2021
6e674d3
Modified the OutputArgModifyFn interface
luqiang-guo Jul 9, 2021
038c42f
Merge branch 'master' into Replace_check_using_maybe_check_part_SetIn…
luqiang-guo Jul 9, 2021
bb22c4d
Merge branch 'master' into Replace_check_using_maybe_check_part_Outpu…
luqiang-guo Jul 9, 2021
c7fccf9
add return
luqiang-guo Jul 9, 2021
9c3fe8b
maybe error stack from CheckAndConstructOp to OutputArgModifier callb…
liufengwei0103 Jul 10, 2021
3434303
maybe error stack from CheckAndConstructOp to OutputArgModifier callb…
liufengwei0103 Jul 10, 2021
46ddb46
Merge branch 'OutputArgModifier_return_maybe' of https://github.com/o…
liufengwei0103 Jul 10, 2021
2d0b5a7
OutputArgModifier return maybe part_1
liufengwei0103 Jul 10, 2021
1d97b6c
Merge branch 'OutputArgModifier_return_maybe_part_1' of https://githu…
liufengwei0103 Jul 10, 2021
dc45b30
maybe error stack from CheckAndConstructOp to OutputArgModifier callb…
liufengwei0103 Jul 10, 2021
1f0e385
Merge branch 'OutputArgModifier_return_maybe' of https://github.com/o…
liufengwei0103 Jul 10, 2021
fab170a
Merge branch 'OutputArgModifier_return_maybe_part_1' of https://githu…
liufengwei0103 Jul 10, 2021
eb78626
input_arg_modifier return maybe
liufengwei0103 Jul 10, 2021
041eecf
gen_bw_fn return maybe
liufengwei0103 Jul 10, 2021
3c213d8
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
liufengwei0103 Jul 16, 2021
648ff8b
add MakeGenBackwardOpConf because ofstatement-expression not allowed …
liufengwei0103 Jul 16, 2021
164a047
Merge branch 'master' into gen_bw_fn_return_maybe
liufengwei0103 Jul 16, 2021
4ef9bf7
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
liufengwei0103 Jul 16, 2021
b93b81c
add maybe after merge master
liufengwei0103 Jul 16, 2021
4473cde
Merge branch 'gen_bw_fn_return_maybe' of https://github.com/Oneflow-I…
liufengwei0103 Jul 16, 2021
88e0979
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
liufengwei0103 Jul 16, 2021
d5f4b10
Merge branch 'master' into gen_bw_fn_return_maybe
oneflow-ci-bot Jul 16, 2021
6b1a885
Merge branch 'master' into gen_bw_fn_return_maybe
oneflow-ci-bot Jul 16, 2021
3b43178
Merge branch 'master' into gen_bw_fn_return_maybe
oneflow-ci-bot Jul 16, 2021
3842722
Merge branch 'master' into gen_bw_fn_return_maybe
oneflow-ci-bot Jul 16, 2021
33c2e3e
Merge branch 'master' into gen_bw_fn_return_maybe
oneflow-ci-bot Jul 16, 2021
7899fbc
Merge branch 'master' into gen_bw_fn_return_maybe
oneflow-ci-bot Jul 16, 2021
4dfdb70
Merge branch 'master' into gen_bw_fn_return_maybe
oneflow-ci-bot Jul 16, 2021
edaf987
Merge branch 'master' into gen_bw_fn_return_maybe
oneflow-ci-bot Jul 17, 2021
74ab623
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
liufengwei0103 Jul 17, 2021
09a53d3
Merge branch 'gen_bw_fn_return_maybe' of https://github.com/Oneflow-I…
liufengwei0103 Jul 17, 2021
5a90886
fix bug: JUST in lambda
liufengwei0103 Jul 17, 2021
fba73cf
Merge branch 'master' into gen_bw_fn_return_maybe
liufengwei0103 Jul 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions oneflow/core/eager/opkernel_object.cpp
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@ Maybe<void> OpKernelObject::ResetOpAndKernel(
const OpNodeSignatureDesc& op_node_signature, const ParallelContext* parallel_ctx,
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc* parallel_desc) {
auto op = ConstructOp(op_conf_, device_type_);
auto op = JUST(ConstructOp(op_conf_, device_type_));
JUST(op->FillOpParallelDesc(*parallel_desc));
const auto LogicalBlobDesc4BnInOp = [&](const std::string& bn) -> const BlobDesc& {
return CHECK_JUST(op_node_signature.LogicalBlobDesc4BnInOp(bn));
@@ -55,7 +55,7 @@ Maybe<void> SystemOpKernelObject::ResetKernel(
const OpNodeSignatureDesc& op_node_signature, const ParallelContext* parallel_ctx,
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc* parallel_desc) {
auto op = ConstructOp(op_conf_, device_type_);
auto op = JUST(ConstructOp(op_conf_, device_type_));
JUST(op->FillOpParallelDesc(*parallel_desc));
const auto LogicalBlobDesc4BnInOp = [&](const std::string& bn) -> const BlobDesc& {
return CHECK_JUST(op_node_signature.LogicalBlobDesc4BnInOp(bn));
2 changes: 1 addition & 1 deletion oneflow/core/framework/op_kernel_infer_cache.cpp
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@ namespace user_op {

OpKernelInferCache::OpKernelInferCache(const KernelConf& kernel_conf, const JobDesc& job_desc) {
const OperatorConf& op_conf = kernel_conf.op_attribute().op_conf();
std::shared_ptr<Operator> op = ConstructOp(op_conf);
std::shared_ptr<Operator> op = CHECK_JUST(ConstructOp(op_conf));
cache_key_.job_desc = &job_desc;
cache_key_.op_conf_sym = op->GetOpConfWithoutOpNameAndLbn();
cache_key_.ibn_idx2shape_sym.resize(op->input_bns().size());
3 changes: 2 additions & 1 deletion oneflow/core/framework/user_op_registry.h
Original file line number Diff line number Diff line change
@@ -52,7 +52,8 @@ using InputArgModifyFn = std::function<Maybe<void>(GetInputArgModifier, const Us
using OutputArgModifier = OutputBlobModifier;
using GetOutputArgModifier =
std::function<OutputArgModifier*(const std::string& out_arg_name, int32_t out_arg_index)>;
using OutputArgModifyFn = std::function<void(GetOutputArgModifier, const UserOpConfWrapper&)>;
using OutputArgModifyFn =
std::function<Maybe<void>(GetOutputArgModifier, const UserOpConfWrapper&)>;
using OutputBlobTimeShapeInferFn = std::function<Maybe<void>(InferOutputBlobTimeShapeFnContext*)>;
using ParallelDistributionInferFn = std::function<Maybe<void>(InferParallelDistributionFnContext*)>;

2 changes: 1 addition & 1 deletion oneflow/core/graph/boxing_identity_task_node.cpp
Original file line number Diff line number Diff line change
@@ -40,7 +40,7 @@ void BoxingIdentityTaskNode::BuildExecGphAndRegst() {
op_conf.set_name("System-Boxing-Identity-" + NewUniqueId());
op_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(this->device_type())));
*op_conf.mutable_boxing_identity_conf()->mutable_lbi() = lbi();
std::shared_ptr<Operator> sole_op = ConstructOp(op_conf);
std::shared_ptr<Operator> sole_op = CHECK_JUST(ConstructOp(op_conf));
node->mut_op() = sole_op;
node->BindBnWithRegst(sole_op->SoleIbn(), GetSoleConsumedRegst("in"));
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
2 changes: 1 addition & 1 deletion oneflow/core/graph/boxing_zeros_task_node.cpp
Original file line number Diff line number Diff line change
@@ -45,7 +45,7 @@ void BoxingZerosTaskNode::BuildExecGphAndRegst() {
*op_conf.mutable_boxing_zeros_conf()->mutable_lbi() = lbi();
shape_.ToProto(op_conf.mutable_boxing_zeros_conf()->mutable_shape());
op_conf.mutable_boxing_zeros_conf()->set_data_type(data_type_);
std::shared_ptr<Operator> sole_op = ConstructOp(op_conf);
std::shared_ptr<Operator> sole_op = CHECK_JUST(ConstructOp(op_conf));
node->mut_op() = sole_op;
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn()));
2 changes: 1 addition & 1 deletion oneflow/core/graph/collective_boxing_pack_task_node.cpp
Original file line number Diff line number Diff line change
@@ -53,7 +53,7 @@ void CollectiveBoxingPackTaskNode::BuildExecGphAndRegst() {
src_sbp_parallel_.ToProto(collective_boxing_pack_conf->mutable_src_sbp_parallel());
dst_sbp_parallel_.ToProto(collective_boxing_pack_conf->mutable_dst_sbp_parallel());
collective_boxing_pack_conf->set_num_ranks(parallel_num_);
std::shared_ptr<Operator> sole_op = ConstructOp(op_conf);
std::shared_ptr<Operator> sole_op = CHECK_JUST(ConstructOp(op_conf));
node->mut_op() = sole_op;
node->BindBnWithRegst(sole_op->SoleIbn(), GetSoleConsumedRegst("in"));
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
2 changes: 1 addition & 1 deletion oneflow/core/graph/collective_boxing_task_node.cpp
Original file line number Diff line number Diff line change
@@ -41,7 +41,7 @@ void CollectiveBoxingGenericTaskNode::ConsumeAllRegsts() {

void CollectiveBoxingGenericTaskNode::BuildExecGphAndRegst() {
ExecNode* node = mut_exec_gph().NewNode();
std::shared_ptr<Operator> boxing_op = ConstructOp(op_conf_);
std::shared_ptr<Operator> boxing_op = CHECK_JUST(ConstructOp(op_conf_));
node->mut_op() = boxing_op;
for (const std::string& ibn : boxing_op->input_bns()) {
node->BindBnWithRegst(ibn, GetSoleConsumedRegst("in"));
2 changes: 1 addition & 1 deletion oneflow/core/graph/collective_boxing_unpack_task_node.cpp
Original file line number Diff line number Diff line change
@@ -53,7 +53,7 @@ void CollectiveBoxingUnpackTaskNode::BuildExecGphAndRegst() {
src_sbp_parallel_.ToProto(collective_boxing_unpack_conf->mutable_src_sbp_parallel());
dst_sbp_parallel_.ToProto(collective_boxing_unpack_conf->mutable_dst_sbp_parallel());
collective_boxing_unpack_conf->set_num_ranks(parallel_num_);
std::shared_ptr<Operator> sole_op = ConstructOp(op_conf);
std::shared_ptr<Operator> sole_op = CHECK_JUST(ConstructOp(op_conf));
node->mut_op() = sole_op;
node->BindBnWithRegst(sole_op->SoleIbn(), GetSoleConsumedRegst("in"));
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
2 changes: 1 addition & 1 deletion oneflow/core/graph/copy_task_node.cpp
Original file line number Diff line number Diff line change
@@ -34,7 +34,7 @@ void CopyTaskNode::BuildExecGphAndRegst() {
auto in_regst = GetSoleConsumedRegst("copy_in");
out_regst->CopyBlobDescFrom(in_regst.get());
ExecNode* node = mut_exec_gph().NewNode();
node->mut_op() = ConstructOp(NewCopyOpConf());
node->mut_op() = CHECK_JUST(ConstructOp(NewCopyOpConf()));
node->BindBnWithRegst(node->op()->SoleIbn(), in_regst);
node->BindBnWithRegst(node->op()->SoleObn(), out_regst);
}
2 changes: 1 addition & 1 deletion oneflow/core/graph/op_graph.cpp
Original file line number Diff line number Diff line change
@@ -56,7 +56,7 @@ const cfg::ParallelDistribution& OpNode::ParallelDistribution4Lbi(const LogicalB
OpNode::OpNode(const std::shared_ptr<const ParallelDesc>& parallel_desc,
const OperatorConf& op_conf)
: parallel_desc_(parallel_desc),
op_(ConstructOp(op_conf, parallel_desc->device_type())),
op_(CHECK_JUST(ConstructOp(op_conf, parallel_desc->device_type()))),
ibns_(op_->input_bns().begin(), op_->input_bns().end()) {
CHECK_JUST(op_->FillOpParallelDesc(parallel_desc));
}
2 changes: 1 addition & 1 deletion oneflow/core/graph/slice_boxing_task_node.cpp
Original file line number Diff line number Diff line change
@@ -61,7 +61,7 @@ void SliceBoxingTaskNode::ConsumeAllRegsts() {

void SliceBoxingTaskNode::BuildExecGphAndRegst() {
ExecNode* node = mut_exec_gph().NewNode();
std::shared_ptr<Operator> op = ConstructOp(GetBoxingOpConf());
std::shared_ptr<Operator> op = CHECK_JUST(ConstructOp(GetBoxingOpConf()));
node->mut_op() = op;
FOR_RANGE(size_t, i, 0, op->input_bns().size()) {
const std::string& ibn = op->input_bns().Get(i);
6 changes: 3 additions & 3 deletions oneflow/core/job/job_build_and_infer_ctx.cpp
Original file line number Diff line number Diff line change
@@ -500,7 +500,7 @@ Maybe<OpAttribute> JobBuildAndInferCtx::AddAndInferMirroredOp(const OperatorConf
const auto& scope = Global<symbol::Storage<Scope>>::Get()->Get(op_conf.scope_symbol_id());
const auto* job_desc = JUST(scope.job_desc());
const auto& parallel_desc = *JUST(scope.GetParallelDesc(op_conf));
auto op = ConstructOp(op_conf, parallel_desc.device_type());
auto op = JUST(ConstructOp(op_conf, parallel_desc.device_type()));
JUST(CheckAllInputsConvertableToMirroredBlob(*op));
int32_t parallel_num = parallel_desc.parallel_num();
JUST(CheckAllInputsWithSameParallelNum(*op, parallel_num));
@@ -571,7 +571,7 @@ Maybe<OpAttribute> JobBuildAndInferCtx::AddAndInferOp(const OperatorConf& op_con
CHECK_NE_OR_RETURN(op_conf.device_tag(), "invalid_device")
<< Error::OpConfDeviceTagNoSetError() << "op_name: " << op_name << " not set device tag";

op_name2op_.emplace(op_name, ConstructOp(op_conf));
op_name2op_.emplace(op_name, JUST(ConstructOp(op_conf)));
Operator* op = op_name2op_.at(op_name).get();

cfg::SbpSignature sbp_sig_conf;
@@ -1090,7 +1090,7 @@ void JobBuildAndInferCtx::InferBlobBackwardSignature(
// find backward used logical blob ids
auto backward_used_lbis = std::make_shared<HashSet<LogicalBlobId>>();
for (const auto& bw_op_conf : bw_op_confs) {
const auto& bw_op = ConstructOp(bw_op_conf, op.device_type());
const auto& bw_op = CHECK_JUST(ConstructOp(bw_op_conf, op.device_type()));
for (const auto& ibn : bw_op->input_bns()) {
const auto& lbi = bw_op->BnInOp2Lbi(ibn);
if (FwLogicalBlobDescPtr4Lbi(lbi) != nullptr) { backward_used_lbis->insert(lbi); }
5 changes: 3 additions & 2 deletions oneflow/core/job/job_builder.cpp
Original file line number Diff line number Diff line change
@@ -316,14 +316,15 @@ void JobBuilder::AddOrMutOpsOnlyOnce(const ParallelConf& parallel_conf,
MutOpsOnlyOnce(mut_ops);
}

void JobBuilder::ForEachOperator(const std::function<void(const Operator&)>& Handler) const {
Maybe<void> JobBuilder::ForEachOperator(const std::function<void(const Operator&)>& Handler) const {
for (const auto& pair : op_name2op_conf_) {
auto it = op_name2parallel_conf_.find(pair.first);
CHECK(it != op_name2parallel_conf_.end()) << "op_name: " << pair.first;
DeviceType device_type = ParallelDesc(*it->second).device_type();
std::shared_ptr<Operator> op = ConstructOp(*pair.second, device_type);
std::shared_ptr<Operator> op = JUST(ConstructOp(*pair.second, device_type));
Handler(*op);
}
return Maybe<void>::Ok();
}

const ParallelConf& JobBuilder::ParallelConf4OpName(const std::string& op_name) const {
2 changes: 1 addition & 1 deletion oneflow/core/job/job_builder.h
Original file line number Diff line number Diff line change
@@ -70,7 +70,7 @@ class JobBuilder final {
void SetSbpParallel4Oba(const OpBlobArg& oba, const cfg::SbpParallel& sbp_parallel);
void SetParallelDistribution4Oba(const OpBlobArg& oba,
const cfg::ParallelDistribution& parallel_distribution);
void ForEachOperator(const std::function<void(const Operator&)>& Handler) const;
Maybe<void> ForEachOperator(const std::function<void(const Operator&)>& Handler) const;

const ParallelConf& ParallelConf4Lbi(const LogicalBlobId& lbi) const;
const ParallelConf& ParallelConf4OpName(const std::string& op_name) const;
19 changes: 10 additions & 9 deletions oneflow/core/job_rewriter/autograd.cpp
Original file line number Diff line number Diff line change
@@ -401,11 +401,11 @@ void BindFwBwObaPairs(const OpGraph& op_graph, const OpBlobArgPairs& fw_bw_oba_p
}
}

void CalcFwBwObaPairs(const OpGraph& op_graph,
const HashMap<OpBlobArg, LogicalBlobId>& in_oba2in_diff_lbi,
const HashMap<OpBlobArg, LogicalBlobId>& out_oba2out_diff_lbi,
const HashMap<OpBlobArg, LogicalBlobId>& out_oba2clone_bw_add_out_lbi,
const JobBuilder& job_builder, OpBlobArgPairs* fw_bw_oba_pairs) {
Maybe<void> CalcFwBwObaPairs(const OpGraph& op_graph,
const HashMap<OpBlobArg, LogicalBlobId>& in_oba2in_diff_lbi,
const HashMap<OpBlobArg, LogicalBlobId>& out_oba2out_diff_lbi,
const HashMap<OpBlobArg, LogicalBlobId>& out_oba2clone_bw_add_out_lbi,
const JobBuilder& job_builder, OpBlobArgPairs* fw_bw_oba_pairs) {
HashMap<LogicalBlobId, OpBlobArg> in_diff_lbi2in_oba;
op_graph.ReverseTopoForEachNode([&](OpNode* op_node) {
const auto& op = op_node->op();
@@ -432,7 +432,7 @@ void CalcFwBwObaPairs(const OpGraph& op_graph,
for (const auto& pair : out_oba2clone_bw_add_out_lbi) {
CHECK(clone_bw_add_out_lbi2out_oba.emplace(pair.second, pair.first).second);
}
job_builder.ForEachOperator([&](const Operator& op) {
JUST(job_builder.ForEachOperator([&](const Operator& op) {
for (const auto& ibn : op.input_bns()) {
const auto& out_oba_it = out_diff_lbi2out_oba.find(op.BnInOp2Lbi(ibn));
if (out_oba_it == out_diff_lbi2out_oba.end()) { continue; }
@@ -457,7 +457,8 @@ void CalcFwBwObaPairs(const OpGraph& op_graph,
*pair->mutable_second() = clone_out_oba_it->second;
}
}
});
}));
return Maybe<void>::Ok();
}

void InitOutOba2OutDiffLbi(JobPassCtx* ctx, const OpGraph& op_graph,
@@ -845,8 +846,8 @@ Maybe<void> AutoGrad(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_b
}
}
OpBlobArgPairs fw_bw_oba_pairs;
CalcFwBwObaPairs(op_graph, in_oba2in_diff_lbi, out_oba2out_diff_lbi, out_oba2clone_bw_add_out_lbi,
*job_builder, &fw_bw_oba_pairs);
JUST(CalcFwBwObaPairs(op_graph, in_oba2in_diff_lbi, out_oba2out_diff_lbi,
out_oba2clone_bw_add_out_lbi, *job_builder, &fw_bw_oba_pairs));
BindFwBwObaPairs(op_graph, fw_bw_oba_pairs, identical_sbp_oba_pairs);
CalcOutLbi2OutDiffLbi(op_graph, out_oba2out_diff_lbi, out_lbi2out_diff_lbi);
return Maybe<void>::Ok();
4 changes: 2 additions & 2 deletions oneflow/core/job_rewriter/autotick.cpp
Original file line number Diff line number Diff line change
@@ -164,15 +164,15 @@ Maybe<void> ConnectSrcSubsetTickAndOtherTick(const OperatorConf& src_subset_tick
CHECK(src_subset_tick_op.has_src_subset_tick_conf());
const std::string& src_lbn =
src_subset_tick_op.name() + "/" + src_subset_tick_op.src_subset_tick_conf().out();
job_builder->ForEachOperator([&](const Operator& op) {
JUST(job_builder->ForEachOperator([&](const Operator& op) {
if (op.op_name() != src_subset_tick_op.name()) {
CHECK(!op.op_conf().has_src_subset_tick_conf());
}
auto mut_helper = NewMutOpConTickInputHelper(op.op_conf());
if (!mut_helper) { return; }
if (mut_helper->IsTickInputBound() == true) { return; }
job_builder->MutOpsOnlyOnce({mut_helper->NewTickInputBoundOpConf(src_lbn)});
});
}));
return Maybe<void>::Ok();
}

2 changes: 1 addition & 1 deletion oneflow/core/kernel/runtime_blob_shape_infer_helper.cpp
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@ namespace oneflow {
RuntimeBlobShapeInferHelper::RuntimeBlobShapeInferHelper(const OperatorConf& op_conf,
const KernelConf& kernel_conf,
const JobDesc* job_desc) {
op_ = ConstructOp(op_conf);
op_ = CHECK_JUST(ConstructOp(op_conf));
const OpAttribute& op_attribute = kernel_conf.op_attribute();
if (op_attribute.has_parallel_conf_signature()
&& op_attribute.parallel_conf_signature().has_op_parallel_conf()) {
3 changes: 2 additions & 1 deletion oneflow/core/operator/acc_tick_op.cpp
Original file line number Diff line number Diff line change
@@ -27,11 +27,12 @@ Maybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& G

} // namespace

void AccTickOp::InitFromOpConf() {
Maybe<void> AccTickOp::InitFromOpConf() {
CHECK(op_conf().has_acc_tick_conf());

EnrollInputBn("one", false);
EnrollOutputBn("acc", false);
return Maybe<void>::Ok();
}

Maybe<void> AccTickOp::InferLogicalOutBlobDescs(
2 changes: 1 addition & 1 deletion oneflow/core/operator/acc_tick_op.h
Original file line number Diff line number Diff line change
@@ -26,7 +26,7 @@ class AccTickOp final : public Operator {
AccTickOp() = default;
~AccTickOp() = default;

void InitFromOpConf() override;
Maybe<void> InitFromOpConf() override;

Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
5 changes: 3 additions & 2 deletions oneflow/core/operator/assign_op.cpp
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@ class AssignOp final : public Operator {
AssignOp() = default;
~AssignOp() override = default;

void InitFromOpConf() override;
Maybe<void> InitFromOpConf() override;
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override;
@@ -37,10 +37,11 @@ class AssignOp final : public Operator {
cfg::SbpSignatureList* sbp_sig_list) const override;
};

void AssignOp::InitFromOpConf() {
Maybe<void> AssignOp::InitFromOpConf() {
CHECK(op_conf().has_assign_conf());
EnrollInputBn("ref")->set_is_mutable(true);
EnrollInputBn("value");
return Maybe<void>::Ok();
}

std::string DebugString(const BlobDesc& blob_desc) {
5 changes: 3 additions & 2 deletions oneflow/core/operator/boxing_identity_op.cpp
Original file line number Diff line number Diff line change
@@ -25,7 +25,7 @@ class BoxingIdentityOp : public Operator {
BoxingIdentityOp() = default;
~BoxingIdentityOp() override = default;

void InitFromOpConf() override;
Maybe<void> InitFromOpConf() override;
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override {
@@ -40,9 +40,10 @@ class BoxingIdentityOp : public Operator {
LogicalBlobId lbi4obn(const std::string& output_bn) const override;
};

void BoxingIdentityOp::InitFromOpConf() {
Maybe<void> BoxingIdentityOp::InitFromOpConf() {
EnrollInputBn("in", false);
EnrollOutputBn("out", false);
return Maybe<void>::Ok();
}

LogicalBlobId BoxingIdentityOp::lbi4ibn(const std::string& input_bn) const {
3 changes: 2 additions & 1 deletion oneflow/core/operator/boxing_op.cpp
Original file line number Diff line number Diff line change
@@ -42,7 +42,7 @@ void BoxingOp::VirtualGenKernelConf(
EraseEmptyBnInVec(GetBlobDesc4BnInOp, op_attribute->mutable_output_bns());
}

void BoxingOp::InitFromOpConf() {
Maybe<void> BoxingOp::InitFromOpConf() {
CHECK(op_conf().has_boxing_conf());
const BoxingOpConf& boxing_conf = op_conf().boxing_conf();

@@ -56,6 +56,7 @@ void BoxingOp::InitFromOpConf() {
for (int32_t i = 0; i < boxing_conf.out_num(); ++i) {
EnrollOutputBn("out_" + std::to_string(i), false);
}
return Maybe<void>::Ok();
}

LogicalBlobId BoxingOp::lbi4ibn(const std::string& input_bn) const {
2 changes: 1 addition & 1 deletion oneflow/core/operator/boxing_op.h
Original file line number Diff line number Diff line change
@@ -26,7 +26,7 @@ class BoxingOp final : public Operator {
BoxingOp() = default;
~BoxingOp() = default;

void InitFromOpConf() override;
Maybe<void> InitFromOpConf() override;
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override;
7 changes: 5 additions & 2 deletions oneflow/core/operator/boxing_zeros_op.cpp
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@ class BoxingZerosOp : public Operator {
BoxingZerosOp() = default;
~BoxingZerosOp() override = default;

void InitFromOpConf() override;
Maybe<void> InitFromOpConf() override;
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override {
@@ -39,7 +39,10 @@ class BoxingZerosOp : public Operator {
LogicalBlobId lbi4obn(const std::string& output_bn) const override;
};

void BoxingZerosOp::InitFromOpConf() { EnrollOutputBn("out", false); }
Maybe<void> BoxingZerosOp::InitFromOpConf() {
EnrollOutputBn("out", false);
return Maybe<void>::Ok();
}

LogicalBlobId BoxingZerosOp::lbi4ibn(const std::string& input_bn) const {
return this->op_conf().boxing_zeros_conf().lbi();
3 changes: 2 additions & 1 deletion oneflow/core/operator/broadcast_to_compatible_with_op.cpp
Original file line number Diff line number Diff line change
@@ -58,11 +58,12 @@ class BroadcastToCompatibleWithOp final : public Operator {
BroadcastToCompatibleWithOp() = default;
~BroadcastToCompatibleWithOp() override = default;

void InitFromOpConf() {
Maybe<void> InitFromOpConf() {
CHECK(op_conf().has_broadcast_to_compatible_with_conf());
EnrollInputBn("x");
EnrollRepeatedInputBn("compatible", false);
EnrollOutputBn("y");
return Maybe<void>::Ok();
}

Maybe<void> InferLogicalOutBlobDescs(
3 changes: 2 additions & 1 deletion oneflow/core/operator/callback_notify_op.cpp
Original file line number Diff line number Diff line change
@@ -18,9 +18,10 @@ limitations under the License.

namespace oneflow {

void CallbackNotifyOp::InitFromOpConf() {
Maybe<void> CallbackNotifyOp::InitFromOpConf() {
CHECK(op_conf().has_callback_notify_conf());
EnrollInputBn("in", false);
return Maybe<void>::Ok();
}

namespace {
2 changes: 1 addition & 1 deletion oneflow/core/operator/callback_notify_op.h
Original file line number Diff line number Diff line change
@@ -26,7 +26,7 @@ class CallbackNotifyOp final : public Operator {
CallbackNotifyOp() = default;
~CallbackNotifyOp() = default;

void InitFromOpConf() override;
Maybe<void> InitFromOpConf() override;
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override;
3 changes: 2 additions & 1 deletion oneflow/core/operator/case_op.cpp
Original file line number Diff line number Diff line change
@@ -18,9 +18,10 @@ limitations under the License.

namespace oneflow {

void CaseOp::InitFromOpConf() {
Maybe<void> CaseOp::InitFromOpConf() {
EnrollInputBn("in", false);
EnrollRepeatedOutputBn("out", false);
return Maybe<void>::Ok();
}

namespace {
Loading