Skip to content

Commit

Permalink
interface_op support parallel_distribution (#4479)
Browse files Browse the repository at this point in the history
* interface_op support parallel_distribution

* add JUST

* fix

* fix

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
guo-ran and oneflow-ci-bot authored Mar 24, 2021
1 parent b2c5212 commit 9d6ab3f
Show file tree
Hide file tree
Showing 15 changed files with 138 additions and 51 deletions.
7 changes: 5 additions & 2 deletions oneflow/core/framework/op_arg_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,13 @@ void OpArgParallelAttribute::Assign(const std::shared_ptr<OpArgParallelAttribute

void OpArgParallelAttribute::DumpToInterfaceBlobConf(
std::shared_ptr<cfg::InterfaceBlobConf> interface_blob_conf) const {
interface_blob_conf->mutable_parallel_distribution()->clear_sbp_parallel();
if (sbp_parallel_->has_split_parallel()) {
interface_blob_conf->mutable_split_axis()->set_value(sbp_parallel_->split_parallel().axis());
*interface_blob_conf->mutable_parallel_distribution()->add_sbp_parallel() = *sbp_parallel_;
} else {
interface_blob_conf->clear_split_axis();
interface_blob_conf->mutable_parallel_distribution()
->add_sbp_parallel()
->mutable_broadcast_parallel();
}
}

Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/job/job_conf.proto
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ message IndexedSlicesOptimizerConf {
message ParallelBlobConf {
required BlobDescProto logical_blob_desc_conf = 1;
required ParallelConf parallel_conf = 2;
required SbpParallel sbp_conf = 3;
required ParallelDistribution parallel_distribution = 3;
}

message JobInputDef {
Expand Down
1 change: 0 additions & 1 deletion oneflow/core/job/model_io_job.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ OperatorConf GenForeignInputOpConf(const std::string& job_name, const int64_t in
*blob_conf->mutable_shape()->mutable_dim()->Add() = input_size;
blob_conf->set_data_type(DataType::kInt8);
blob_conf->set_is_dynamic(true);
blob_conf->mutable_split_axis()->clear_value();
return foreign_input_op_conf;
}

Expand Down
1 change: 0 additions & 1 deletion oneflow/core/job/model_io_v2_job.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ OperatorConf GenForeignInputOpConf(const std::string& job_name, const int64_t in
*blob_conf->mutable_shape()->mutable_dim()->Add() = input_size;
blob_conf->set_is_dynamic(true);
blob_conf->set_data_type(DataType::kInt8);
blob_conf->mutable_split_axis()->clear_value();
return foreign_input_op_conf;
}

Expand Down
13 changes: 7 additions & 6 deletions oneflow/core/job/oneflow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ namespace oneflow {

bool operator==(const ParallelBlobConf& lhs, const ParallelBlobConf& rhs) {
return BlobDesc(lhs.logical_blob_desc_conf()) == BlobDesc(rhs.logical_blob_desc_conf())
&& lhs.parallel_conf() == rhs.parallel_conf() && lhs.sbp_conf() == rhs.sbp_conf();
&& lhs.parallel_conf() == rhs.parallel_conf()
&& lhs.parallel_distribution() == rhs.parallel_distribution();
}

namespace {
Expand Down Expand Up @@ -504,11 +505,11 @@ void GetMemSharingOpBlobInfo(const JobBuilder& job_builder, const std::string& o
ParallelBlobConf ret;
*blob_conf->mutable_parallel_conf() = job_builder.ParallelConf4OpName(op_name);
*blob_conf->mutable_logical_blob_desc_conf() = job.helper().lbn2logical_blob_desc().at(lbn);
*blob_conf->mutable_sbp_conf() = job.job_parallel_view_conf()
.op_name2sbp_signature_conf()
.at(op_name)
.bn_in_op2sbp_parallel()
.at(obn);
*blob_conf->mutable_parallel_distribution() = job.job_parallel_view_conf()
.op_name2parallel_distribution_signature_conf()
.at(op_name)
.bn_in_op2parallel_distribution()
.at(obn);
}

void FilterOpName2ParallelBlobConf(
Expand Down
33 changes: 27 additions & 6 deletions oneflow/core/operator/input_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Maybe<void> InputOp::InferOutBlobDescs(
const ParallelContext* parallel_ctx) const {
BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out");
JUST(InterfaceOpUtil::InferOutBlobDesc(op_conf().input_conf().blob_conf(), out_blob_desc,
parallel_ctx));
parallel_ctx, *JUST(GetOpParallelDesc())));
return Maybe<void>::Ok();
}

Expand All @@ -51,15 +51,36 @@ Maybe<void> InputOp::InferSbpSignature(
const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,
std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,
const ParallelDesc& parallel_desc) const {
InterfaceOpUtil::GetInputLikeOpSbpSignature(op_conf().input_conf().blob_conf(), input_bns(),
output_bns(), sbp_signature);
JUST(InterfaceOpUtil::GetInputLikeOpSbpSignature(op_conf().input_conf().blob_conf(), input_bns(),
output_bns(), sbp_signature));
return Maybe<void>::Ok();
}

Maybe<void> InputOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const {
InterfaceOpUtil::GetInputLikeOpSbpSignature(op_conf().input_conf().blob_conf(), input_bns(),
output_bns(),
sbp_sig_list->mutable_sbp_signature()->Add());
JUST(InterfaceOpUtil::GetInputLikeOpSbpSignature(op_conf().input_conf().blob_conf(), input_bns(),
output_bns(),
sbp_sig_list->mutable_sbp_signature()->Add()));
return Maybe<void>::Ok();
}

Maybe<void> InputOp::InferParallelDistributionSignature(
ParallelDistributionSignature* parallel_distribution_signature,
const ParallelDistributionSignature& parallel_distribution_constraints,
const ParallelDesc& parallel_desc,
std::function<Maybe<const ParallelDistributionInferHint*>(const std::string&)>
ParallelDistributionInferHint4Ibn) const {
const auto& parallel_hierarchy = parallel_desc.hierarchy();
const InterfaceBlobConf& blob_conf = op_conf().input_conf().blob_conf();
ParallelDistribution& tick_parallel_distribution =
(*parallel_distribution_signature->mutable_bn_in_op2parallel_distribution())["tick"];
tick_parallel_distribution.clear_sbp_parallel();
FOR_RANGE(int64_t, i, 0, parallel_hierarchy->NumAxes()) {
tick_parallel_distribution.mutable_sbp_parallel()->Add()->mutable_broadcast_parallel();
}
ParallelDistribution& out_parallel_distribution =
(*parallel_distribution_signature->mutable_bn_in_op2parallel_distribution())["out"];
JUST(InterfaceOpUtil::ParseParallelDistributionFromBlobConf(blob_conf, parallel_desc,
&out_parallel_distribution));
return Maybe<void>::Ok();
}

Expand Down
6 changes: 6 additions & 0 deletions oneflow/core/operator/input_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ class InputOp final : public Operator {

Maybe<void> GetSbpSignatures(SbpSignatureList* sbp_sig_list) const override;
Symbol<OperatorConf> GetOpConfWithoutOpNameAndLbn() const override;
Maybe<void> InferParallelDistributionSignature(
ParallelDistributionSignature* parallel_distribution_signature,
const ParallelDistributionSignature& parallel_distribution_constraints,
const ParallelDesc& parallel_desc,
std::function<Maybe<const ParallelDistributionInferHint*>(const std::string&)>
ParallelDistributionInferHint4Ibn) const override;
};

} // namespace oneflow
Expand Down
5 changes: 3 additions & 2 deletions oneflow/core/operator/interface_blob_conf.proto
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ package oneflow;

import "oneflow/core/common/shape.proto";
import "oneflow/core/common/data_type.proto";
import "oneflow/core/job/sbp_parallel.proto";

message InterfaceBlobConf {
optional ShapeProto shape = 1;
optional DataType data_type = 2;
optional OptInt64 split_axis = 3;
optional bool is_dynamic = 5;
optional bool is_dynamic = 3;
optional ParallelDistribution parallel_distribution = 4;
}
50 changes: 31 additions & 19 deletions oneflow/core/operator/interface_op_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,17 @@ void CheckShape(const Shape& shape) {
Maybe<void> GetSbpSignature(const InterfaceBlobConf& blob_conf, const PbRpf<std::string>& input_bns,
const PbRpf<std::string>& output_bns, SbpSignature* sbp_signature,
bool is_for_input_op) {
if (blob_conf.split_axis().has_value()) {
if (!blob_conf.has_parallel_distribution()) {
SbpSignatureBuilder().Broadcast(input_bns).Broadcast(output_bns).Build(sbp_signature);
return Maybe<void>::Ok();
}
CHECK_EQ_OR_RETURN(blob_conf.parallel_distribution().sbp_parallel_size(), 1);
const auto& sbp_parallel = blob_conf.parallel_distribution().sbp_parallel(0);
if (sbp_parallel.has_split_parallel()) {
int64_t num_axes = blob_conf.shape().dim_size();
int64_t split_axis = blob_conf.split_axis().value();
if (split_axis < 0) { split_axis += num_axes; }
int64_t split_axis = sbp_parallel.split_parallel().axis();
CHECK_GE_OR_RETURN(split_axis, 0);
CHECK_LT_OR_RETURN(split_axis, num_axes);

SbpSignatureBuilder sbp_signature_builder;
if (is_for_input_op) {
// broadcast tick args for InputOp
Expand All @@ -52,27 +56,27 @@ Maybe<void> GetSbpSignature(const InterfaceBlobConf& blob_conf, const PbRpf<std:

Maybe<void> InterfaceOpUtil::InferOutBlobDesc(const InterfaceBlobConf& blob_conf,
BlobDesc* out_blob_desc,
const ParallelContext* parallel_ctx) {
out_blob_desc->mut_shape() = Shape(blob_conf.shape());
CheckShape(out_blob_desc->shape());
CHECK_GT(out_blob_desc->mut_shape().At(0), 0);
const ParallelContext* parallel_ctx,
const ParallelDesc& parallel_desc) {
ParallelDistribution parallel_distribution;
JUST(ParseParallelDistributionFromBlobConf(blob_conf, parallel_desc, &parallel_distribution));
out_blob_desc->mut_shape() = *JUST(GetPhysicalShape(
Shape(blob_conf.shape()), parallel_distribution, parallel_desc, *parallel_ctx));
out_blob_desc->set_data_type(blob_conf.data_type());
out_blob_desc->set_is_dynamic(blob_conf.is_dynamic());
if (blob_conf.split_axis().has_value()) {
int64_t split_axis = blob_conf.split_axis().value();
BalancedSplitter bs(out_blob_desc->shape().At(split_axis), parallel_ctx->parallel_num());
out_blob_desc->mut_shape().Set(split_axis, bs.At(parallel_ctx->parallel_id()).size());
}
return Maybe<void>::Ok();
}

Maybe<void> InterfaceOpUtil::InferLogicalOutBlobDesc(const InterfaceBlobConf& blob_conf,
BlobDesc* out_blob_desc,
const ParallelDesc& parallel_desc) {
CHECK_OR_RETURN(blob_conf.has_shape());
out_blob_desc->mut_shape() = Shape(blob_conf.shape());
CheckShape(out_blob_desc->shape());
CHECK_GT(out_blob_desc->mut_shape().At(0), 0);
CHECK_OR_RETURN(blob_conf.has_data_type());
out_blob_desc->set_data_type(blob_conf.data_type());
CHECK_OR_RETURN(blob_conf.has_is_dynamic());
out_blob_desc->set_is_dynamic(blob_conf.is_dynamic());
return Maybe<void>::Ok();
}
Expand All @@ -99,13 +103,21 @@ Maybe<void> InterfaceOpUtil::InitBlobConf(InterfaceBlobConf* blob_conf,
blob_desc.shape().ToProto(blob_conf->mutable_shape());
blob_conf->set_data_type(blob_desc.data_type());
blob_conf->set_is_dynamic(blob_desc.is_dynamic());
if (parallel_blob_conf.sbp_conf().has_split_parallel()) {
int64_t axis = parallel_blob_conf.sbp_conf().split_parallel().axis();
blob_conf->mutable_split_axis()->set_value(axis);
} else if (parallel_blob_conf.sbp_conf().has_broadcast_parallel()) {
blob_conf->mutable_split_axis()->clear_value();
*blob_conf->mutable_parallel_distribution() = parallel_blob_conf.parallel_distribution();
return Maybe<void>::Ok();
}

Maybe<void> InterfaceOpUtil::ParseParallelDistributionFromBlobConf(
const InterfaceBlobConf& blob_conf, const ParallelDesc& parallel_desc,
ParallelDistribution* parallel_distribution) {
const int64_t num_axes = parallel_desc.hierarchy()->NumAxes();
if (blob_conf.has_parallel_distribution()) {
*parallel_distribution = blob_conf.parallel_distribution();
} else {
OF_UNIMPLEMENTED();
parallel_distribution->clear_sbp_parallel();
FOR_RANGE(int64_t, i, 0, num_axes) {
parallel_distribution->add_sbp_parallel()->mutable_broadcast_parallel();
}
}
return Maybe<void>::Ok();
}
Expand Down
7 changes: 6 additions & 1 deletion oneflow/core/operator/interface_op_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ namespace oneflow {

struct InterfaceOpUtil final {
static Maybe<void> InferOutBlobDesc(const InterfaceBlobConf& blob_conf, BlobDesc* out_blob_desc,
const ParallelContext* parallel_ctx);
const ParallelContext* parallel_ctx,
const ParallelDesc& parallel_desc);
static Maybe<void> InferLogicalOutBlobDesc(const InterfaceBlobConf& blob_conf,
BlobDesc* out_blob_desc,
const ParallelDesc& parallel_desc);
Expand All @@ -38,6 +39,10 @@ struct InterfaceOpUtil final {
SbpSignature* sbp_signature);
static Maybe<void> InitBlobConf(InterfaceBlobConf* blob_conf,
const ParallelBlobConf& parallel_blob_conf);

static Maybe<void> ParseParallelDistributionFromBlobConf(
const InterfaceBlobConf& blob_conf, const ParallelDesc& parallel_desc,
ParallelDistribution* parallel_distribution);
};

} // namespace oneflow
Expand Down
31 changes: 25 additions & 6 deletions oneflow/core/operator/output_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ Maybe<void> OutputOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
BlobDesc* out_blob_desc = BlobDesc4BnInOp("out");
InterfaceOpUtil::InferLogicalOutBlobDesc(op_conf().output_conf().blob_conf(), out_blob_desc,
parallel_desc);
JUST(InterfaceOpUtil::InferLogicalOutBlobDesc(op_conf().output_conf().blob_conf(), out_blob_desc,
parallel_desc));
return Maybe<void>::Ok();
}

Expand All @@ -42,8 +42,8 @@ Maybe<void> OutputOp::InferOutBlobDescs(
if (in_blob_desc->is_dynamic()) {
*out_blob_desc = *in_blob_desc;
} else {
InterfaceOpUtil::InferOutBlobDesc(op_conf().output_conf().blob_conf(), out_blob_desc,
parallel_ctx);
JUST(InterfaceOpUtil::InferOutBlobDesc(op_conf().output_conf().blob_conf(), out_blob_desc,
parallel_ctx, *JUST(GetOpParallelDesc())));
CHECK_OR_RETURN(*out_blob_desc == *in_blob_desc);
}
return Maybe<void>::Ok();
Expand All @@ -54,8 +54,27 @@ Maybe<void> OutputOp::InferSbpSignature(
const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,
std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,
const ParallelDesc& parallel_desc) const {
InterfaceOpUtil::GetOutputLikeOpSbpSignature(op_conf().output_conf().blob_conf(), input_bns(),
output_bns(), sbp_signature);
JUST(InterfaceOpUtil::GetOutputLikeOpSbpSignature(op_conf().output_conf().blob_conf(),
input_bns(), output_bns(), sbp_signature));
return Maybe<void>::Ok();
}

Maybe<void> OutputOp::InferParallelDistributionSignature(
ParallelDistributionSignature* parallel_distribution_signature,
const ParallelDistributionSignature& parallel_distribution_constraints,
const ParallelDesc& parallel_desc,
std::function<Maybe<const ParallelDistributionInferHint*>(const std::string&)>
ParallelDistributionInferHint4Ibn) const {
const InterfaceBlobConf& blob_conf = op_conf().output_conf().blob_conf();
ParallelDistribution& in_parallel_distribution =
(*parallel_distribution_signature->mutable_bn_in_op2parallel_distribution())["in"];
ParallelDistribution& out_parallel_distribution =
(*parallel_distribution_signature->mutable_bn_in_op2parallel_distribution())["out"];
JUST(InterfaceOpUtil::ParseParallelDistributionFromBlobConf(blob_conf, parallel_desc,
&in_parallel_distribution));
JUST(InterfaceOpUtil::ParseParallelDistributionFromBlobConf(blob_conf, parallel_desc,
&out_parallel_distribution));

return Maybe<void>::Ok();
}

Expand Down
6 changes: 6 additions & 0 deletions oneflow/core/operator/output_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ class OutputOp final : public Operator {
std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,
const ParallelDesc& parallel_desc) const override;
Symbol<OperatorConf> GetOpConfWithoutOpNameAndLbn() const override;
Maybe<void> InferParallelDistributionSignature(
ParallelDistributionSignature* parallel_distribution_signature,
const ParallelDistributionSignature& parallel_distribution_constraints,
const ParallelDesc& parallel_desc,
std::function<Maybe<const ParallelDistributionInferHint*>(const std::string&)>
ParallelDistributionInferHint4Ibn) const override;
};

} // namespace oneflow
Expand Down
5 changes: 4 additions & 1 deletion oneflow/python/framework/input_blob_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import oneflow
import oneflow.core.operator.op_conf_pb2 as op_conf_util
import oneflow.core.operator.interface_blob_conf_pb2 as inter_face_blob_conf_util
import oneflow.core.job.sbp_parallel_pb2 as sbp_parallel_pb
import oneflow.python.framework.c_api_util as c_api_util
import oneflow.python.framework.compile_context as compile_context
import oneflow.python.framework.distribute as distribute_util
Expand Down Expand Up @@ -114,7 +115,9 @@ def ToInterfaceBlobConf(self):
interface_blob_conf.is_dynamic = self.is_dynamic
# NOTE(chengcheng): rm batch_axis, so set split_axis always = 0 for safe. will support
# set sbp in future, or will delete in multi-client
interface_blob_conf.split_axis.value = 0
sbp_parallel = sbp_parallel_pb.SbpParallel()
sbp_parallel.split_parallel.axis = 0
interface_blob_conf.parallel_distribution.sbp_parallel.extend([sbp_parallel])
return interface_blob_conf

def _Distribute2Str(self):
Expand Down
16 changes: 12 additions & 4 deletions oneflow/python/serving/inference_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import oneflow_api.oneflow.core.operator.interface_blob_conf as interface_blob_conf_proto_cfg
import oneflow_api.oneflow.core.common.shape as shape_proto_cfg
import oneflow_api.oneflow.core.common.data_type as dtype_proto_cfg
import oneflow_api.oneflow.core.job.sbp_parallel as sbp_parallel_cfg
import oneflow.core.job.job_conf_pb2 as job_conf_proto
import oneflow.core.operator.interface_blob_conf_pb2 as interface_blob_conf_proto
import oneflow.core.serving.saved_model_pb2 as saved_model_pb
Expand Down Expand Up @@ -106,10 +107,17 @@ def _inferface_blob_conf_proto_to_cfg(
dtype = dtype_proto_cfg.DataType(int(inferface_blob_conf_proto.data_type))
mut_inferface_blob_conf_cfg.set_data_type(dtype)

split_axis = dtype_proto_cfg.OptInt64()
if inferface_blob_conf_proto.split_axis.HasField("value"):
split_axis.set_value(inferface_blob_conf_proto.split_axis.value)
mut_inferface_blob_conf_cfg.mutable_split_axis().CopyFrom(split_axis)
if inferface_blob_conf_proto.HasField("parallel_distribution"):
# TODO(guoran): Process Nd sbp, parallel_distribution_cfg CopyFrom parallel_distribution_proto
assert len(inferface_blob_conf_proto.parallel_distribution.sbp_parallel) == 1
sbp_proto = inferface_blob_conf_proto.parallel_distribution.sbp_parallel[0]
if sbp_proto.HasField("split_parallel"):
split_axis = sbp_proto.split_parallel.axis
sbp = sbp_parallel_cfg.SbpParallel()
sbp.mutable_split_parallel().set_axis(split_axis)
mut_inferface_blob_conf_cfg.mutable_parallel_distribution().mutable_sbp_parallel().Add().CopyFrom(
sbp
)

mut_inferface_blob_conf_cfg.set_is_dynamic(inferface_blob_conf_proto.is_dynamic)

Expand Down
6 changes: 5 additions & 1 deletion oneflow/python/serving/saved_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import oneflow.core.job.job_conf_pb2 as job_conf_pb
import oneflow.core.register.logical_blob_id_pb2 as logical_blob_id_pb
import oneflow.core.operator.interface_blob_conf_pb2 as interface_blob_conf_pb
import oneflow.core.job.sbp_parallel_pb2 as sbp_parallel_pb
from oneflow.python.oneflow_export import oneflow_export


Expand Down Expand Up @@ -315,7 +316,10 @@ def GetInterfaceBlobConf(job_name, lbn, blob_conf=None):
blob_conf.shape.dim.extend(shape)
blob_conf.data_type = dtype
if split_axis is not None:
blob_conf.split_axis.value = split_axis
sbp_parallel = sbp_parallel_pb.SbpParallel()
sbp_parallel.split_parallel.axis = split_axis
blob_conf.parallel_distribution.sbp_parallel.extend([sbp_parallel])

blob_conf.is_dynamic = is_dynamic
return blob_conf

Expand Down

0 comments on commit 9d6ab3f

Please sign in to comment.