diff --git a/oneflow/core/framework/op_arg_util.cpp b/oneflow/core/framework/op_arg_util.cpp index c899ed45c63..ac0f01f6880 100644 --- a/oneflow/core/framework/op_arg_util.cpp +++ b/oneflow/core/framework/op_arg_util.cpp @@ -113,10 +113,13 @@ void OpArgParallelAttribute::Assign(const std::shared_ptr interface_blob_conf) const { + interface_blob_conf->mutable_parallel_distribution()->clear_sbp_parallel(); if (sbp_parallel_->has_split_parallel()) { - interface_blob_conf->mutable_split_axis()->set_value(sbp_parallel_->split_parallel().axis()); + *interface_blob_conf->mutable_parallel_distribution()->add_sbp_parallel() = *sbp_parallel_; } else { - interface_blob_conf->clear_split_axis(); + interface_blob_conf->mutable_parallel_distribution() + ->add_sbp_parallel() + ->mutable_broadcast_parallel(); } } diff --git a/oneflow/core/job/job_conf.proto b/oneflow/core/job/job_conf.proto index 7714a81a77c..66b1da46fcf 100644 --- a/oneflow/core/job/job_conf.proto +++ b/oneflow/core/job/job_conf.proto @@ -169,7 +169,7 @@ message IndexedSlicesOptimizerConf { message ParallelBlobConf { required BlobDescProto logical_blob_desc_conf = 1; required ParallelConf parallel_conf = 2; - required SbpParallel sbp_conf = 3; + required ParallelDistribution parallel_distribution = 3; } message JobInputDef { diff --git a/oneflow/core/job/model_io_job.cpp b/oneflow/core/job/model_io_job.cpp index 6fbbffc0954..fd4ef699cb8 100644 --- a/oneflow/core/job/model_io_job.cpp +++ b/oneflow/core/job/model_io_job.cpp @@ -46,7 +46,6 @@ OperatorConf GenForeignInputOpConf(const std::string& job_name, const int64_t in *blob_conf->mutable_shape()->mutable_dim()->Add() = input_size; blob_conf->set_data_type(DataType::kInt8); blob_conf->set_is_dynamic(true); - blob_conf->mutable_split_axis()->clear_value(); return foreign_input_op_conf; } diff --git a/oneflow/core/job/model_io_v2_job.cpp b/oneflow/core/job/model_io_v2_job.cpp index 3eade86cada..9816be020e6 100644 --- a/oneflow/core/job/model_io_v2_job.cpp +++ b/oneflow/core/job/model_io_v2_job.cpp @@ -46,7 +46,6 @@ OperatorConf GenForeignInputOpConf(const std::string& job_name, const int64_t in *blob_conf->mutable_shape()->mutable_dim()->Add() = input_size; blob_conf->set_is_dynamic(true); blob_conf->set_data_type(DataType::kInt8); - blob_conf->mutable_split_axis()->clear_value(); return foreign_input_op_conf; } diff --git a/oneflow/core/job/oneflow.cpp b/oneflow/core/job/oneflow.cpp index c863e480410..293b3318072 100644 --- a/oneflow/core/job/oneflow.cpp +++ b/oneflow/core/job/oneflow.cpp @@ -60,7 +60,8 @@ namespace oneflow { bool operator==(const ParallelBlobConf& lhs, const ParallelBlobConf& rhs) { return BlobDesc(lhs.logical_blob_desc_conf()) == BlobDesc(rhs.logical_blob_desc_conf()) - && lhs.parallel_conf() == rhs.parallel_conf() && lhs.sbp_conf() == rhs.sbp_conf(); + && lhs.parallel_conf() == rhs.parallel_conf() + && lhs.parallel_distribution() == rhs.parallel_distribution(); } namespace { @@ -504,11 +505,11 @@ void GetMemSharingOpBlobInfo(const JobBuilder& job_builder, const std::string& o ParallelBlobConf ret; *blob_conf->mutable_parallel_conf() = job_builder.ParallelConf4OpName(op_name); *blob_conf->mutable_logical_blob_desc_conf() = job.helper().lbn2logical_blob_desc().at(lbn); - *blob_conf->mutable_sbp_conf() = job.job_parallel_view_conf() - .op_name2sbp_signature_conf() - .at(op_name) - .bn_in_op2sbp_parallel() - .at(obn); + *blob_conf->mutable_parallel_distribution() = job.job_parallel_view_conf() + .op_name2parallel_distribution_signature_conf() + .at(op_name) + .bn_in_op2parallel_distribution() + .at(obn); } void FilterOpName2ParallelBlobConf( diff --git a/oneflow/core/operator/input_op.cpp b/oneflow/core/operator/input_op.cpp index 6e52adab56c..0264f266e8d 100644 --- a/oneflow/core/operator/input_op.cpp +++ b/oneflow/core/operator/input_op.cpp @@ -42,7 +42,7 @@ Maybe InputOp::InferOutBlobDescs( const ParallelContext* parallel_ctx) const { BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); JUST(InterfaceOpUtil::InferOutBlobDesc(op_conf().input_conf().blob_conf(), out_blob_desc, - parallel_ctx)); + parallel_ctx, *JUST(GetOpParallelDesc()))); return Maybe::Ok(); } @@ -51,15 +51,36 @@ Maybe InputOp::InferSbpSignature( const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const { - InterfaceOpUtil::GetInputLikeOpSbpSignature(op_conf().input_conf().blob_conf(), input_bns(), - output_bns(), sbp_signature); + JUST(InterfaceOpUtil::GetInputLikeOpSbpSignature(op_conf().input_conf().blob_conf(), input_bns(), + output_bns(), sbp_signature)); return Maybe::Ok(); } Maybe InputOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const { - InterfaceOpUtil::GetInputLikeOpSbpSignature(op_conf().input_conf().blob_conf(), input_bns(), - output_bns(), - sbp_sig_list->mutable_sbp_signature()->Add()); + JUST(InterfaceOpUtil::GetInputLikeOpSbpSignature(op_conf().input_conf().blob_conf(), input_bns(), + output_bns(), + sbp_sig_list->mutable_sbp_signature()->Add())); + return Maybe::Ok(); +} + +Maybe InputOp::InferParallelDistributionSignature( + ParallelDistributionSignature* parallel_distribution_signature, + const ParallelDistributionSignature& parallel_distribution_constraints, + const ParallelDesc& parallel_desc, + std::function(const std::string&)> + ParallelDistributionInferHint4Ibn) const { + const auto& parallel_hierarchy = parallel_desc.hierarchy(); + const InterfaceBlobConf& blob_conf = op_conf().input_conf().blob_conf(); + ParallelDistribution& tick_parallel_distribution = + (*parallel_distribution_signature->mutable_bn_in_op2parallel_distribution())["tick"]; + tick_parallel_distribution.clear_sbp_parallel(); + FOR_RANGE(int64_t, i, 0, parallel_hierarchy->NumAxes()) { + tick_parallel_distribution.mutable_sbp_parallel()->Add()->mutable_broadcast_parallel(); + } + ParallelDistribution& out_parallel_distribution = + (*parallel_distribution_signature->mutable_bn_in_op2parallel_distribution())["out"]; + JUST(InterfaceOpUtil::ParseParallelDistributionFromBlobConf(blob_conf, parallel_desc, + &out_parallel_distribution)); return Maybe::Ok(); } diff --git a/oneflow/core/operator/input_op.h b/oneflow/core/operator/input_op.h index 64bb7093668..f64bb67a578 100644 --- a/oneflow/core/operator/input_op.h +++ b/oneflow/core/operator/input_op.h @@ -43,6 +43,12 @@ class InputOp final : public Operator { Maybe GetSbpSignatures(SbpSignatureList* sbp_sig_list) const override; Symbol GetOpConfWithoutOpNameAndLbn() const override; + Maybe InferParallelDistributionSignature( + ParallelDistributionSignature* parallel_distribution_signature, + const ParallelDistributionSignature& parallel_distribution_constraints, + const ParallelDesc& parallel_desc, + std::function(const std::string&)> + ParallelDistributionInferHint4Ibn) const override; }; } // namespace oneflow diff --git a/oneflow/core/operator/interface_blob_conf.proto b/oneflow/core/operator/interface_blob_conf.proto index 9aa7422a9b1..273c9fc4a11 100644 --- a/oneflow/core/operator/interface_blob_conf.proto +++ b/oneflow/core/operator/interface_blob_conf.proto @@ -3,10 +3,11 @@ package oneflow; import "oneflow/core/common/shape.proto"; import "oneflow/core/common/data_type.proto"; +import "oneflow/core/job/sbp_parallel.proto"; message InterfaceBlobConf { optional ShapeProto shape = 1; optional DataType data_type = 2; - optional OptInt64 split_axis = 3; - optional bool is_dynamic = 5; + optional bool is_dynamic = 3; + optional ParallelDistribution parallel_distribution = 4; } diff --git a/oneflow/core/operator/interface_op_util.cpp b/oneflow/core/operator/interface_op_util.cpp index e2b09eb3e20..2fb3d89c49a 100644 --- a/oneflow/core/operator/interface_op_util.cpp +++ b/oneflow/core/operator/interface_op_util.cpp @@ -27,13 +27,17 @@ void CheckShape(const Shape& shape) { Maybe GetSbpSignature(const InterfaceBlobConf& blob_conf, const PbRpf& input_bns, const PbRpf& output_bns, SbpSignature* sbp_signature, bool is_for_input_op) { - if (blob_conf.split_axis().has_value()) { + if (!blob_conf.has_parallel_distribution()) { + SbpSignatureBuilder().Broadcast(input_bns).Broadcast(output_bns).Build(sbp_signature); + return Maybe::Ok(); + } + CHECK_EQ_OR_RETURN(blob_conf.parallel_distribution().sbp_parallel_size(), 1); + const auto& sbp_parallel = blob_conf.parallel_distribution().sbp_parallel(0); + if (sbp_parallel.has_split_parallel()) { int64_t num_axes = blob_conf.shape().dim_size(); - int64_t split_axis = blob_conf.split_axis().value(); - if (split_axis < 0) { split_axis += num_axes; } + int64_t split_axis = sbp_parallel.split_parallel().axis(); CHECK_GE_OR_RETURN(split_axis, 0); CHECK_LT_OR_RETURN(split_axis, num_axes); - SbpSignatureBuilder sbp_signature_builder; if (is_for_input_op) { // broadcast tick args for InputOp @@ -52,27 +56,27 @@ Maybe GetSbpSignature(const InterfaceBlobConf& blob_conf, const PbRpf InterfaceOpUtil::InferOutBlobDesc(const InterfaceBlobConf& blob_conf, BlobDesc* out_blob_desc, - const ParallelContext* parallel_ctx) { - out_blob_desc->mut_shape() = Shape(blob_conf.shape()); - CheckShape(out_blob_desc->shape()); - CHECK_GT(out_blob_desc->mut_shape().At(0), 0); + const ParallelContext* parallel_ctx, + const ParallelDesc& parallel_desc) { + ParallelDistribution parallel_distribution; + JUST(ParseParallelDistributionFromBlobConf(blob_conf, parallel_desc, ¶llel_distribution)); + out_blob_desc->mut_shape() = *JUST(GetPhysicalShape( + Shape(blob_conf.shape()), parallel_distribution, parallel_desc, *parallel_ctx)); out_blob_desc->set_data_type(blob_conf.data_type()); out_blob_desc->set_is_dynamic(blob_conf.is_dynamic()); - if (blob_conf.split_axis().has_value()) { - int64_t split_axis = blob_conf.split_axis().value(); - BalancedSplitter bs(out_blob_desc->shape().At(split_axis), parallel_ctx->parallel_num()); - out_blob_desc->mut_shape().Set(split_axis, bs.At(parallel_ctx->parallel_id()).size()); - } return Maybe::Ok(); } Maybe InterfaceOpUtil::InferLogicalOutBlobDesc(const InterfaceBlobConf& blob_conf, BlobDesc* out_blob_desc, const ParallelDesc& parallel_desc) { + CHECK_OR_RETURN(blob_conf.has_shape()); out_blob_desc->mut_shape() = Shape(blob_conf.shape()); CheckShape(out_blob_desc->shape()); CHECK_GT(out_blob_desc->mut_shape().At(0), 0); + CHECK_OR_RETURN(blob_conf.has_data_type()); out_blob_desc->set_data_type(blob_conf.data_type()); + CHECK_OR_RETURN(blob_conf.has_is_dynamic()); out_blob_desc->set_is_dynamic(blob_conf.is_dynamic()); return Maybe::Ok(); } @@ -99,13 +103,21 @@ Maybe InterfaceOpUtil::InitBlobConf(InterfaceBlobConf* blob_conf, blob_desc.shape().ToProto(blob_conf->mutable_shape()); blob_conf->set_data_type(blob_desc.data_type()); blob_conf->set_is_dynamic(blob_desc.is_dynamic()); - if (parallel_blob_conf.sbp_conf().has_split_parallel()) { - int64_t axis = parallel_blob_conf.sbp_conf().split_parallel().axis(); - blob_conf->mutable_split_axis()->set_value(axis); - } else if (parallel_blob_conf.sbp_conf().has_broadcast_parallel()) { - blob_conf->mutable_split_axis()->clear_value(); + *blob_conf->mutable_parallel_distribution() = parallel_blob_conf.parallel_distribution(); + return Maybe::Ok(); +} + +Maybe InterfaceOpUtil::ParseParallelDistributionFromBlobConf( + const InterfaceBlobConf& blob_conf, const ParallelDesc& parallel_desc, + ParallelDistribution* parallel_distribution) { + const int64_t num_axes = parallel_desc.hierarchy()->NumAxes(); + if (blob_conf.has_parallel_distribution()) { + *parallel_distribution = blob_conf.parallel_distribution(); } else { - OF_UNIMPLEMENTED(); + parallel_distribution->clear_sbp_parallel(); + FOR_RANGE(int64_t, i, 0, num_axes) { + parallel_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); + } } return Maybe::Ok(); } diff --git a/oneflow/core/operator/interface_op_util.h b/oneflow/core/operator/interface_op_util.h index 9580d310d5c..86673e36bf5 100644 --- a/oneflow/core/operator/interface_op_util.h +++ b/oneflow/core/operator/interface_op_util.h @@ -24,7 +24,8 @@ namespace oneflow { struct InterfaceOpUtil final { static Maybe InferOutBlobDesc(const InterfaceBlobConf& blob_conf, BlobDesc* out_blob_desc, - const ParallelContext* parallel_ctx); + const ParallelContext* parallel_ctx, + const ParallelDesc& parallel_desc); static Maybe InferLogicalOutBlobDesc(const InterfaceBlobConf& blob_conf, BlobDesc* out_blob_desc, const ParallelDesc& parallel_desc); @@ -38,6 +39,10 @@ struct InterfaceOpUtil final { SbpSignature* sbp_signature); static Maybe InitBlobConf(InterfaceBlobConf* blob_conf, const ParallelBlobConf& parallel_blob_conf); + + static Maybe ParseParallelDistributionFromBlobConf( + const InterfaceBlobConf& blob_conf, const ParallelDesc& parallel_desc, + ParallelDistribution* parallel_distribution); }; } // namespace oneflow diff --git a/oneflow/core/operator/output_op.cpp b/oneflow/core/operator/output_op.cpp index c2ec6613492..d9040f91cc2 100644 --- a/oneflow/core/operator/output_op.cpp +++ b/oneflow/core/operator/output_op.cpp @@ -29,8 +29,8 @@ Maybe OutputOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { BlobDesc* out_blob_desc = BlobDesc4BnInOp("out"); - InterfaceOpUtil::InferLogicalOutBlobDesc(op_conf().output_conf().blob_conf(), out_blob_desc, - parallel_desc); + JUST(InterfaceOpUtil::InferLogicalOutBlobDesc(op_conf().output_conf().blob_conf(), out_blob_desc, + parallel_desc)); return Maybe::Ok(); } @@ -42,8 +42,8 @@ Maybe OutputOp::InferOutBlobDescs( if (in_blob_desc->is_dynamic()) { *out_blob_desc = *in_blob_desc; } else { - InterfaceOpUtil::InferOutBlobDesc(op_conf().output_conf().blob_conf(), out_blob_desc, - parallel_ctx); + JUST(InterfaceOpUtil::InferOutBlobDesc(op_conf().output_conf().blob_conf(), out_blob_desc, + parallel_ctx, *JUST(GetOpParallelDesc()))); CHECK_OR_RETURN(*out_blob_desc == *in_blob_desc); } return Maybe::Ok(); @@ -54,8 +54,27 @@ Maybe OutputOp::InferSbpSignature( const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const { - InterfaceOpUtil::GetOutputLikeOpSbpSignature(op_conf().output_conf().blob_conf(), input_bns(), - output_bns(), sbp_signature); + JUST(InterfaceOpUtil::GetOutputLikeOpSbpSignature(op_conf().output_conf().blob_conf(), + input_bns(), output_bns(), sbp_signature)); + return Maybe::Ok(); +} + +Maybe OutputOp::InferParallelDistributionSignature( + ParallelDistributionSignature* parallel_distribution_signature, + const ParallelDistributionSignature& parallel_distribution_constraints, + const ParallelDesc& parallel_desc, + std::function(const std::string&)> + ParallelDistributionInferHint4Ibn) const { + const InterfaceBlobConf& blob_conf = op_conf().output_conf().blob_conf(); + ParallelDistribution& in_parallel_distribution = + (*parallel_distribution_signature->mutable_bn_in_op2parallel_distribution())["in"]; + ParallelDistribution& out_parallel_distribution = + (*parallel_distribution_signature->mutable_bn_in_op2parallel_distribution())["out"]; + JUST(InterfaceOpUtil::ParseParallelDistributionFromBlobConf(blob_conf, parallel_desc, + &in_parallel_distribution)); + JUST(InterfaceOpUtil::ParseParallelDistributionFromBlobConf(blob_conf, parallel_desc, + &out_parallel_distribution)); + return Maybe::Ok(); } diff --git a/oneflow/core/operator/output_op.h b/oneflow/core/operator/output_op.h index 1966fb58a36..684367f2259 100644 --- a/oneflow/core/operator/output_op.h +++ b/oneflow/core/operator/output_op.h @@ -41,6 +41,12 @@ class OutputOp final : public Operator { std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const override; Symbol GetOpConfWithoutOpNameAndLbn() const override; + Maybe InferParallelDistributionSignature( + ParallelDistributionSignature* parallel_distribution_signature, + const ParallelDistributionSignature& parallel_distribution_constraints, + const ParallelDesc& parallel_desc, + std::function(const std::string&)> + ParallelDistributionInferHint4Ibn) const override; }; } // namespace oneflow diff --git a/oneflow/python/framework/input_blob_def.py b/oneflow/python/framework/input_blob_def.py index 9fb6d6e7242..364b82a4f56 100644 --- a/oneflow/python/framework/input_blob_def.py +++ b/oneflow/python/framework/input_blob_def.py @@ -24,6 +24,7 @@ import oneflow import oneflow.core.operator.op_conf_pb2 as op_conf_util import oneflow.core.operator.interface_blob_conf_pb2 as inter_face_blob_conf_util +import oneflow.core.job.sbp_parallel_pb2 as sbp_parallel_pb import oneflow.python.framework.c_api_util as c_api_util import oneflow.python.framework.compile_context as compile_context import oneflow.python.framework.distribute as distribute_util @@ -114,7 +115,9 @@ def ToInterfaceBlobConf(self): interface_blob_conf.is_dynamic = self.is_dynamic # NOTE(chengcheng): rm batch_axis, so set split_axis always = 0 for safe. will support # set sbp in future, or will delete in multi-client - interface_blob_conf.split_axis.value = 0 + sbp_parallel = sbp_parallel_pb.SbpParallel() + sbp_parallel.split_parallel.axis = 0 + interface_blob_conf.parallel_distribution.sbp_parallel.extend([sbp_parallel]) return interface_blob_conf def _Distribute2Str(self): diff --git a/oneflow/python/serving/inference_session.py b/oneflow/python/serving/inference_session.py index 2fa83bd051f..4f76614ca4b 100644 --- a/oneflow/python/serving/inference_session.py +++ b/oneflow/python/serving/inference_session.py @@ -28,6 +28,7 @@ import oneflow_api.oneflow.core.operator.interface_blob_conf as interface_blob_conf_proto_cfg import oneflow_api.oneflow.core.common.shape as shape_proto_cfg import oneflow_api.oneflow.core.common.data_type as dtype_proto_cfg +import oneflow_api.oneflow.core.job.sbp_parallel as sbp_parallel_cfg import oneflow.core.job.job_conf_pb2 as job_conf_proto import oneflow.core.operator.interface_blob_conf_pb2 as interface_blob_conf_proto import oneflow.core.serving.saved_model_pb2 as saved_model_pb @@ -106,10 +107,17 @@ def _inferface_blob_conf_proto_to_cfg( dtype = dtype_proto_cfg.DataType(int(inferface_blob_conf_proto.data_type)) mut_inferface_blob_conf_cfg.set_data_type(dtype) - split_axis = dtype_proto_cfg.OptInt64() - if inferface_blob_conf_proto.split_axis.HasField("value"): - split_axis.set_value(inferface_blob_conf_proto.split_axis.value) - mut_inferface_blob_conf_cfg.mutable_split_axis().CopyFrom(split_axis) + if inferface_blob_conf_proto.HasField("parallel_distribution"): + # TODO(guoran): Process Nd sbp, parallel_distribution_cfg CopyFrom parallel_distribution_proto + assert len(inferface_blob_conf_proto.parallel_distribution.sbp_parallel) == 1 + sbp_proto = inferface_blob_conf_proto.parallel_distribution.sbp_parallel[0] + if sbp_proto.HasField("split_parallel"): + split_axis = sbp_proto.split_parallel.axis + sbp = sbp_parallel_cfg.SbpParallel() + sbp.mutable_split_parallel().set_axis(split_axis) + mut_inferface_blob_conf_cfg.mutable_parallel_distribution().mutable_sbp_parallel().Add().CopyFrom( + sbp + ) mut_inferface_blob_conf_cfg.set_is_dynamic(inferface_blob_conf_proto.is_dynamic) diff --git a/oneflow/python/serving/saved_model_builder.py b/oneflow/python/serving/saved_model_builder.py index 6f7a2b96598..e97c7ffd267 100644 --- a/oneflow/python/serving/saved_model_builder.py +++ b/oneflow/python/serving/saved_model_builder.py @@ -26,6 +26,7 @@ import oneflow.core.job.job_conf_pb2 as job_conf_pb import oneflow.core.register.logical_blob_id_pb2 as logical_blob_id_pb import oneflow.core.operator.interface_blob_conf_pb2 as interface_blob_conf_pb +import oneflow.core.job.sbp_parallel_pb2 as sbp_parallel_pb from oneflow.python.oneflow_export import oneflow_export @@ -315,7 +316,10 @@ def GetInterfaceBlobConf(job_name, lbn, blob_conf=None): blob_conf.shape.dim.extend(shape) blob_conf.data_type = dtype if split_axis is not None: - blob_conf.split_axis.value = split_axis + sbp_parallel = sbp_parallel_pb.SbpParallel() + sbp_parallel.split_parallel.axis = split_axis + blob_conf.parallel_distribution.sbp_parallel.extend([sbp_parallel]) + blob_conf.is_dynamic = is_dynamic return blob_conf