Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Task infer blob desc support choosing method #10124

Merged
merged 18 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion oneflow/core/graph/boxing_identity_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void BoxingIdentityTaskNode::BuildExecGphAndRegst() {
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn()));
node->BindBnWithRegst(sole_op->SoleObn(), out_regst);
node->InferBlobDescs(nullptr);
(node->*GetInferBlobDescsMethod())(nullptr);
}

void BoxingIdentityTaskNode::InferProducedDataRegstTimeShape() {
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/graph/boxing_zeros_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ void BoxingZerosTaskNode::BuildExecGphAndRegst() {
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn()));
node->BindBnWithRegst(sole_op->SoleObn(), out_regst);
node->InferBlobDescs(nullptr);
(node->*GetInferBlobDescsMethod())(nullptr);
}

void BoxingZerosTaskNode::InferProducedDataRegstTimeShape() {
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/graph/collective_boxing_pack_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ void CollectiveBoxingPackTaskNode::BuildExecGphAndRegst() {
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn()));
node->BindBnWithRegst(sole_op->SoleObn(), out_regst);
node->InferBlobDescs(nullptr);
(node->*GetInferBlobDescsMethod())(nullptr);
}

void CollectiveBoxingPackTaskNode::InferProducedDataRegstTimeShape() {
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/graph/collective_boxing_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ void CollectiveBoxingGenericTaskNode::BuildExecGphAndRegst() {
node->BindBnWithRegst(obn, out_regst);
out_regst->AddLbi(boxing_op->BnInOp2Lbi(obn));
}
node->InferBlobDescs(nullptr);
(node->*GetInferBlobDescsMethod())(nullptr);
}

void CollectiveBoxingGenericTaskNode::InferProducedDataRegstTimeShape() {
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/graph/collective_boxing_unpack_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ void CollectiveBoxingUnpackTaskNode::BuildExecGphAndRegst() {
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn()));
node->BindBnWithRegst(sole_op->SoleObn(), out_regst);
node->InferBlobDescs(nullptr);
(node->*GetInferBlobDescsMethod())(nullptr);
}

void CollectiveBoxingUnpackTaskNode::InferProducedDataRegstTimeShape() {
Expand Down
4 changes: 4 additions & 0 deletions oneflow/core/graph/compute_task_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class CompTaskNode : public TaskNode {
// op
std::shared_ptr<const Operator> op() const { return op_node_->shared_op(); }

ExecNode::InferBlobDescsMethod GetInferBlobDescsMethod() const override {
return &ExecNode::InferBlobDescsByInputs;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该支持 from sbp 和 logical shape ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该支持 from sbp 和 logical shape ?

from sbp 的推理方法和编译模式关联了,所以就没加到这个分支

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其实为了加速,master 编译这里也可以用 from sbp 吧,这样是不是会更快?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其实为了加速,master 编译这里也可以用 from sbp 吧,这样是不是会更快?

当前估计使用 from sbp 不会变快:

  • from sbp 本身的实现开销和之前做 infer physical blobdesc 估计差不多
  • 只适用于 user op,改成通用的影响的地方比较多,有个后续 pr 在做这个
  • master infer 的过程如果不使用并行,加速不明显

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有的op不支持from sbp吧,比如涉及到求平均,求和或者求最大值的。(我记得当前有一个op是这样的,好像是叫bn?)

还有个问题就是如果sbp变动了会怎么样?当前是要重新推导一遍
比如自动并行会大规模修改sbp。这时候如果有个logical desc储存着应该好一点

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有的op不支持from sbp吧,比如涉及到求平均,求和或者求最大值的。(我记得当前有一个op是这样的,好像是叫bn?)

user op 都是符合的,这里使用的场景,之前用 physical infer 推理后,也会再用 sbp 做下 check,所以也是符合的。

还有个问题就是如果sbp变动了会怎么样?当前是要重新推导一遍 比如自动并行会大规模修改sbp。这时候如果有个logical desc储存着应该好一点

因为这个推理发生在 plan 生成阶段,是在 自动不行之后,按说 sbp 已经稳定了。

}

protected:
const OpNode* GetOneSuccOpNodeOnEdge(TaskEdge* edge);
const OpNode* GetOnePredOpNodeOnEdge(TaskEdge* edge);
Expand Down
69 changes: 67 additions & 2 deletions oneflow/core/graph/exec_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,31 @@ Maybe<void> CheckPhysicalBlobDesc(
return Maybe<void>::Ok();
}

// A helper function to infer blob's physical shape with ND SBP.
Maybe<void> InferPhysicalBlobDesc(
const Operator& op, const PbRpf<std::string>& bns,
const std::function<Maybe<const BlobDesc>(const std::string&)>& GetLogicalBlobDesc,
const NdSbpSignature* nd_sbp_signature, const ParallelContext* parallel_ctx,
const std::function<BlobDesc*(const std::string&)>& GetPhysicalBlobDesc) {
const std::shared_ptr<const ParallelDesc> op_parallel_desc = JUST(op.GetOpParallelDesc());
for (const auto& bn : bns) {
BlobDesc* physical_blob_desc = GetPhysicalBlobDesc(bn);
const auto& logical_blob_desc = *JUST(GetLogicalBlobDesc(bn));
CHECK_NOTNULL_OR_RETURN(physical_blob_desc)
<< "physical_blob_desc should not be nullptr. op location: " << op.op_loc();
*physical_blob_desc = logical_blob_desc;
const auto& physical_shape = JUST_MSG(
GetPhysicalShape(logical_blob_desc.shape(), nd_sbp_signature->bn_in_op2nd_sbp().at(bn),
*op_parallel_desc, *parallel_ctx),
std::stringstream() << " check physical shape failed, op name " << op.op_loc());
physical_blob_desc->set_shape(*physical_shape);
}
return Maybe<void>::Ok();
}

} // namespace

void ExecNode::InferBlobDescs(const ParallelContext* parallel_ctx) {
void ExecNode::InferBlobDescsByInputs(const ParallelContext* parallel_ctx) {
auto GetBlobDesc4BnInOp = GetBlobDesc4BnInOpFunc();
const OpNode* op_node = Singleton<OpGraph>::Get()->OpNode4OpName(op()->op_name());
const NdSbpSignature* nd_sbp_signature = nullptr;
Expand All @@ -128,7 +150,50 @@ void ExecNode::InferBlobDescs(const ParallelContext* parallel_ctx) {
CHECK_JUST_MSG(op_->InferInplaceObn2IbnIf(&mut_inplace_obn2ibn_, &con_inplace_obn2ibn_,
GetBlobDesc4BnInOp, parallel_ctx),
std::stringstream()
<< " infer inplace obn to ibn if failed, op name " << op_->op_loc());
<< " infer inplace obn to ibn is failed, op name " << op_->op_loc());
}

void ExecNode::InferBlobDescsByNdSbp(const ParallelContext* parallel_ctx) {
const HashSet<std::string> ibns{op()->input_bns().begin(), op()->input_bns().end()};
HashMap<std::string, BlobDesc> ibn2blob_desc{};
const auto& GetBlobDesc4BnInOp = [&](const std::string& bn_in_op) -> BlobDesc* {
// Generate temp regst to store input blob desc, and will be released after infer output blob
// desc.
if (ibns.count(bn_in_op) > 0) {
auto iter = ibn2blob_desc.find(bn_in_op);
if (iter == ibn2blob_desc.end()) {
iter = ibn2blob_desc.emplace(bn_in_op, kInvalidDataType).first;
}
return &iter->second;
}
auto it = bn_in_op2regst_.find(bn_in_op);
if (it == bn_in_op2regst_.end()) { return nullptr; }
std::shared_ptr<RegstDesc> regst = it->second;
CHECK(regst);
return regst->MutBlobDesc(op()->BnInOp2Lbi(bn_in_op));
};
const OpNode* op_node = Singleton<OpGraph>::Get()->OpNode4OpName(op()->op_name());
const NdSbpSignature* nd_sbp_signature = &CHECK_NOTNULL(op_node)->nd_sbp_signature();

// TODO(strint): user op can infer output with SBP, so there is no need to infer the input.
// Reference: https://github.com/Oneflow-Inc/oneflow/pull/8971
// Infer input blob desc with SBP, the infer results are set into the temp input blob desc.
CHECK_JUST(InferPhysicalBlobDesc(
*op(), op()->input_bns(),
std::bind(&Operator::GetLogicalBlobDesc4Ibn, op().get(), std::placeholders::_1),
nd_sbp_signature, parallel_ctx, GetBlobDesc4BnInOp));

// Infer output blob desc with input.
CHECK_JUST_MSG(op_->InferBlobDescsIf(GetBlobDesc4BnInOp, parallel_ctx, &GlobalJobDesc()),
std::stringstream() << " infer blob descs is failed, op name " << op_->op_loc());
CHECK_JUST(CheckPhysicalBlobDesc(
*op(), op()->output_bns(),
std::bind(&Operator::GetLogicalBlobDesc4Obn, op().get(), std::placeholders::_1),
nd_sbp_signature, parallel_ctx, GetBlobDesc4BnInOp));
CHECK_JUST_MSG(op_->InferInplaceObn2IbnIf(&mut_inplace_obn2ibn_, &con_inplace_obn2ibn_,
GetBlobDesc4BnInOp, parallel_ctx),
std::stringstream()
<< " infer inplace obn to ibn is failed, op name " << op_->op_loc());
}

std::function<BlobDesc*(const std::string&)> ExecNode::GetBlobDesc4BnInOpFunc() const {
Expand Down
4 changes: 3 additions & 1 deletion oneflow/core/graph/exec_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ class ExecNode final : public Node<ExecNode, ExecEdge> {
std::string VisualStr() const override { return op_->op_name(); }
void ToProto(const ParallelContext*, ExecNodeProto*) const;

void InferBlobDescs(const ParallelContext* parallel_ctx);
typedef void (ExecNode::*InferBlobDescsMethod)(const ParallelContext*);
void InferBlobDescsByInputs(const ParallelContext* parallel_ctx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里只提供了一种 method 吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里只提供了一种 method 吗

对,相当于这个分支不改变执行逻辑,只提供了接口

void InferBlobDescsByNdSbp(const ParallelContext* parallel_ctx);

const HashMap<std::string, std::string>& mut_inplace_obn2ibn() const {
return mut_inplace_obn2ibn_;
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/graph/nccl_send_recv_boxing_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ void NcclSendRecvBoxingTaskNode::BuildExecGphAndRegst() {
node->BindBnWithRegst(sole_op->SoleObn(), out_regst);
}
node->AddBnToRegstAndBindIt(&Operator::tmp_bns, GetProducedRegst("tmp"));
node->InferBlobDescs(parallel_ctx());
(node->*GetInferBlobDescsMethod())(parallel_ctx());
}

void NcclSendRecvBoxingTaskNode::InferProducedDataRegstTimeShape() {
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/graph/slice_boxing_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void SliceBoxingTaskNode::BuildExecGphAndRegst() {
out_regst->AddLbi(lbi());
node->BindBnWithRegst(op->SoleObn(), out_regst);
node->AddBnToRegstAndBindIt(&Operator::tmp_bns, GetProducedRegst("tmp"));
node->InferBlobDescs(parallel_ctx());
(node->*GetInferBlobDescsMethod())(parallel_ctx());
}

void SliceBoxingTaskNode::InferProducedDataRegstTimeShape() {
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/graph/task_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
DeviceType device_type() const;
virtual const ParallelContext* parallel_ctx() const { return nullptr; }

// Different types of TaskNode/Compile Mode choose different output BlobDesc inference methods
virtual ExecNode::InferBlobDescsMethod GetInferBlobDescsMethod() const = 0;

// Setters
void set_machine_id(int64_t val);
void set_thrd_id(int64_t val);
Expand Down
7 changes: 7 additions & 0 deletions oneflow/core/graph/transport_task_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,16 @@ class TransportTaskNode : public TaskNode {
const TaskGraphRebuildCtx& ctx);
void ToTransportTaskProtoIf(TransportTaskProto*) const;

ExecNode::InferBlobDescsMethod GetInferBlobDescsMethod() const override {
// TransportTaskNode infers output BlobDesc based on input BlobDesc, because it can't infers
// output BlobDesc with SBP.
return &ExecNode::InferBlobDescsByInputs;
}

private:
virtual Maybe<void> InitTransportTaskFromProto(const TransportTaskProto&,
const TaskGraphRebuildCtx& ctx) = 0;

virtual void ToTransportTaskProto(TransportTaskProto*) const = 0;
LogicalBlobId lbi_;
};
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/graph_impl/acc_compute_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ void AccCompTaskNode::BuildExecGphAndRegst() {
exec_node->BindBnWithRegst(op()->SoleIbn(), in_regst);
out_regst->AddLbi(op()->BnInOp2Lbi(op()->SoleObn()));
exec_node->BindBnWithRegst(op()->SoleObn(), out_regst);
exec_node->InferBlobDescs(parallel_ctx());
(exec_node->*GetInferBlobDescsMethod())(parallel_ctx());
out_regst->ForEachLbi([out_regst](const LogicalBlobId& lbi) {
const BlobDesc* blob_desc = out_regst->GetBlobDesc(lbi);
CHECK_EQ(blob_desc->is_dynamic(), false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void AccCtrlTickCompTaskNode::BuildExecGphAndRegst() {
exec_node->BindBnWithRegst(op->SoleIbn(), in_regst);
out_regst->AddLbi(op->BnInOp2Lbi(op->SoleObn()));
exec_node->BindBnWithRegst(op->SoleObn(), out_regst);
exec_node->InferBlobDescs(parallel_ctx());
(exec_node->*GetInferBlobDescsMethod())(parallel_ctx());
}

REGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kAccCtrlTick);
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/graph_impl/acc_tick_compute_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void AccTickCompTaskNode::BuildExecGphAndRegst() {
exec_node->BindBnWithRegst(op->SoleIbn(), in_regst);
out_regst->AddLbi(op->BnInOp2Lbi(op->SoleObn()));
exec_node->BindBnWithRegst(op->SoleObn(), out_regst);
exec_node->InferBlobDescs(parallel_ctx());
(exec_node->*GetInferBlobDescsMethod())(parallel_ctx());
}

REGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kAccTick);
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/graph_impl/case_compute_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ void CaseCompTaskNode::BuildExecGphAndRegst() {
out_regst->AddLbi(sole_op->BnInOp2Lbi(name));
node->BindBnWithRegst(name, out_regst);
}
node->InferBlobDescs(parallel_ctx());
(node->*GetInferBlobDescsMethod())(parallel_ctx());
}

void CaseCompTaskNode::InferProducedDataRegstTimeShape() { NaiveInferProducedDataRegstTimeShape(); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void CriticalSectionWaitTickCompTaskNode::BuildExecGphAndRegst() {
out_regst->AddLbi(lbi);
node->BindBnWithRegst(obn, out_regst);
}
node->InferBlobDescs(parallel_ctx());
(node->*GetInferBlobDescsMethod())(parallel_ctx());
}

REGISTER_INDEPENDENT_TASK_STREAM_INDEX_GETTER(TaskType::kCriticalSectionWaitTick);
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/graph_impl/decode_h2d_compute_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ void DecodeH2DCompTaskNode::BuildExecGphAndRegst() {
out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn()));
node->BindBnWithRegst(sole_op->SoleObn(), out_regst);
node->AddBnToRegstAndBindIt(&Operator::tmp_bns, GetProducedRegst("tmp"));
node->InferBlobDescs(parallel_ctx());
(node->*GetInferBlobDescsMethod())(parallel_ctx());
}

REGISTER_NAMED_TASK_STREAM_INDEX_GETTER(DeviceType::kCUDA, TaskType::kDecodeH2D, "DECODE_H2D")
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/graph_impl/device_tick_compute_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void DeviceTickCompTaskNode::BuildExecGphAndRegst() {
out_regst->AddLbi(lbi);
node->BindBnWithRegst(obn, out_regst);
}
node->InferBlobDescs(parallel_ctx());
(node->*GetInferBlobDescsMethod())(parallel_ctx());
}

REGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kDeviceTick);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ void DistributeConcatCompTaskNode::ConsumeAllRegsts() {
void DistributeConcatCompTaskNode::BuildExecGphAndRegst() {
BuildExecGphStructAndBindInRegst();
BuildOutRegst();
mut_exec_gph().TopoForEachNode([this](ExecNode* node) { node->InferBlobDescs(parallel_ctx()); });
mut_exec_gph().TopoForEachNode(
[this](ExecNode* node) { (node->*GetInferBlobDescsMethod())(parallel_ctx()); });
}

void DistributeConcatCompTaskNode::BuildExecGphStructAndBindInRegst() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ void DistributeSplitCompTaskNode::ConsumeAllRegsts() {
void DistributeSplitCompTaskNode::BuildExecGphAndRegst() {
BuildExecGphStructAndBindInRegst();
BuildOutRegst();
mut_exec_gph().TopoForEachNode([this](ExecNode* node) { node->InferBlobDescs(parallel_ctx()); });
mut_exec_gph().TopoForEachNode(
[this](ExecNode* node) { (node->*GetInferBlobDescsMethod())(parallel_ctx()); });
}

void DistributeSplitCompTaskNode::BuildExecGphStructAndBindInRegst() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void DstSubsetTickCompTaskNode::BuildExecGphAndRegst() {
out_regst->AddLbi(lbi);
node->BindBnWithRegst(obn, out_regst);
}
node->InferBlobDescs(parallel_ctx());
(node->*GetInferBlobDescsMethod())(parallel_ctx());
}

REGISTER_TICK_TASK_STREAM_INDEX_GETTER(TaskType::kDstSubsetTick);
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/graph_impl/esac_compute_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ void EsacCompTaskNode::BuildExecGphAndRegst() {
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
out_regst->AddLbi(sole_op->BnInOp2Lbi("out"));
node->BindBnWithRegst("out", out_regst);
node->InferBlobDescs(parallel_ctx());
(node->*GetInferBlobDescsMethod())(parallel_ctx());
}

void EsacCompTaskNode::InferProducedDataRegstTimeShape() { NaiveInferProducedDataRegstTimeShape(); }
Expand Down
3 changes: 2 additions & 1 deletion oneflow/core/graph_impl/normal_forward_compute_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ void NormalForwardCompTaskNode::BuildExecGphAndRegst() {
BuildExecGphStructAndBindInRegst();
BuildOutRegst();
BuildTmp7BufRegsts();
mut_exec_gph().TopoForEachNode([this](ExecNode* node) { node->InferBlobDescs(parallel_ctx()); });
mut_exec_gph().TopoForEachNode(
[this](ExecNode* node) { (node->*GetInferBlobDescsMethod())(parallel_ctx()); });
}

void NormalForwardCompTaskNode::BuildExecGphStructAndBindInRegst() {
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/graph_impl/pack_compute_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ void PackCompTaskNode::BuildExecGphAndRegst() {
out_regst->AddLbi(op()->BnInOp2Lbi(op()->SoleObn()));
exec_node->BindBnWithRegst(op()->SoleObn(), out_regst);

exec_node->InferBlobDescs(parallel_ctx());
(exec_node->*GetInferBlobDescsMethod())(parallel_ctx());
}

REGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kPack);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void ReentrantLockCompTaskNode::BuildExecGphAndRegst() {
out_regst->AddLbi(lbi);
node->BindBnWithRegst(obn, out_regst);
}
node->InferBlobDescs(parallel_ctx());
(node->*GetInferBlobDescsMethod())(parallel_ctx());
}

void ReentrantLockCompTaskNode::InferProducedDataRegstTimeShape() {
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/graph_impl/repeat_compute_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ void RepeatCompTaskNode::BuildExecGphAndRegst() {
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn()));
node->BindBnWithRegst(sole_op->SoleObn(), out_regst);
node->InferBlobDescs(parallel_ctx());
(node->*GetInferBlobDescsMethod())(parallel_ctx());

// NOTE(chengcheng): force inplace
CHECK_EQ(in_regst->NumOfLbi(), 1);
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/graph_impl/source_tick_compute_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void SourceTickCompTaskNode::BuildExecGphAndRegst() {
out_regst->AddLbi(lbi);
node->BindBnWithRegst(obn, out_regst);
}
node->InferBlobDescs(parallel_ctx());
(node->*GetInferBlobDescsMethod())(parallel_ctx());
}

REGISTER_TICK_TASK_STREAM_INDEX_GETTER(TaskType::kSourceTick);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void SrcSubsetTickCompTaskNode::BuildExecGphAndRegst() {
out_regst->AddLbi(lbi);
node->BindBnWithRegst(obn, out_regst);
}
node->InferBlobDescs(parallel_ctx());
(node->*GetInferBlobDescsMethod())(parallel_ctx());
}

REGISTER_TICK_TASK_STREAM_INDEX_GETTER(TaskType::kSrcSubsetTick);
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/graph_impl/ssp_variable_proxy_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class SspVariableProxyCompTaskNode final : public CompTaskNode {
BuildExecGphStructAndBindInRegst();
BuildOutRegst();
mut_exec_gph().TopoForEachNode(
[this](ExecNode* node) { node->InferBlobDescs(parallel_ctx()); });
[this](ExecNode* node) { (node->*GetInferBlobDescsMethod())(parallel_ctx()); });
}

void BuildExecGphStructAndBindInRegst() {
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/graph_impl/tick_compute_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void TickCompTaskNode::BuildExecGphAndRegst() {
out_regst->AddLbi(lbi);
node->BindBnWithRegst(obn, out_regst);
}
node->InferBlobDescs(parallel_ctx());
(node->*GetInferBlobDescsMethod())(parallel_ctx());
}

REGISTER_TICK_TASK_STREAM_INDEX_GETTER(TaskType::kTick);
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/graph_impl/unpack_compute_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ void UnpackCompTaskNode::BuildExecGphAndRegst() {
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
out_regst->AddLbi(op()->BnInOp2Lbi(op()->SoleObn()));
exec_node->BindBnWithRegst(op()->SoleObn(), out_regst);
exec_node->InferBlobDescs(parallel_ctx());
(exec_node->*GetInferBlobDescsMethod())(parallel_ctx());
}

REGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kUnpack);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ void WaitAndSendIdsCompTaskNode::BuildExecGphAndRegst() {
out_regst->AddLbi(lbi);
node->BindBnWithRegst(obn, out_regst);
}
node->InferBlobDescs(parallel_ctx());
(node->*GetInferBlobDescsMethod())(parallel_ctx());
}

void WaitAndSendIdsCompTaskNode::InferProducedDataRegstTimeShape() {
Expand Down