Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plan seperation compile #9913

Closed
wants to merge 43 commits into from
Closed
Changes from 1 commit
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
560a511
implement RankTaskGraph
lixinqi Sep 19, 2022
51034f0
RankCompiler
lixinqi Sep 19, 2022
3736494
fix compiler complaints
lixinqi Sep 20, 2022
5807882
CompTaskNode::ConsumeFakeRegsts
lixinqi Sep 20, 2022
c07ea5e
TransportTaskProto::lbi
lixinqi Sep 20, 2022
1fd10ef
makes sure all ranks know all var_op_names
lixinqi Sep 22, 2022
74c96df
RankTaskGraph::ForEachDutyRank
lixinqi Sep 22, 2022
e89143b
PortableCtrlEdge
lixinqi Sep 23, 2022
1b10509
compile in MultiThreadLoop
lixinqi Sep 23, 2022
44bf12b
CompileMode
lixinqi Sep 23, 2022
bd50bc7
rebuild new_task_id_ before ProduceRegst
lixinqi Sep 26, 2022
7853956
RankTaskGraph::InitRegstDescsConsumers()
lixinqi Sep 26, 2022
b725318
PlanUtil::GenReachableTaskPairs
lixinqi Sep 27, 2022
45bc629
disable checking consumer_task_regst_desc_id_size
lixinqi Sep 27, 2022
3c4ea9d
TaskNode::InitConsumedRegstsFromProto
lixinqi Sep 27, 2022
9880ba4
remove RegstDesc::InitConsumersFromProto
lixinqi Sep 27, 2022
20175fc
refactor CompTaskNode::ConsumeFakeRegstsIf
lixinqi Sep 27, 2022
fbff274
refactor CompTaskNode::ConsumeFakeRegsts
lixinqi Sep 27, 2022
ede3cd2
remove Plan::fake_consumed_regst_desc_id
lixinqi Sep 27, 2022
3ba45e5
revert part of code in job/plan_util.cpp
lixinqi Sep 28, 2022
2e9ab1a
refacotr ParallelDesc::TryGetParallelId
lixinqi Sep 28, 2022
93a7947
cut boxing_task_graph by rank
lixinqi Sep 29, 2022
818d14d
make sure TaskIdGenerator::Generator is thread safe
lixinqi Oct 8, 2022
8ca22bf
atomic<int64_t> mem_block_id
lixinqi Oct 8, 2022
2adbb13
chunk id add lock
strint Oct 8, 2022
ccf9bea
get chunk proto with lock
strint Oct 9, 2022
2c577df
create chunk with lock
strint Oct 9, 2022
fa49459
mutable std::mutex
lixinqi Oct 9, 2022
a4e67b0
Rank task graph merge master (#9440)
strint Nov 22, 2022
7d69c25
fix conflict
strint Nov 22, 2022
d4782a7
auto format by CI
oneflow-ci-bot Nov 22, 2022
eb76987
fix conflict
strint Nov 22, 2022
bb9e65e
Merge branch 'rank_task_graph' of https://github.com/Oneflow-Inc/onef…
strint Nov 22, 2022
6b575fc
fix conflict
strint Nov 22, 2022
13ba2ac
auto format by CI
oneflow-ci-bot Nov 22, 2022
92face0
fix conflict
strint Nov 22, 2022
1b2edca
fix
strint Nov 22, 2022
3910af6
address pr comments
lixinqi Nov 24, 2022
a37f9f8
fix bug
strint Dec 13, 2022
132a8a7
Rank task graph fix (#9749)
strint Feb 28, 2023
f1352e6
rm useless
strint Feb 28, 2023
6b13581
fix muti thread merge bug
strint Feb 28, 2023
035e83c
Plan sep compile merge master (#9915)
strint Mar 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
402 changes: 223 additions & 179 deletions oneflow/core/functional/impl/array_functor.cpp
Original file line number Diff line number Diff line change
@@ -3373,14 +3373,42 @@ class FillTensorFunctor {
};

class IndexAddFunctor {
public:
IndexAddFunctor() {
op_ = CHECK_JUST(one::OpBuilder("index_add")
.Input("input")
.Input("index")
.Input("source")
.Output("output")
.Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const int64_t& dim,
const std::shared_ptr<one::Tensor>& index,
const std::shared_ptr<one::Tensor>& source, const Scalar& alpha) const {
CHECK_OR_RETURN(source->ndim() == 0 || index->shape()->Count(0) == source->shape()->At(dim))
<< "index_copy_(): Number of indices (," << index->shape()->Count(0)
<< ", \") should be equal to source.size(dim) (," << source->shape()->At(dim) << ", \")";
CHECK_OR_RETURN(index->dtype()->data_type() != DataType::kInt32
|| index->dtype()->data_type() != DataType::kInt64)
<< "Input(Index) holds the wrong type, it holds "
<< DataType_Name(index->dtype()->data_type())
<< " , but "
"desires to be int32_t or int64_t";
const float alpha_value = alpha.As<float>();
int64_t dim_ = dim;
dim_ = JUST(maybe_wrap_dim(dim_, input->ndim()));
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dim", "alpha");
attrs.SetAllAttrs(dim_, alpha_value);
TensorProcessor tensor_processor;
JUST(tensor_processor.PromoteInputsToCommonDtype(true, input->dtype())
.AddInputs({input, source})
.Apply());
TensorTuple input_tuple = JUST(tensor_processor.GetInputs());
return OpInterpUtil::Dispatch<Tensor>(*op_, {input, index, input_tuple.at(1)}, attrs);
}
TensorTuple input_tuple = JUST(tensor_processor.GetInputs());
return OpInterpUtil::Dispatch<Tensor>(*op_, {input, index, input_tuple.at(1)}, attrs);
}

private : std::shared_ptr<OpExpr>
op_;
private:
std::shared_ptr<OpExpr> op_;
};

class IndexAddInplaceFunctor {
@@ -3459,196 +3487,212 @@ class BroadcastTensorsFunctor {
};
class BinCountFunctor {
public:
op_ = CHECK_JUST(OpBuilder("bincount").Input("in").Output("out").Build());
weight_op_ =
BinCountFunctor() {
op_ = CHECK_JUST(OpBuilder("bincount").Input("in").Output("out").Build());
weight_op_ =
CHECK_JUST(OpBuilder("bincount").Input("in").Input("weight").Output("out").Build());
}

Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& input, const Optional<Tensor>& weight,
const Optional<int64_t>& minlength) const {
CHECK_OR_RETURN(!input->dtype()->is_floating_point()) << "bincount can only support int tensor";
TensorProcessor tensor_processor;
JUST(tensor_processor.AddInputs({input}, DType::Int64()).Apply());
const auto x = JUST(tensor_processor.GetInputs()).at(0);
std::shared_ptr<Tensor> local_tensor = x;
int64_t max = 0;

// check min value
{
if (x->is_global()) { local_tensor = JUST(GlobalToLocal(x, false)); }
auto tensor_min = JUST(functional::Min(local_tensor));
int64_t min = 0;
const auto& callback_min =
SyncAutoMemcpy(stream, &min, eager_blob_object->dptr(), sizeof(min),
memory::MakeHostMemCase(), eager_blob_object->mem_case());
[&](ep::Stream* stream, const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {
SyncAutoMemcpy(stream, &min, eager_blob_object->dptr(), sizeof(min),
memory::MakeHostMemCase(), eager_blob_object->mem_case());
};
JUST(SyncAccessTensorWithTimeOut(tensor_min, callback_min, "const"));
CHECK_GE_OR_RETURN(min, 0) << "bincount only supports 1-d non-negative integral inputs.";

auto tensor_max = JUST(functional::Max(local_tensor));
const auto& callback_max =
SyncAutoMemcpy(stream, &max, eager_blob_object->dptr(), sizeof(max),
memory::MakeHostMemCase(), eager_blob_object->mem_case());
[&](ep::Stream* stream, const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {
SyncAutoMemcpy(stream, &max, eager_blob_object->dptr(), sizeof(max),
memory::MakeHostMemCase(), eager_blob_object->mem_case());
};
JUST(SyncAccessTensorWithTimeOut(tensor_max, callback_max, "const"));
max += 1;
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("size");
if (minlength) { max = std::max(JUST(minlength), max); }
attrs.SetAllAttrs(max);
if (weight) {
CHECK_EQ_OR_RETURN(JUST(weight)->nelement(), x->nelement())
<< "input and weights should have the same length";
return OpInterpUtil::Dispatch<Tensor>(*weight_op_, {x, JUST(weight)}, attrs);
} else {
return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);
}
}

private:
std::shared_ptr<OpExpr> op_;
std::shared_ptr<OpExpr> weight_op_;
};

} // namespace impl

ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::ArgMaxFunctor>("ArgMax");
m.add_functor<impl::ArgMinFunctor>("ArgMin");
m.add_functor<impl::GlobalConstantFunctor>("GlobalConstant");
m.add_functor<impl::ConstantFunctor>("Constant");
m.add_functor<impl::GlobalEmptyFunctor>("GlobalEmpty");
m.add_functor<impl::EmptyFunctor>("Empty");
m.add_functor<impl::ZerosLikeFunctor>("ZerosLike");
m.add_functor<impl::OnesLikeFunctor>("OnesLike");
m.add_functor<impl::FlattenFunctor>("Flatten");
m.add_functor<impl::FillFunctor>("Fill");
m.add_functor<impl::FillTensorFunctor>("FillTensor");
m.add_functor<impl::WhereFunctor>("Where");
m.add_functor<impl::WhereScalarXFunctor>("WhereScalarX");
m.add_functor<impl::WhereScalarYFunctor>("WhereScalarY");
m.add_functor<impl::WhereScalarXYFunctor>("WhereScalarXY");
m.add_functor<impl::ArgWhereFunctor>("ArgWhere");
m.add_functor<impl::NonZeroFunctor>("NonZero");
m.add_functor<impl::BroadcastLikeFunctor>("BroadcastLike");
m.add_functor<impl::ConcatFunctor>("Concat");
m.add_functor<impl::StackFunctor>("Stack");
m.add_functor<impl::StackGradFunctor>("StackGrad");
m.add_functor<impl::AtLeast1DFunctor>("AtLeast1D");
m.add_functor<impl::AtLeast1DListFunctor>("AtLeast1D");
m.add_functor<impl::AtLeast2DFunctor>("AtLeast2D");
m.add_functor<impl::AtLeast2DListFunctor>("AtLeast2D");
m.add_functor<impl::AtLeast3DFunctor>("AtLeast3D");
m.add_functor<impl::AtLeast3DListFunctor>("AtLeast3D");
m.add_functor<impl::HStackFunctor>("HStack");
m.add_functor<impl::ColumnStackFunctor>("ColumnStack");
m.add_functor<impl::VStackFunctor>("VStack");
m.add_functor<impl::RowStackFunctor>("RowStack");
m.add_functor<impl::DStackFunctor>("DStack");
m.add_functor<impl::ExpandFunctor>("Expand");
m.add_functor<impl::ExpandDimsFunctor>("ExpandDims");
m.add_functor<impl::ExpandDimsFunctor>("Unsqueeze");
m.add_functor<impl::UnsqueezeMultipleFunctor>("UnsqueezeMultiple");
m.add_functor<impl::SqueezeFunctor>("Squeeze");
m.add_functor<impl::RollFunctor>("Roll");
m.add_functor<impl::GatherFunctor>("Gather");
m.add_functor<impl::DimGatherFunctor>("DimGather");
m.add_functor<impl::ArgSortFunctor>("ArgSort");
m.add_functor<impl::SearchSortedFunctor>("SearchSorted");
m.add_functor<impl::SearchSortedScalarFunctor>("SearchSortedScalar");
m.add_functor<impl::GatherNdFunctor>("GatherNd");
m.add_functor<impl::ScatterNdFunctor>("ScatterNd");
m.add_functor<impl::TensorScatterNdUpdateFunctor>("TensorScatterNdUpdate");
m.add_functor<impl::ScatterNdLikeFunctor>("ScatterNdLike");
m.add_functor<impl::ReshapeFunctor>("Reshape");
m.add_functor<impl::ViewFunctor>("View");
m.add_functor<impl::ToContiguousFunctor>("ToContiguous");
m.add_functor<impl::InplaceToContiguousFunctor>("InplaceToContiguous");
m.add_functor<impl::NarrowFunctor>("Narrow");
m.add_functor<impl::NarrowGradFunctor>("NarrowGrad");
m.add_functor<impl::SliceUpdateFunctor>("SliceUpdate");
m.add_functor<impl::SliceFunctor>("Slice");
m.add_functor<impl::SliceGradFunctor>("SliceGrad");
m.add_functor<impl::SliceView1dContiguousFunctor>("SliceView1dContiguous");
m.add_functor<impl::CopyFunctor>("Copy");
m.add_functor<impl::FlipFunctor>("Flip");
m.add_functor<impl::UnfoldTensorFunctor>("UnfoldTensor");
m.add_functor<impl::UnfoldTensorGradFunctor>("UnfoldTensorGrad");
m.add_functor<impl::UpsampleGradFunctor>("UpsampleGrad");
m.add_functor<impl::UpsampleNearest2DFunctor>("UpsampleNearest2D");
m.add_functor<impl::UpsampleNearest2DGradFunctor>("UpsampleNearest2DGrad");
m.add_functor<impl::UpsampleBilinear2DFunctor>("UpsampleBilinear2D");
m.add_functor<impl::UpsampleBilinear2DGradFunctor>("UpsampleBilinear2DGrad");
m.add_functor<impl::UpsampleLinear1DFunctor>("UpsampleLinear1D");
m.add_functor<impl::UpsampleLinear1DGradFunctor>("UpsampleLinear1DGrad");
m.add_functor<impl::UpsampleNearest1DFunctor>("UpsampleNearest1D");
m.add_functor<impl::UpsampleNearest1DGradFunctor>("UpsampleNearest1DGrad");
m.add_functor<impl::UpsampleBicubic2DFunctor>("UpsampleBicubic2D");
m.add_functor<impl::UpsampleBicubic2DGradFunctor>("UpsampleBicubic2DGrad");
m.add_functor<impl::UpsampleNearest3DFunctor>("UpsampleNearest3D");
m.add_functor<impl::UpsampleNearest3DGradFunctor>("UpsampleNearest3DGrad");
m.add_functor<impl::UpsampleTrilinear3DFunctor>("UpsampleTrilinear3D");
m.add_functor<impl::UpsampleTrilinear3DGradFunctor>("UpsampleTrilinear3DGrad");
m.add_functor<impl::UnsortedSegmentSumLikeFunctor>("UnsortedSegmentSumLike");
m.add_functor<impl::UnsortedSegmentSumFunctor>("UnsortedSegmentSum");
m.add_functor<impl::TrilFunctor>("Tril");
m.add_functor<impl::TriuFunctor>("Triu");
m.add_functor<impl::InplaceTriuFunctor>("InplaceTriu");
m.add_functor<impl::DiagFunctor>("Diag");
m.add_functor<impl::DiagGradFunctor>("DiagGrad");
m.add_functor<impl::DiagonalFunctor>("Diagonal");
m.add_functor<impl::DiagonalGradFunctor>("DiagonalGrad");
m.add_functor<impl::TensorGetItemFunctor>("TensorGetItem");
m.add_functor<impl::DimScatterFunctorImpl<impl::DimScatterType::kUpdate>>("DimScatterUpdate");
m.add_functor<impl::DimScatterFunctorImpl<impl::DimScatterType::kAdd>>("DimScatterAdd");
m.add_functor<impl::DimScatterFunctorImpl<impl::DimScatterType::kMultiply>>("DimScatterMul");
m.add_functor<impl::DimScatterFunctor>("DimScatter");
m.add_functor<impl::DimScatterScalarFunctorImpl<impl::DimScatterType::kUpdate>>(
"DimScatterUpdateScalar");
m.add_functor<impl::DimScatterScalarFunctorImpl<impl::DimScatterType::kAdd>>(
"DimScatterAddScalar");
m.add_functor<impl::DimScatterScalarFunctorImpl<impl::DimScatterType::kMultiply>>(
"DimScatterMulScalar");
m.add_functor<impl::DimScatterScalarFunctor>("DimScatterScalar");
m.add_functor<impl::DimScatterAddLikeFunctor>("DimScatterAddLike");

m.add_functor<impl::TensorSetItemFunctor>("TensorSetItem");
m.add_functor<impl::CastLikeFunctor>("CastLike");
m.add_functor<impl::ElementwiseMinimumGradFunctor>("ElementwiseMinGrad");
m.add_functor<impl::ElementwiseMaximumGradFunctor>("ElementwiseMaxGrad");
m.add_functor<impl::BroadcastPowXGradFunctor>("BroadcastPowXGrad");
m.add_functor<impl::BroadcastPowYGradFunctor>("BroadcastPowYGrad");
m.add_functor<impl::DivGradFunctor>("DivGrad");
m.add_functor<impl::IdentityFunctor>("Identity");
m.add_functor<impl::AmpWhiteIdentityFunctor>("AmpWhiteIdentity");
m.add_functor<impl::AmpBlackIdentityFunctor>("AmpBlackIdentity");
m.add_functor<impl::ReduceSumLikeFunctor>("ReduceSumLike");
m.add_functor<impl::BroadcastReduceSumLikeFunctor>("BroadcastReduceSumLike");
m.add_functor<impl::SplitFunctor>("Split");
m.add_functor<impl::UnbindFunctor>("Unbind");
m.add_functor<impl::ChunkFunctor>("Chunk");
m.add_functor<impl::SplitLikeFunctor>("SplitLike");
m.add_functor<impl::SplitWithSizeFunctor>("SplitWithSize");
m.add_functor<impl::BatchGatherFunctor>("BatchGather");
m.add_functor<impl::UnsortedBatchSegmentSumFunctor>("UnsortedBatchSegmentSum");
m.add_functor<impl::MaskedFillFunctor<false>>("MaskedFill");
m.add_functor<impl::MaskedFillFunctor<true>>("MaskedFillInplace");
m.add_functor<impl::MeshgridFunctor>("Meshgrid");
m.add_functor<impl::IndexSelectFunctor>("IndexSelect");
m.add_functor<impl::ToFunctor, impl::To2Functor, impl::To3Functor, impl::To4Functor,
impl::ToDeviceFunctor>("To");
m.add_functor<impl::TopKFunctor>("TopK");
m.add_functor<impl::InTopKFunctor>("InTopK");
m.add_functor<impl::TensorToTensorBufferFunctor>("TensorToTensorBuffer");
m.add_functor<impl::TensorBufferToTensorFunctor>("TensorBufferToTensor");
m.add_functor<impl::GenTensorBufferFunctor>("GenTensorBuffer");
m.add_functor<impl::RepeatFunctor>("Repeat");
m.add_functor<impl::RepeatInterLeaveIndexFunctor>("RepeatInterLeaveIndex");
m.add_functor<impl::RepeatInterLeaveIntFunctor>("RepeatInterLeaveInt");
m.add_functor<impl::RepeatInterLeaveTensorFunctor>("RepeatInterLeaveTensor");
m.add_functor<impl::TileFunctor>("Tile");
m.add_functor<impl::TransposeAllDimPropertyFunctor>("TransposeAllDimProperty");
m.add_functor<impl::TransposeAllDimFunctionFunctor>("TransposeAllDimFunction");
m.add_functor<impl::ReshapeLikeFunctor>("ReshapeLike");
m.add_functor<impl::PinMemoryFunctor>("PinMemory");
m.add_functor<impl::BroadcastShapesFunctor>("BroadcastShapes");
m.add_functor<impl::BroadcastTensorsFunctor>("BroadcastTensors");
m.add_functor<impl::ExpandFunctor>("BroadcastTo"); // BroadcastTo is an alias of Expand
m.add_functor<impl::BinCountFunctor>("BinCount");
m.add_functor<impl::IndexAddFunctor>("IndexAdd");
m.add_functor<impl::IndexAddInplaceFunctor>("IndexAddInplace");
};

} // namespace functional
} // namespace one
} // namespace oneflow
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("size");
if (minlength) {
CHECK_GE_OR_RETURN(JUST(minlength), 0) << "minlength should be >= 0";
max = std::max(JUST(minlength), max);
}
attrs.SetAllAttrs(max);
if (weight) {
CHECK_EQ_OR_RETURN(JUST(weight)->nelement(), x->nelement())
<< "input and weights should have the same length";
return OpInterpUtil::Dispatch<Tensor>(*weight_op_, {x, JUST(weight)}, attrs);
} else {
return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);
}
}

private:
std::shared_ptr<OpExpr> op_;
std::shared_ptr<OpExpr> weight_op_;
};

} // namespace impl

ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::ArgMaxFunctor>("ArgMax");
m.add_functor<impl::ArgMinFunctor>("ArgMin");
m.add_functor<impl::GlobalConstantFunctor>("GlobalConstant");
m.add_functor<impl::ConstantFunctor>("Constant");
m.add_functor<impl::GlobalEmptyFunctor>("GlobalEmpty");
m.add_functor<impl::EmptyFunctor>("Empty");
m.add_functor<impl::ZerosLikeFunctor>("ZerosLike");
m.add_functor<impl::OnesLikeFunctor>("OnesLike");
m.add_functor<impl::FlattenFunctor>("Flatten");
m.add_functor<impl::FillFunctor>("Fill");
m.add_functor<impl::FillTensorFunctor>("FillTensor");
m.add_functor<impl::WhereFunctor>("Where");
m.add_functor<impl::WhereScalarXFunctor>("WhereScalarX");
m.add_functor<impl::WhereScalarYFunctor>("WhereScalarY");
m.add_functor<impl::WhereScalarXYFunctor>("WhereScalarXY");
m.add_functor<impl::ArgWhereFunctor>("ArgWhere");
m.add_functor<impl::NonZeroFunctor>("NonZero");
m.add_functor<impl::BroadcastLikeFunctor>("BroadcastLike");
m.add_functor<impl::ConcatFunctor>("Concat");
m.add_functor<impl::StackFunctor>("Stack");
m.add_functor<impl::StackGradFunctor>("StackGrad");
m.add_functor<impl::AtLeast1DFunctor>("AtLeast1D");
m.add_functor<impl::AtLeast1DListFunctor>("AtLeast1D");
m.add_functor<impl::AtLeast2DFunctor>("AtLeast2D");
m.add_functor<impl::AtLeast2DListFunctor>("AtLeast2D");
m.add_functor<impl::AtLeast3DFunctor>("AtLeast3D");
m.add_functor<impl::AtLeast3DListFunctor>("AtLeast3D");
m.add_functor<impl::HStackFunctor>("HStack");
m.add_functor<impl::ColumnStackFunctor>("ColumnStack");
m.add_functor<impl::VStackFunctor>("VStack");
m.add_functor<impl::RowStackFunctor>("RowStack");
m.add_functor<impl::DStackFunctor>("DStack");
m.add_functor<impl::ExpandFunctor>("Expand");
m.add_functor<impl::ExpandDimsFunctor>("ExpandDims");
m.add_functor<impl::ExpandDimsFunctor>("Unsqueeze");
m.add_functor<impl::UnsqueezeMultipleFunctor>("UnsqueezeMultiple");
m.add_functor<impl::SqueezeFunctor>("Squeeze");
m.add_functor<impl::RollFunctor>("Roll");
m.add_functor<impl::GatherFunctor>("Gather");
m.add_functor<impl::DimGatherFunctor>("DimGather");
m.add_functor<impl::ArgSortFunctor>("ArgSort");
m.add_functor<impl::SearchSortedFunctor>("SearchSorted");
m.add_functor<impl::SearchSortedScalarFunctor>("SearchSortedScalar");
m.add_functor<impl::GatherNdFunctor>("GatherNd");
m.add_functor<impl::ScatterNdFunctor>("ScatterNd");
m.add_functor<impl::TensorScatterNdUpdateFunctor>("TensorScatterNdUpdate");
m.add_functor<impl::ScatterNdLikeFunctor>("ScatterNdLike");
m.add_functor<impl::ReshapeFunctor>("Reshape");
m.add_functor<impl::ViewFunctor>("View");
m.add_functor<impl::ToContiguousFunctor>("ToContiguous");
m.add_functor<impl::InplaceToContiguousFunctor>("InplaceToContiguous");
m.add_functor<impl::NarrowFunctor>("Narrow");
m.add_functor<impl::NarrowGradFunctor>("NarrowGrad");
m.add_functor<impl::SliceUpdateFunctor>("SliceUpdate");
m.add_functor<impl::SliceFunctor>("Slice");
m.add_functor<impl::SliceGradFunctor>("SliceGrad");
m.add_functor<impl::SliceView1dContiguousFunctor>("SliceView1dContiguous");
m.add_functor<impl::CopyFunctor>("Copy");
m.add_functor<impl::FlipFunctor>("Flip");
m.add_functor<impl::UnfoldTensorFunctor>("UnfoldTensor");
m.add_functor<impl::UnfoldTensorGradFunctor>("UnfoldTensorGrad");
m.add_functor<impl::UpsampleGradFunctor>("UpsampleGrad");
m.add_functor<impl::UpsampleNearest2DFunctor>("UpsampleNearest2D");
m.add_functor<impl::UpsampleNearest2DGradFunctor>("UpsampleNearest2DGrad");
m.add_functor<impl::UpsampleBilinear2DFunctor>("UpsampleBilinear2D");
m.add_functor<impl::UpsampleBilinear2DGradFunctor>("UpsampleBilinear2DGrad");
m.add_functor<impl::UpsampleLinear1DFunctor>("UpsampleLinear1D");
m.add_functor<impl::UpsampleLinear1DGradFunctor>("UpsampleLinear1DGrad");
m.add_functor<impl::UpsampleNearest1DFunctor>("UpsampleNearest1D");
m.add_functor<impl::UpsampleNearest1DGradFunctor>("UpsampleNearest1DGrad");
m.add_functor<impl::UpsampleBicubic2DFunctor>("UpsampleBicubic2D");
m.add_functor<impl::UpsampleBicubic2DGradFunctor>("UpsampleBicubic2DGrad");
m.add_functor<impl::UpsampleNearest3DFunctor>("UpsampleNearest3D");
m.add_functor<impl::UpsampleNearest3DGradFunctor>("UpsampleNearest3DGrad");
m.add_functor<impl::UpsampleTrilinear3DFunctor>("UpsampleTrilinear3D");
m.add_functor<impl::UpsampleTrilinear3DGradFunctor>("UpsampleTrilinear3DGrad");
m.add_functor<impl::UnsortedSegmentSumLikeFunctor>("UnsortedSegmentSumLike");
m.add_functor<impl::UnsortedSegmentSumFunctor>("UnsortedSegmentSum");
m.add_functor<impl::TrilFunctor>("Tril");
m.add_functor<impl::TriuFunctor>("Triu");
m.add_functor<impl::InplaceTriuFunctor>("InplaceTriu");
m.add_functor<impl::DiagFunctor>("Diag");
m.add_functor<impl::DiagGradFunctor>("DiagGrad");
m.add_functor<impl::DiagonalFunctor>("Diagonal");
m.add_functor<impl::DiagonalGradFunctor>("DiagonalGrad");
m.add_functor<impl::TensorGetItemFunctor>("TensorGetItem");
m.add_functor<impl::DimScatterFunctorImpl<impl::DimScatterType::kUpdate>>("DimScatterUpdate");
m.add_functor<impl::DimScatterFunctorImpl<impl::DimScatterType::kAdd>>("DimScatterAdd");
m.add_functor<impl::DimScatterFunctorImpl<impl::DimScatterType::kMultiply>>("DimScatterMul");
m.add_functor<impl::DimScatterFunctor>("DimScatter");
m.add_functor<impl::DimScatterScalarFunctorImpl<impl::DimScatterType::kUpdate>>(
"DimScatterUpdateScalar");
m.add_functor<impl::DimScatterScalarFunctorImpl<impl::DimScatterType::kAdd>>(
"DimScatterAddScalar");
m.add_functor<impl::DimScatterScalarFunctorImpl<impl::DimScatterType::kMultiply>>(
"DimScatterMulScalar");
m.add_functor<impl::DimScatterScalarFunctor>("DimScatterScalar");
m.add_functor<impl::DimScatterAddLikeFunctor>("DimScatterAddLike");

m.add_functor<impl::TensorSetItemFunctor>("TensorSetItem");
m.add_functor<impl::CastLikeFunctor>("CastLike");
m.add_functor<impl::ElementwiseMinimumGradFunctor>("ElementwiseMinGrad");
m.add_functor<impl::ElementwiseMaximumGradFunctor>("ElementwiseMaxGrad");
m.add_functor<impl::BroadcastPowXGradFunctor>("BroadcastPowXGrad");
m.add_functor<impl::BroadcastPowYGradFunctor>("BroadcastPowYGrad");
m.add_functor<impl::DivGradFunctor>("DivGrad");
m.add_functor<impl::IdentityFunctor>("Identity");
m.add_functor<impl::AmpWhiteIdentityFunctor>("AmpWhiteIdentity");
m.add_functor<impl::AmpBlackIdentityFunctor>("AmpBlackIdentity");
m.add_functor<impl::ReduceSumLikeFunctor>("ReduceSumLike");
m.add_functor<impl::BroadcastReduceSumLikeFunctor>("BroadcastReduceSumLike");
m.add_functor<impl::SplitFunctor>("Split");
m.add_functor<impl::UnbindFunctor>("Unbind");
m.add_functor<impl::ChunkFunctor>("Chunk");
m.add_functor<impl::SplitLikeFunctor>("SplitLike");
m.add_functor<impl::SplitWithSizeFunctor>("SplitWithSize");
m.add_functor<impl::BatchGatherFunctor>("BatchGather");
m.add_functor<impl::UnsortedBatchSegmentSumFunctor>("UnsortedBatchSegmentSum");
m.add_functor<impl::MaskedFillFunctor<false>>("MaskedFill");
m.add_functor<impl::MaskedFillFunctor<true>>("MaskedFillInplace");
m.add_functor<impl::MeshgridFunctor>("Meshgrid");
m.add_functor<impl::IndexSelectFunctor>("IndexSelect");
m.add_functor<impl::ToFunctor, impl::To2Functor, impl::To3Functor, impl::To4Functor,
impl::ToDeviceFunctor>("To");
m.add_functor<impl::TopKFunctor>("TopK");
m.add_functor<impl::InTopKFunctor>("InTopK");
m.add_functor<impl::TensorToTensorBufferFunctor>("TensorToTensorBuffer");
m.add_functor<impl::TensorBufferToTensorFunctor>("TensorBufferToTensor");
m.add_functor<impl::GenTensorBufferFunctor>("GenTensorBuffer");
m.add_functor<impl::RepeatFunctor>("Repeat");
m.add_functor<impl::RepeatInterLeaveIndexFunctor>("RepeatInterLeaveIndex");
m.add_functor<impl::RepeatInterLeaveIntFunctor>("RepeatInterLeaveInt");
m.add_functor<impl::RepeatInterLeaveTensorFunctor>("RepeatInterLeaveTensor");
m.add_functor<impl::TileFunctor>("Tile");
m.add_functor<impl::TransposeAllDimPropertyFunctor>("TransposeAllDimProperty");
m.add_functor<impl::TransposeAllDimFunctionFunctor>("TransposeAllDimFunction");
m.add_functor<impl::ReshapeLikeFunctor>("ReshapeLike");
m.add_functor<impl::PinMemoryFunctor>("PinMemory");
m.add_functor<impl::BroadcastShapesFunctor>("BroadcastShapes");
m.add_functor<impl::BroadcastTensorsFunctor>("BroadcastTensors");
m.add_functor<impl::ExpandFunctor>("BroadcastTo"); // BroadcastTo is an alias of Expand
m.add_functor<impl::BinCountFunctor>("BinCount");
m.add_functor<impl::IndexAddFunctor>("IndexAdd");
m.add_functor<impl::IndexAddInplaceFunctor>("IndexAddInplace");
};

} // namespace functional
} // namespace one
} // namespace oneflow
469 changes: 234 additions & 235 deletions oneflow/core/functional/impl/common.cpp
Original file line number Diff line number Diff line change
@@ -17,273 +17,272 @@ limitations under the License.
#include "oneflow/core/functional/impl/common.h"
#include "oneflow/core/autograd/autograd_mode.h"
#include "oneflow/core/common/wrap_dim_utils.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/ccl/ccl.h"
#include "oneflow/core/job/rank_group.h"

namespace oneflow {
namespace one {
namespace functional {
namespace {

Maybe<Shape> InferUnifiedShapeForBroadcasting(const Shape& input_shape,
const Shape& other_shape) {
// same shapes need no broadcasting
if (input_shape == other_shape) { return input_shape; }

const auto unify_shapes_with_same_num_axes = [](const Shape& input_shape,
const Shape& other_shape) -> Maybe<Shape> {
// num_axes.first == num_axes.second
Shape target;
for (size_t i = 0; i < input_shape.NumAxes() /* both input_shape and other_shape are ok */;
++i) {
const auto num_in_curr_dim = std::make_pair(input_shape.At(i), other_shape.At(i));

// A = (2, ), B = (2, ), A[0] == B[0], so C = (2, )
if (num_in_curr_dim.first == num_in_curr_dim.second) {
target.push_back(num_in_curr_dim.first);
continue;
}

// A = (2, ), B = (3, ), A[0] != B[0] and A[0] != 1 and B[0] != 1, so raise RuntimeError
if (num_in_curr_dim.first != 1 && num_in_curr_dim.second != 1) {
return Error::RuntimeError()
<< fmt::format("input and other can't be broadcasted to a single shape. [input's "
"shape: {}, other's shape: {}].",
input_shape.ToString(), other_shape.ToString());
}

// A = (2, ), B = (1, ), A[0] != B[0] but B[0] == 1, so C = (2, )
target.push_back(
num_in_curr_dim.first == 1
? num_in_curr_dim.second
: num_in_curr_dim.first); // num_in_curr_dim.first and num_in_curr_dim.second can't
// be 1 at the same time
namespace oneflow {
namespace one {
namespace functional {
namespace {

Maybe<Shape> InferUnifiedShapeForBroadcasting(const Shape& input_shape, const Shape& other_shape) {
// same shapes need no broadcasting
if (input_shape == other_shape) { return input_shape; }

const auto unify_shapes_with_same_num_axes = [](const Shape& input_shape,
const Shape& other_shape) -> Maybe<Shape> {
// num_axes.first == num_axes.second
Shape target;
for (size_t i = 0; i < input_shape.NumAxes() /* both input_shape and other_shape are ok */;
++i) {
const auto num_in_curr_dim = std::make_pair(input_shape.At(i), other_shape.At(i));

// A = (2, ), B = (2, ), A[0] == B[0], so C = (2, )
if (num_in_curr_dim.first == num_in_curr_dim.second) {
target.push_back(num_in_curr_dim.first);
continue;
}
return target;
};

const int64_t input_num_axes = input_shape.NumAxes();
const int64_t other_num_axes = other_shape.NumAxes();

if (input_num_axes == other_num_axes) {
return unify_shapes_with_same_num_axes(input_shape, other_shape);
}

const int64_t unified_num_axes = std::max(input_num_axes, other_num_axes);

// shape = (3, 4) and unified_num_axes = 3 ==> shape will be (1, 3, 4)
const auto expand_shape_if_necessary = [unified_num_axes](const Shape& shape_to_expand) {
const int64_t shape_to_expand_num_axes = shape_to_expand.NumAxes();
if (shape_to_expand_num_axes < unified_num_axes) {
auto new_shape = Shape::Ones(unified_num_axes);
std::copy(shape_to_expand.begin(), shape_to_expand.end(),
new_shape.begin() + (unified_num_axes - shape_to_expand_num_axes));
return new_shape;
// A = (2, ), B = (3, ), A[0] != B[0] and A[0] != 1 and B[0] != 1, so raise RuntimeError
if (num_in_curr_dim.first != 1 && num_in_curr_dim.second != 1) {
return Error::RuntimeError()
<< fmt::format("input and other can't be broadcasted to a single shape. [input's "
"shape: {}, other's shape: {}].",
input_shape.ToString(), other_shape.ToString());
}
return shape_to_expand;
};

return unify_shapes_with_same_num_axes(expand_shape_if_necessary(input_shape),
expand_shape_if_necessary(other_shape));
}
// A = (2, ), B = (1, ), A[0] != B[0] but B[0] == 1, so C = (2, )
target.push_back(
num_in_curr_dim.first == 1
? num_in_curr_dim.second
: num_in_curr_dim.first); // num_in_curr_dim.first and num_in_curr_dim.second can't
// be 1 at the same time
}
return target;
};

} // namespace
const int64_t input_num_axes = input_shape.NumAxes();
const int64_t other_num_axes = other_shape.NumAxes();

bool IsStaticZerosTensor(const std::shared_ptr<Tensor>& x) {
return nullptr != std::dynamic_pointer_cast<StaticZerosTensor>(x);
if (input_num_axes == other_num_axes) {
return unify_shapes_with_same_num_axes(input_shape, other_shape);
}

bool IsInplaceValid(const std::shared_ptr<Tensor>& x) {
return !autograd::GradMode::is_enabled() || !(x->is_leaf() && x->requires_grad());
}
const int64_t unified_num_axes = std::max(input_num_axes, other_num_axes);

bool IsScalarTensor(const std::shared_ptr<Tensor>& x) {
return x->shape()->NumAxes() == 0 && x->shape()->elem_cnt() == 1;
// shape = (3, 4) and unified_num_axes = 3 ==> shape will be (1, 3, 4)
const auto expand_shape_if_necessary = [unified_num_axes](const Shape& shape_to_expand) {
const int64_t shape_to_expand_num_axes = shape_to_expand.NumAxes();
if (shape_to_expand_num_axes < unified_num_axes) {
auto new_shape = Shape::Ones(unified_num_axes);
std::copy(shape_to_expand.begin(), shape_to_expand.end(),
new_shape.begin() + (unified_num_axes - shape_to_expand_num_axes));
return new_shape;
}
return shape_to_expand;
};

return unify_shapes_with_same_num_axes(expand_shape_if_necessary(input_shape),
expand_shape_if_necessary(other_shape));
}

} // namespace

bool IsStaticZerosTensor(const std::shared_ptr<Tensor>& x) {
return nullptr != std::dynamic_pointer_cast<StaticZerosTensor>(x);
}

bool IsInplaceValid(const std::shared_ptr<Tensor>& x) {
return !autograd::GradMode::is_enabled() || !(x->is_leaf() && x->requires_grad());
}

bool IsScalarTensor(const std::shared_ptr<Tensor>& x) {
return x->shape()->NumAxes() == 0 && x->shape()->elem_cnt() == 1;
}

Maybe<std::vector<int32_t>> CheckAxis(const std::vector<int32_t>& axis, const int32_t& ndim) {
const int32_t naxis = axis.size();
int32_t reduce_ndim = naxis;
if (naxis == 0 || ndim == 0) { reduce_ndim = ndim; };
std::vector<int32_t> reduce_axis(reduce_ndim);
if (naxis == 0) {
std::iota(reduce_axis.begin(), reduce_axis.end(), 0);
} else {
JUST(dim_list_to_bitset(axis, ndim)); // checking axis[dim]'s validation
for (int32_t i = 0; i < naxis; i++) {
if (i < reduce_ndim) { reduce_axis[i] = JUST(maybe_wrap_dim(axis[i], ndim)); };
}
}

Maybe<std::vector<int32_t>> CheckAxis(const std::vector<int32_t>& axis, const int32_t& ndim) {
const int32_t naxis = axis.size();
int32_t reduce_ndim = naxis;
if (naxis == 0 || ndim == 0) { reduce_ndim = ndim; };
std::vector<int32_t> reduce_axis(reduce_ndim);
if (naxis == 0) {
std::iota(reduce_axis.begin(), reduce_axis.end(), 0);
return reduce_axis;
}

Maybe<void> CheckInplaceValid(const std::shared_ptr<Tensor>& x) {
CHECK_OR_RETURN(IsInplaceValid(x))
<< Error::RuntimeError()
<< "a leaf Tensor that requires grad is being used in an in-place operation";
return Maybe<void>::Ok();
}

Maybe<void> CheckInplaceCastValid(const std::shared_ptr<Tensor>& x,
const std::shared_ptr<Tensor>& x_cast) {
CHECK_OR_RETURN(*x->dtype() == *x_cast->dtype())
<< Error::RuntimeError() << "result type " << x_cast->dtype()->name()
<< " can't be cast to the desired output type " << x->dtype()->name();
return Maybe<void>::Ok();
}

Maybe<void> CheckInplaceShapeCanExpandTo(const Shape& shape, const Shape& expand_shape) {
if (shape == expand_shape) { return Maybe<void>::Ok(); }

CHECK_OR_RETURN(expand_shape.NumAxes() >= shape.NumAxes())
<< Error::RuntimeError() << "Can not expand origin shape " << shape.ToString() << " to "
<< expand_shape.ToString() << " in an inplace operation";

int shift = expand_shape.NumAxes() - shape.NumAxes();
for (int i = expand_shape.NumAxes() - 1; i >= 0; --i) {
int index = i - shift;
if (index >= 0) {
int dim_a = expand_shape.At(i);
int dim_b = shape.At(index);
// NOTE(lixiang): When a dimension of tensor a and tensor b are not equal in size, dim_a needs
// to be greater than 0, and dim_b should be equal to 1.
CHECK_OR_RETURN(!(dim_a != dim_b && (dim_a <= 0 || dim_b != 1)))
<< Error::RuntimeError() << "Tensor with shape " << expand_shape.ToString()
<< " doesn't match the broadcast shape in an inplace operation";
} else {
JUST(dim_list_to_bitset(axis, ndim)); // checking axis[dim]'s validation
for (int32_t i = 0; i < naxis; i++) {
if (i < reduce_ndim) { reduce_axis[i] = JUST(maybe_wrap_dim(axis[i], ndim)); };
}
// For 0-size tensor, expand_shape.At(i) can equal to 0.
CHECK_OR_RETURN(expand_shape.At(i) >= 0); // NOLINT(maybe-need-error-msg)
}
return reduce_axis;
}

Maybe<void> CheckInplaceValid(const std::shared_ptr<Tensor>& x) {
CHECK_OR_RETURN(IsInplaceValid(x))
<< Error::RuntimeError()
<< "a leaf Tensor that requires grad is being used in an in-place operation";
return Maybe<void>::Ok();
}

Maybe<void> CheckInplaceCastValid(const std::shared_ptr<Tensor>& x,
const std::shared_ptr<Tensor>& x_cast) {
CHECK_OR_RETURN(*x->dtype() == *x_cast->dtype())
<< Error::RuntimeError() << "result type " << x_cast->dtype()->name()
<< " can't be cast to the desired output type " << x->dtype()->name();
return Maybe<void>::Ok();
return Maybe<void>::Ok();
}

Optional<Stride> ComputeStride(const Shape& shape, const Stride& stride,
const Shape& target_shape) {
/*************************************************
* Description: in some case, view operate is not allowed, so need to check it's validation,
* the check refers to torch(aten/src/ATen/native/TensorShape.cpp)
*************************************************/
if (stride.size() == 0) {
// for scalar input tensor
return Stride(target_shape.NumAxes(), 1);
}

Maybe<void> CheckInplaceShapeCanExpandTo(const Shape& shape, const Shape& expand_shape) {
if (shape == expand_shape) { return Maybe<void>::Ok(); }

CHECK_OR_RETURN(expand_shape.NumAxes() >= shape.NumAxes())
<< Error::RuntimeError() << "Can not expand origin shape " << shape.ToString() << " to "
<< expand_shape.ToString() << " in an inplace operation";

int shift = expand_shape.NumAxes() - shape.NumAxes();
for (int i = expand_shape.NumAxes() - 1; i >= 0; --i) {
int index = i - shift;
if (index >= 0) {
int dim_a = expand_shape.At(i);
int dim_b = shape.At(index);
// NOTE(lixiang): When a dimension of tensor a and tensor b are not equal in size, dim_a
// needs to be greater than 0, and dim_b should be equal to 1.
CHECK_OR_RETURN(!(dim_a != dim_b && (dim_a <= 0 || dim_b != 1)))
<< Error::RuntimeError() << "Tensor with shape " << expand_shape.ToString()
<< " doesn't match the broadcast shape in an inplace operation";
} else {
// For 0-size tensor, expand_shape.At(i) can equal to 0.
CHECK_OR_RETURN(expand_shape.At(i) >= 0); // NOLINT(maybe-need-error-msg)
int64_t elem_count = shape.elem_cnt();
int64_t ndim = shape.NumAxes();
int64_t tgt_ndim = target_shape.NumAxes();
DimVector shape_vec = shape.dim_vec();
DimVector tgt_shape_vec = target_shape.dim_vec();
if (elem_count == 0) { return NullOpt; }

int64_t view_d = tgt_ndim - 1;
int64_t chunk_base_stride = stride.back();
Stride target_stride(tgt_ndim);
// stride for each subspace in the chunk
// numel in current chunk
int64_t tensor_numel = 1;
int64_t view_numel = 1;
for (int64_t tensor_d = ndim - 1; tensor_d >= 0; tensor_d--) {
tensor_numel *= shape_vec[tensor_d];
// if end of tensor size chunk, check view
if ((tensor_d == 0)
|| (shape_vec[tensor_d - 1] != 1
&& stride[tensor_d - 1] != tensor_numel * chunk_base_stride)) {
while (view_d >= 0 && (view_numel < tensor_numel || tgt_shape_vec[view_d] == 1)) {
target_stride[view_d] = view_numel * chunk_base_stride;
view_numel *= tgt_shape_vec[view_d];
view_d--;
}
}

return Maybe<void>::Ok();
}

Optional<Stride> ComputeStride(const Shape& shape, const Stride& stride,
const Shape& target_shape) {
/*************************************************
* Description: in some case, view operate is not allowed, so need to check it's validation,
* the check refers to torch(aten/src/ATen/native/TensorShape.cpp)
*************************************************/
if (stride.size() == 0) {
// for scalar input tensor
return Stride(target_shape.NumAxes(), 1);
}
int64_t elem_count = shape.elem_cnt();
int64_t ndim = shape.NumAxes();
int64_t tgt_ndim = target_shape.NumAxes();
DimVector shape_vec = shape.dim_vec();
DimVector tgt_shape_vec = target_shape.dim_vec();
if (elem_count == 0) { return NullOpt; }

int64_t view_d = tgt_ndim - 1;
int64_t chunk_base_stride = stride.back();
Stride target_stride(tgt_ndim);
// stride for each subspace in the chunk
// numel in current chunk
int64_t tensor_numel = 1;
int64_t view_numel = 1;
for (int64_t tensor_d = ndim - 1; tensor_d >= 0; tensor_d--) {
tensor_numel *= shape_vec[tensor_d];
// if end of tensor size chunk, check view
if ((tensor_d == 0)
|| (shape_vec[tensor_d - 1] != 1
&& stride[tensor_d - 1] != tensor_numel * chunk_base_stride)) {
while (view_d >= 0 && (view_numel < tensor_numel || tgt_shape_vec[view_d] == 1)) {
target_stride[view_d] = view_numel * chunk_base_stride;
view_numel *= tgt_shape_vec[view_d];
view_d--;
}
if (view_numel != tensor_numel) { return NullOpt; }
if (tensor_d > 0) {
chunk_base_stride = stride[tensor_d - 1];
tensor_numel = 1;
view_numel = 1;
}
if (view_numel != tensor_numel) { return NullOpt; }
if (tensor_d > 0) {
chunk_base_stride = stride[tensor_d - 1];
tensor_numel = 1;
view_numel = 1;
}
}
if (view_d != -1) { return NullOpt; }
return target_stride;
}

Maybe<Shape> InferShapeUnspecifiedDim(const int64_t& elem_count, const Shape& shape) {
int need_infer_axis = -1;
int64_t target_elem_count = 1;
for (int i = 0; i < shape.NumAxes(); ++i) {
if (shape.At(i) < -1) {
return Error::RuntimeError() << "Invalid shape dimension " << shape.At(i);
} else if (shape.At(i) == -1) {
CHECK_OR_RETURN_ERROR(need_infer_axis == -1)
<< Error::RuntimeError() << "only one dimension can be inferred";
need_infer_axis = i;
} else {
target_elem_count *= shape.At(i);
}
}
Shape infered_shape = shape;
if (need_infer_axis == -1) {
if (elem_count > 0) {
// For 0-size tensor, we don't need to check the element size.
CHECK_OR_RETURN_ERROR(target_elem_count == elem_count)
<< Error::RuntimeError() << "shape '" << shape.ToString()
<< "' is invalid for input of size " << elem_count;
}
if (view_d != -1) { return NullOpt; }
return target_stride;
}

Maybe<Shape> InferShapeUnspecifiedDim(const int64_t& elem_count, const Shape& shape) {
int need_infer_axis = -1;
int64_t target_elem_count = 1;
for (int i = 0; i < shape.NumAxes(); ++i) {
if (shape.At(i) < -1) {
return Error::RuntimeError() << "Invalid shape dimension " << shape.At(i);
} else if (shape.At(i) == -1) {
CHECK_OR_RETURN_ERROR(need_infer_axis == -1)
<< Error::RuntimeError() << "only one dimension can be inferred";
need_infer_axis = i;
} else {
infered_shape.Set(need_infer_axis, elem_count / target_elem_count);
CHECK_OR_RETURN_ERROR(target_elem_count * infered_shape.At(need_infer_axis) == elem_count)
target_elem_count *= shape.At(i);
}
}
Shape infered_shape = shape;
if (need_infer_axis == -1) {
if (elem_count > 0) {
// For 0-size tensor, we don't need to check the element size.
CHECK_OR_RETURN_ERROR(target_elem_count == elem_count)
<< Error::RuntimeError() << "shape '" << shape.ToString()
<< "' is invalid for input of size " << elem_count;
}
return infered_shape;
} else {
infered_shape.Set(need_infer_axis, elem_count / target_elem_count);
CHECK_OR_RETURN_ERROR(target_elem_count * infered_shape.At(need_infer_axis) == elem_count)
<< Error::RuntimeError() << "shape '" << shape.ToString()
<< "' is invalid for input of size " << elem_count;
}
return infered_shape;
}

Maybe<Shape> InferUnifiedShapeForBroadcasting(const std::vector<Shape>& shapes) {
if (shapes.empty()) { return Error::RuntimeError() << "shapes should not be empty."; }
if (shapes.size() == 1) { return JUST(VectorAt(shapes, 0)); }

auto result = *JUST(
InferUnifiedShapeForBroadcasting(JUST(VectorAt(shapes, 0)), JUST(VectorAt(shapes, 1))));
Maybe<Shape> InferUnifiedShapeForBroadcasting(const std::vector<Shape>& shapes) {
if (shapes.empty()) { return Error::RuntimeError() << "shapes should not be empty."; }
if (shapes.size() == 1) { return JUST(VectorAt(shapes, 0)); }

// (1, 2) vs (3, 2) => (3, 2)
if (shapes.size() == 2) { return result; }
auto result =
*JUST(InferUnifiedShapeForBroadcasting(JUST(VectorAt(shapes, 0)), JUST(VectorAt(shapes, 1))));

/*
(1, 3) vs (3, 1) vs (3, 1, 1)
1. (1, 3) vs (3, 1) => (3, 3)
2. (3, 3) vs (3, 1, 1) => (3, 3, 3)
3. (3, 3, 3) is the final result
*/
for (auto iter = shapes.begin() + 2; iter != shapes.end(); ++iter) {
result = *JUST(InferUnifiedShapeForBroadcasting(result, *iter));
}
return result;
}
// (1, 2) vs (3, 2) => (3, 2)
if (shapes.size() == 2) { return result; }

/*
if input shapes are [(1, 3), (3, 1), (3, 1, 1)]
will return ((3, 3, 3), [true, true, true])
means the shape to broadcast to is (3, 3, 3) and all three shapes need broadcasting
*/
Maybe<std::tuple<Shape, std::deque<bool>>> InferUnifiedShapeForBroadcastingWithInfo(
const std::vector<Shape>& shapes) {
const auto unified_shape = *JUST(InferUnifiedShapeForBroadcasting(shapes));
std::deque<bool> need_to_broadcast;
for (const auto& x : shapes) { need_to_broadcast.emplace_back(x != unified_shape); }
return std::make_tuple(unified_shape, need_to_broadcast);
}
(1, 3) vs (3, 1) vs (3, 1, 1)
Maybe<void> BroadcastSeedToAllRanks(uint64_t* seed, int64_t root) {
CHECK_NOTNULL_OR_RETURN(seed) << "seed is not allowed to be nullptr";
const auto& rank_group = JUST(RankGroup::DefaultRankGroup());
const auto& parallel_desc =
JUST(RankGroup::GetDefaultParallelDesc(DeviceType::kCPU, rank_group));
const auto& meta_transport_token =
JUST(TransportToken::NewTransportToken(kTransportTokenTypeMeta));
JUST(ccl::CpuBroadcast(seed, seed, sizeof(*seed), root, parallel_desc, meta_transport_token));
return Maybe<void>::Ok();
1. (1, 3) vs (3, 1) => (3, 3)
2. (3, 3) vs (3, 1, 1) => (3, 3, 3)
3. (3, 3, 3) is the final result
*/
for (auto iter = shapes.begin() + 2; iter != shapes.end(); ++iter) {
result = *JUST(InferUnifiedShapeForBroadcasting(result, *iter));
}
return result;
}

} // namespace functional
} // namespace one
} // namespace oneflow
/*
if input shapes are [(1, 3), (3, 1), (3, 1, 1)]
will return ((3, 3, 3), [true, true, true])
means the shape to broadcast to is (3, 3, 3) and all three shapes need broadcasting
*/
Maybe<std::tuple<Shape, std::deque<bool>>> InferUnifiedShapeForBroadcastingWithInfo(
const std::vector<Shape>& shapes) {
const auto unified_shape = *JUST(InferUnifiedShapeForBroadcasting(shapes));
std::deque<bool> need_to_broadcast;
for (const auto& x : shapes) { need_to_broadcast.emplace_back(x != unified_shape); }
return std::make_tuple(unified_shape, need_to_broadcast);
}

Maybe<void> BroadcastSeedToAllRanks(uint64_t* seed, int64_t root) {
CHECK_NOTNULL_OR_RETURN(seed) << "seed is not allowed to be nullptr";
const auto& rank_group = JUST(RankGroup::DefaultRankGroup());
const auto& parallel_desc = JUST(RankGroup::GetDefaultParallelDesc(DeviceType::kCPU, rank_group));
const auto& meta_transport_token =
JUST(TransportToken::NewTransportToken(kTransportTokenTypeMeta));
JUST(ccl::CpuBroadcast(seed, seed, sizeof(*seed), root, parallel_desc, meta_transport_token));
return Maybe<void>::Ok();
}

} // namespace functional
} // namespace one
} // namespace oneflow
2 changes: 1 addition & 1 deletion oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
@@ -11016,4 +11016,4 @@ def OneFlow_MlirJitOp : OneFlow_JITLikeOp<"mlir_jit"> {}

def OneFlow_KernelLaunchOp : OneFlow_JITLikeOp<"kernel_launch"> {}

#endif // GET_ONEFLOW_MLIR_JIT_OP_DEFINITIONS
#endif // GET_ONEFLOW_MLIR_JIT_OP_DEFINITIONS