Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer consistent tensor meta #5118

Merged
merged 178 commits into from
Jul 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
178 commits
Select commit Hold shift + click to select a range
f16f663
Device::compute_dep_object_
lixinqi May 12, 2021
998cd7e
Merge branch 'master' into device_compute_dep_object
lixinqi May 12, 2021
8b4d553
sequantialize instructions in the same stream.
lixinqi May 12, 2021
b32e834
refactor AttrMap
lixinqi May 13, 2021
f0ccb1e
merge master
lixinqi May 13, 2021
172199c
refactor Tensor
lixinqi May 18, 2021
d8e30e5
merge master
lixinqi May 18, 2021
cba6684
Export ConsistentTensor::is_cuda
lixinqi May 18, 2021
f85f002
Merge branch 'master' into refactor_tensor
lixinqi May 18, 2021
980e037
remove ConsistentTensor::blob_object
lixinqi May 18, 2021
83fd994
Merge branch 'master' into refactor_tensor
lixinqi May 18, 2021
876b7c2
Merge branch 'refactor_tensor' of github.com:Oneflow-Inc/oneflow into…
lixinqi May 18, 2021
e87517e
Merge branch 'refactor_tensor' into refactor_consistent_tensor
lixinqi May 18, 2021
7a5c737
Merge branch 'master' into refactor_tensor
oneflow-ci-bot May 18, 2021
18e19b5
Merge branch 'master' into refactor_tensor
oneflow-ci-bot May 18, 2021
a74b34a
Merge branch 'master' into refactor_tensor
oneflow-ci-bot May 18, 2021
da51ce6
Merge branch 'master' into refactor_tensor
oneflow-ci-bot May 18, 2021
7d54f45
Merge branch 'master' into refactor_tensor
oneflow-ci-bot May 18, 2021
a4cb4ae
Merge branch 'master' into refactor_tensor
oneflow-ci-bot May 18, 2021
871b987
refactor TensorImpl
lixinqi May 19, 2021
f2b035b
Merge branch 'master' of github.com:Oneflow-Inc/oneflow into refactor…
lixinqi May 19, 2021
f0f96a3
minor fix
lixinqi May 19, 2021
41970b1
Merge branch 'refactor_tensor' of github.com:Oneflow-Inc/oneflow into…
lixinqi May 19, 2021
690ecf4
merge refactor_tensor
lixinqi May 19, 2021
a40eee3
Merge branch 'refactor_consistent_tensor' into eager_consistent_tensor
lixinqi May 19, 2021
63d081b
Merge branch 'master' into refactor_tensor
oneflow-ci-bot May 19, 2021
7eeecf4
fix compiler' complains
lixinqi May 19, 2021
7f9752d
Merge branch 'master' into refactor_tensor
oneflow-ci-bot May 19, 2021
cc6437b
Merge branch 'master' into refactor_tensor
oneflow-ci-bot May 19, 2021
0cd5ad8
Implements EagerConsistentTensorImpl::New
lixinqi May 19, 2021
75464d8
minor fix
lixinqi May 19, 2021
de8d1c1
Merge branch 'refactor_tensor' of github.com:Oneflow-Inc/oneflow into…
lixinqi May 19, 2021
6f7f10b
Merge branch 'refactor_tensor' into refactor_consistent_tensor
lixinqi May 19, 2021
adb7231
Merge branch 'refactor_consistent_tensor' into eager_consistent_tensor
lixinqi May 19, 2021
8ddebf4
merge master
lixinqi May 20, 2021
e20b67f
merge master
lixinqi May 20, 2021
d540ea3
fix compiler complains
lixinqi May 20, 2021
a65a9c2
remove unused code
lixinqi May 20, 2021
b06018b
Merge branch 'master' into refactor_consistent_tensor
lixinqi May 20, 2021
b3587f1
merge master
lixinqi May 20, 2021
948dff1
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan May 21, 2021
7d38063
skip test_creating_consistent_tensor
clackhan May 21, 2021
089784f
merge master
lixinqi May 21, 2021
8bb594b
Merge branch 'master' into hashable_attr_map
lixinqi May 21, 2021
119a604
Merge branch 'master' into refactor_consistent_tensor
lixinqi May 21, 2021
88e21cb
Merge branch 'refactor_consistent_tensor' into eager_consistent_tensor
lixinqi May 21, 2021
4a6484f
backup code
lixinqi May 24, 2021
4b36553
Symbol::shared_from_symbol
lixinqi May 24, 2021
753521c
Merge branch 'master' into hashable_attr_map
lixinqi May 24, 2021
fadd52d
remove redundant header file includes
lixinqi May 24, 2021
8552bcb
Merge branch 'hashable_attr_map' of github.com:Oneflow-Inc/oneflow in…
lixinqi May 24, 2021
fa81614
merge master
lixinqi May 24, 2021
282411e
Merge branch 'hashable_attr_map' into eager_consistent_tensor
lixinqi May 24, 2021
03d4d92
Merge branch 'refactor_symbol' into eager_consistent_tensor
lixinqi May 24, 2021
5c83386
fix bug in Symbol::shared_from_symbol
lixinqi May 24, 2021
adea093
Merge branch 'refactor_symbol' into eager_consistent_tensor
lixinqi May 24, 2021
d4365e5
Merge branch 'master' into refactor_symbol
oneflow-ci-bot May 24, 2021
122dd5c
Merge branch 'master' into refactor_symbol
oneflow-ci-bot May 24, 2021
d07eed4
Merge branch 'master' into refactor_symbol
oneflow-ci-bot May 24, 2021
c33f97f
symbolize ParallelDesc and ParallelDistribution
lixinqi May 25, 2021
0e8e395
Merge branch 'master' into eager_consistent_tensor
lixinqi May 25, 2021
fc71a90
Merge branch 'master' into refactor_symbol
oneflow-ci-bot May 25, 2021
27f0618
Merge branch 'master' into refactor_symbol
oneflow-ci-bot May 25, 2021
e817865
Merge branch 'master' into refactor_symbol
oneflow-ci-bot May 25, 2021
88d2844
Merge branch 'master' into refactor_symbol
oneflow-ci-bot May 25, 2021
1975028
symbolize Scope::GetParallelDesc()
lixinqi May 25, 2021
e1b833f
IsScalarType
lixinqi May 25, 2021
dc6a554
Merge branch 'refactor_symbol' into refactor_scope_parallel_desc
lixinqi May 25, 2021
9fddbc8
fix compiler complains
lixinqi May 25, 2021
28f2f20
Merge branch 'master' into eager_consistent_tensor
lixinqi May 25, 2021
e4f1f51
Merge branch 'refactor_scope_parallel_desc' into infer_consistent_ten…
lixinqi May 25, 2021
bdff487
InputConsistentTensorMeta
lixinqi May 25, 2021
9e475c6
refactor Scope with PlacementScope
lixinqi May 25, 2021
bbb044a
fix bug in exporting Scope to python
lixinqi May 25, 2021
d0fabac
backup code
lixinqi May 28, 2021
a206ffd
refactor DType
lixinqi May 28, 2021
0baf312
fix compiler complains
lixinqi May 28, 2021
b1a13a8
Merge branch 'master' into refactor_dtype
clackhan May 28, 2021
149eba2
Merge branch 'master' into refactor_dtype
oneflow-ci-bot May 28, 2021
6612022
Merge branch 'master' into refactor_dtype
oneflow-ci-bot May 28, 2021
5b02e1c
Merge branch 'master' into refactor_dtype
oneflow-ci-bot May 28, 2021
7a60e21
Merge branch 'master' into refactor_dtype
oneflow-ci-bot May 28, 2021
7fa40ba
Merge branch 'master' into refactor_dtype
oneflow-ci-bot May 28, 2021
1d2fbcc
Merge branch 'master' into refactor_dtype
oneflow-ci-bot May 28, 2021
e00e992
Merge branch 'master' into refactor_dtype
oneflow-ci-bot May 28, 2021
f6beda6
Merge branch 'master' into refactor_dtype
oneflow-ci-bot May 28, 2021
6f474b6
Merge branch 'master' into refactor_dtype
oneflow-ci-bot May 28, 2021
f5aada7
Merge branch 'master' into refactor_dtype
oneflow-ci-bot May 28, 2021
35e4cfa
Merge branch 'master' into refactor_dtype
oneflow-ci-bot May 30, 2021
11ede6d
backup code
lixinqi May 31, 2021
ffba0b2
DType is only allowed to be used in python code
lixinqi May 31, 2021
caca35b
Merge branch 'refactor_dtype' of github.com:Oneflow-Inc/oneflow into …
lixinqi May 31, 2021
fe551f7
merge master
lixinqi May 31, 2021
cf37d20
Merge branch 'master' into refactor_dtype
oneflow-ci-bot May 31, 2021
e332d50
Merge branch 'master' into refactor_dtype
oneflow-ci-bot May 31, 2021
b0f363b
merge master
lixinqi May 31, 2021
50c8d41
backup code
lixinqi May 31, 2021
be5a6d3
Merge branch 'master' into refactor_dtype
oneflow-ci-bot May 31, 2021
08cb01a
dtype api bugfix
lixinqi May 31, 2021
e94d1c5
Merge branch 'refactor_dtype' of github.com:Oneflow-Inc/oneflow into …
lixinqi May 31, 2021
01345ed
fix error on exiting
daquexian May 31, 2021
294c766
Merge branch 'master' into refactor_dtype
oneflow-ci-bot May 31, 2021
848c832
Merge branch 'master' into fix_atexit_error
daquexian May 31, 2021
71c53da
Merge branch 'master' into refactor_dtype
lixinqi May 31, 2021
62ececc
Merge branch 'fix_atexit_error' into refactor_dtype
lixinqi May 31, 2021
4a6c9e2
Merge branch 'refactor_dtype' of github.com:Oneflow-Inc/oneflow into …
lixinqi May 31, 2021
27e3fbb
lazily get rank
daquexian May 31, 2021
b66c308
Merge branch 'fix_atexit_error' into refactor_dtype
daquexian May 31, 2021
6470ed0
Merge branch 'master' into refactor_dtype
oneflow-ci-bot May 31, 2021
809c7f1
Merge branch 'master' into refactor_dtype
oneflow-ci-bot May 31, 2021
9ee1f40
Merge branch 'master' into refactor_dtype
oneflow-ci-bot May 31, 2021
282c79b
Merge branch 'master' into refactor_dtype
oneflow-ci-bot May 31, 2021
f097809
Merge branch 'master' into refactor_dtype
oneflow-ci-bot May 31, 2021
aec513e
Merge branch 'master' into refactor_dtype
oneflow-ci-bot Jun 1, 2021
b35cffc
Export const DType* into python
lixinqi Jun 1, 2021
8819b35
merge branch refactor_dtype and fix compiler complain
lixinqi Jun 1, 2021
bef5488
merge master
lixinqi Jun 1, 2021
7aced04
minor fix
lixinqi Jun 1, 2021
350ef56
Merge branch 'master' into op_expr_infer_tensor_meta
lixinqi Jun 2, 2021
557589e
fix bug
clackhan Jun 2, 2021
9e9f133
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Jun 2, 2021
8ec6793
refine
clackhan Jun 2, 2021
812601b
merge master
lixinqi Jun 3, 2021
02e05bb
Merge branch 'master' into op_expr_infer_tensor_meta
lixinqi Jun 3, 2021
ddf06da
Merge branch 'op_expr_infer_tensor_meta' into infer_consistent_tensor…
lixinqi Jun 3, 2021
3c77886
refactor signature of OpExpr::InferLogicalShapeAndDtype
lixinqi Jun 3, 2021
c8a0002
fix bug
clackhan Jun 3, 2021
faad0f4
Merge branch 'op_expr_infer_tensor_meta' of https://github.com/Oneflo…
clackhan Jun 3, 2021
13dfc33
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Jun 3, 2021
144c87d
Tensor::mut_eager_mirrored_tensor_impl
lixinqi Jun 3, 2021
2e78e55
Merge branch 'op_expr_infer_tensor_meta' into infer_consistent_tensor…
lixinqi Jun 3, 2021
b523c27
backup_code
lixinqi Jun 3, 2021
7ac2444
fix bug
clackhan Jun 3, 2021
adc6bba
Merge branch 'master' into op_expr_infer_tensor_meta
clackhan Jun 3, 2021
fd1c54f
refactor SbpXXX to cfg::SbpXXX
lixinqi Jun 3, 2021
d31c000
Merge branch 'master' into op_expr_infer_tensor_meta
clackhan Jun 4, 2021
079c9c2
Merge branch 'master' into refactor_sbp_to_cfg_sbp
lixinqi Jun 4, 2021
4323fff
Merge branch 'refactor_sbp_to_cfg_sbp' into infer_consistent_tensor_meta
lixinqi Jun 4, 2021
0939eb3
merge refactor_sbp_to_cfg_sbp
lixinqi Jun 4, 2021
36b7785
fix bug
clackhan Jun 4, 2021
bc5e083
Merge branch 'op_expr_infer_tensor_meta' of https://github.com/Oneflo…
clackhan Jun 4, 2021
7a21e7a
Merge branch 'master' into op_expr_infer_tensor_meta
clackhan Jun 4, 2021
64963c3
Merge branch 'op_expr_infer_tensor_meta' into infer_consistent_tensor…
lixinqi Jun 4, 2021
0e93aee
Infer ConsistentTensorMeta
lixinqi Jun 5, 2021
ae97c62
Implement EagerConsistentInterpret::ApplyImpl
lixinqi Jun 6, 2021
347bfa9
1) move XXXTensorMeta into the new file tensor_meta.h; 2) add new Cla…
lixinqi Jun 7, 2021
333402d
merge master
lixinqi Jun 7, 2021
97a6b25
add class ConsistentTensorInferResult
lixinqi Jun 7, 2021
887ceb4
remove unused OpArgMutConsistentTensorMeta::parallel_distribution_
lixinqi Jun 7, 2021
9c314fc
Merge branch 'master' into infer_consistent_tensor_meta
clackhan Jun 7, 2021
b56027f
Merge branch 'master' into infer_consistent_tensor_meta
lixinqi Jun 7, 2021
8f3cd5c
fix stack-overflow bug in Tensor::mut_eager_mirrored_tensor_impl
lixinqi Jun 7, 2021
c60f7ea
Merge branch 'infer_consistent_tensor_meta' of github.com:Oneflow-Inc…
lixinqi Jun 7, 2021
b82e5b0
ignore empty parallel distribution constaint
lixinqi Jun 7, 2021
03e47f9
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Jun 7, 2021
eb78de3
Merge branch 'infer_consistent_tensor_meta' of https://github.com/One…
clackhan Jun 7, 2021
30522a3
fix bug
clackhan Jun 7, 2021
48666f1
Merge branch 'master' into infer_consistent_tensor_meta
clackhan Jun 9, 2021
c4dab31
Merge branch 'master' into infer_consistent_tensor_meta
clackhan Jun 9, 2021
a25a2a5
merge master
lixinqi Jun 22, 2021
a2bfb9f
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Jun 25, 2021
436b77f
Merge branch 'master' into infer_consistent_tensor_meta
jackalcooper Jun 25, 2021
47c4a79
add explicit of cfg
clackhan Jun 25, 2021
002c20b
Merge branch 'master' into infer_consistent_tensor_meta
clackhan Jun 25, 2021
8b454a9
fix xla compile bug
clackhan Jun 28, 2021
ed00650
Merge branch 'infer_consistent_tensor_meta' of https://github.com/One…
clackhan Jun 28, 2021
582c367
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Jun 28, 2021
2b355b1
auto format by CI
oneflow-ci-bot Jun 28, 2021
19fb61f
merge master
lixinqi Jun 29, 2021
face33e
merge origin
lixinqi Jun 29, 2021
fb46186
fix according comment
clackhan Jun 29, 2021
d51e7f1
Merge branch 'master' into infer_consistent_tensor_meta
clackhan Jun 29, 2021
2969c2b
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Jul 1, 2021
57a549b
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Jul 1, 2021
5d6f38f
Merge branch 'master' into infer_consistent_tensor_meta
oneflow-ci-bot Jul 1, 2021
87917ea
fix bug
clackhan Jul 1, 2021
5cd6a3b
Merge branch 'infer_consistent_tensor_meta' of https://github.com/One…
clackhan Jul 1, 2021
45787fc
Merge branch 'master' into infer_consistent_tensor_meta
oneflow-ci-bot Jul 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmake/oneflow.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ list(APPEND OF_CORE_HDRS "${PROJECT_SOURCE_DIR}/oneflow/core/kernel/new_kernel_u
list(APPEND OF_CORE_HDRS "${PROJECT_SOURCE_DIR}/oneflow/core/kernel/kernel_context.h")
list(APPEND OF_CORE_HDRS "${PROJECT_SOURCE_DIR}/oneflow/core/kernel/kernel_util.cuh")
list(APPEND OF_CORE_HDRS "${PROJECT_SOURCE_DIR}/oneflow/core/job/sbp_signature_builder.h")
list(APPEND OF_CORE_HDRS "${PROJECT_SOURCE_DIR}/oneflow/core/common/symbol.h")
list(APPEND OF_CORE_HDRS "${PROJECT_SOURCE_DIR}/oneflow/core/job/parallel_desc.h")
list(APPEND OF_CORE_HDRS "${PROJECT_SOURCE_DIR}/oneflow/core/autograd/autograd_meta.h")
copy_files("${OF_CORE_HDRS}" "${PROJECT_SOURCE_DIR}" "${ONEFLOW_INCLUDE_DIR}" of_include_copy)
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/framework/arg_tuple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ ArgTuple::ArgTuple(const std::vector<std::string>& indexed_bns) : indexed_bns_(i
for (const auto& bn : indexed_bns) { indexed_arg_name_and_index_.push_back(GetPair(bn)); }
InitArgName2BnIndex2TensorTupleIndex(indexed_arg_name_and_index_,
&arg_name2bn_index2tensor_tuple_index_);
for (int i = 0; i < indexed_bns.size(); ++i) {
bn_in_op2tensor_tuple_index_[indexed_bns.at(i)] = i;
}
}

int32_t ArgTuple::TensorTupleIndex4ArgNameAndIndex(const std::string& name, int32_t index) const {
Expand Down
4 changes: 4 additions & 0 deletions oneflow/core/framework/arg_tuple.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class ArgTuple final {
arg_name2bn_index2tensor_tuple_index() const {
return arg_name2bn_index2tensor_tuple_index_;
}
const std::unordered_map<std::string, int32_t>& bn_in_op2tensor_tuple_index() const {
return bn_in_op2tensor_tuple_index_;
}

// return -1 if not found
int32_t TensorTupleIndex4ArgNameAndIndex(const std::string& name, int32_t index) const;
Expand All @@ -45,6 +48,7 @@ class ArgTuple final {
std::vector<std::string> indexed_bns_;
std::vector<std::pair<std::string, int32_t>> indexed_arg_name_and_index_;
std::unordered_map<std::string, std::vector<int32_t>> arg_name2bn_index2tensor_tuple_index_;
std::unordered_map<std::string, int32_t> bn_in_op2tensor_tuple_index_;
};

} // namespace oneflow
Expand Down
209 changes: 209 additions & 0 deletions oneflow/core/framework/consistent_tensor_infer_cache.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/consistent_tensor_infer_cache.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/job/placement_scope.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/op_expr.h"

namespace oneflow {
namespace one {

size_t InputConsistentTensorMeta::hash_value() const {
return std::hash<Symbol<ConsistentTensorMeta>>()(tensor_meta())
^ std::hash<Symbol<cfg::ParallelDistribution>>()(
consumer_parallel_distribution_constraint());
}

bool InputConsistentTensorMeta::operator==(const InputConsistentTensorMeta& other) const {
return this->tensor_meta() == other.tensor_meta()
&& this->consumer_parallel_distribution_constraint()
== other.consumer_parallel_distribution_constraint();
}

void InputConsistentTensorMeta::assign(
Symbol<ConsistentTensorMeta> tensor_meta,
Symbol<cfg::ParallelDistribution> consumer_parallel_distribution_constraint) {
tensor_meta_ = tensor_meta;
consumer_parallel_distribution_constraint_ = consumer_parallel_distribution_constraint;
}

Maybe<void> ConsistentTensorMetaInferArgs::Init(const TensorTuple& input_tensors,
Symbol<PlacementScope> placement_scope,
const AttrMap& attrs) {
input_consistent_tensor_metas_.resize(input_tensors.size());
placement_scope_ = placement_scope;
attrs_ = attrs;
JUST(InitInputConsistentTensorMetas(input_tensors));
return Maybe<void>::Ok();
}

size_t ConsistentTensorMetaInferArgs::hash_value() const {
size_t hash_value = std::hash<Symbol<PlacementScope>>()(placement_scope_);
hash_value ^= std::hash<AttrMap>()(attrs_);
const auto& tensor_meta_hash_functor = std::hash<InputConsistentTensorMeta>();
for (const auto& tensor_meta : input_consistent_tensor_metas_) {
HashCombine(&hash_value, tensor_meta_hash_functor(tensor_meta));
}
return hash_value;
}

bool ConsistentTensorMetaInferArgs::operator==(const ConsistentTensorMetaInferArgs& other) const {
return this->input_consistent_tensor_metas_ == other.input_consistent_tensor_metas_
&& this->placement_scope_ == other.placement_scope_ && this->attrs_ == other.attrs_;
}

Maybe<void> ConsistentTensorMetaInferArgs::MakeParallelDistributionConstraints(
const UserOpExpr& user_op_expr,
cfg::ParallelDistributionSignature* parallel_distribution_signature) const {
const auto& input_arg_tuple = *user_op_expr.input_arg_tuple();
auto* map = parallel_distribution_signature->mutable_bn_in_op2parallel_distribution();
for (int i = 0; i < input_arg_tuple.size(); ++i) {
const auto& constaint =
input_consistent_tensor_metas_.at(i).consumer_parallel_distribution_constraint();
if (constaint) { (*map)[input_arg_tuple.indexed_bns().at(i)] = *constaint; }
}
return Maybe<void>::Ok();
}

Maybe<void> ConsistentTensorMetaInferArgs::MakeInputBlobDescs(
const UserOpExpr& user_op_expr, std::vector<BlobDesc>* blob_descs) const {
CHECK_OR_RETURN(blob_descs->empty());
const auto& input_arg_tuple = *user_op_expr.input_arg_tuple();
blob_descs->reserve(input_arg_tuple.size());
for (int i = 0; i < input_arg_tuple.size(); ++i) {
const auto& tensor_meta = *input_consistent_tensor_metas_.at(i).tensor_meta();
const auto& shape = std::const_pointer_cast<Shape>(tensor_meta.shape_ptr());
blob_descs->emplace_back(shape, tensor_meta.data_type());
}
return Maybe<void>::Ok();
}

Maybe<void> ConsistentTensorMetaInferArgs::MakeParallelDistributionInferHints(
const UserOpExpr& user_op_expr, const std::vector<BlobDesc>& blob_descs,
std::vector<ParallelDistributionInferHint>* hints) const {
CHECK_OR_RETURN(hints->empty());
const auto& input_arg_tuple = *user_op_expr.input_arg_tuple();
hints->reserve(input_arg_tuple.size());
for (int i = 0; i < input_arg_tuple.size(); ++i) {
const auto& tensor_meta = *input_consistent_tensor_metas_.at(i).tensor_meta();
const auto* parallel_desc = &*tensor_meta.parallel_desc();
const auto* blob_desc = &blob_descs.at(i);
const auto* parallel_distribution = &*tensor_meta.parallel_distribution();
hints->emplace_back(parallel_desc, blob_desc, parallel_distribution);
}
return Maybe<void>::Ok();
}

Maybe<void> ConsistentTensorMetaInferArgs::InitInputConsistentTensorMetas(
const TensorTuple& input_tensors) {
for (int i = 0; i < input_tensors.size(); ++i) {
const auto& tensor = *input_tensors.at(i);
const auto& tensor_meta = JUST(tensor.consistent_tensor_meta());
const auto& constraints = JUST(tensor.consumer_parallel_distribution_constraint());
input_consistent_tensor_metas_.at(i).assign(tensor_meta, constraints);
}
return Maybe<void>::Ok();
}

namespace {

Maybe<Operator> MakeOp(const UserOpExpr& user_op_expr, const AttrMap& attrs,
const std::string& device_tag) {
OperatorConf op_conf;
JUST(user_op_expr.BuildOpConf(&op_conf, attrs));
DeviceType device_type = JUST(DeviceType4DeviceTag(device_tag));
return ConstructOp(op_conf, device_type);
}

} // namespace

/*static*/ Maybe<const ConsistentTensorInferResult> ConsistentTensorInferCache::Infer(
const UserOpExpr& user_op_expr, const ConsistentTensorMetaInferArgs& infer_args) {
Symbol<ParallelDesc> parallel_desc;
{
// Get parallel description.
const auto& placement_scope = infer_args.placement_scope();
parallel_desc = JUST(placement_scope->GetParallelDesc(user_op_expr.op_type_name()));
}
std::vector<OpArgMutConsistentTensorMeta> output_mut_metas(user_op_expr.output_size());
{
// Infer OpArgMutConsistentTensorMeta.
const auto& input_metas = infer_args.input_consistent_tensor_metas();
JUST(user_op_expr.InferLogicalShapeAndDType(
infer_args.attrs(), parallel_desc->device_tag(),
[&](int32_t i) { return &*input_metas.at(i).tensor_meta(); },
[&](int32_t i) { return output_mut_metas.at(i).mut_tensor_meta(); }));
}
const auto& op = JUST(MakeOp(user_op_expr, infer_args.attrs(), parallel_desc->device_tag()));
op->FillOpParallelDesc(parallel_desc.shared_from_symbol());
{
// Infer parallel distribution.
cfg::ParallelDistributionSignature parallel_distribution_constraints;
JUST(infer_args.MakeParallelDistributionConstraints(user_op_expr,
&parallel_distribution_constraints));
std::vector<BlobDesc> blob_descs;
JUST(infer_args.MakeInputBlobDescs(user_op_expr, &blob_descs));
std::vector<ParallelDistributionInferHint> pd_infer_hints;
JUST(infer_args.MakeParallelDistributionInferHints(user_op_expr, blob_descs, &pd_infer_hints));
const auto& input_arg_tuple = *user_op_expr.input_arg_tuple();
const auto& ParallelDistributionInferHint4Ibn =
[&](const std::string& ibn) -> Maybe<const ParallelDistributionInferHint*> {
int32_t input_index = input_arg_tuple.bn_in_op2tensor_tuple_index().at(ibn);
CHECK_GE_OR_RETURN(input_index, 0);
CHECK_LT_OR_RETURN(input_index, pd_infer_hints.size());
return &pd_infer_hints.at(input_index);
};
// The inferred results can be retrieved by op->ParallelDistribution4BnInOp(obn).
JUST(op->InferParallelDistributionSignatureIf(parallel_distribution_constraints, *parallel_desc,
ParallelDistributionInferHint4Ibn));
}
auto* result =
new ConsistentTensorInferResult(user_op_expr.input_size(), user_op_expr.output_size());
auto* input_pd = result->mut_input_parallel_distributions();
for (int32_t i = 0; i < user_op_expr.input_size(); ++i) {
const auto& ibn = user_op_expr.input_arg_tuple()->indexed_bns().at(i);
input_pd->at(i) = SymbolOf(*JUST(op->ParallelDistribution4BnInOp(ibn)));
}
auto* output_metas = result->mut_output_tensor_metas();
for (int32_t i = 0; i < user_op_expr.output_size(); ++i) {
const auto& output_mut_meta = output_mut_metas.at(i);
const auto& shape = output_mut_meta.tensor_meta().shape_ptr();
DataType data_type = output_mut_meta.tensor_meta().data_type();
const auto& obn = user_op_expr.output_arg_tuple()->indexed_bns().at(i);
const auto& parallel_distribution = SymbolOf(*JUST(op->ParallelDistribution4BnInOp(obn)));
ConsistentTensorMeta tensor_meta(shape, data_type, parallel_distribution, parallel_desc);
output_metas->at(i) = SymbolOf(tensor_meta);
}
return std::shared_ptr<const ConsistentTensorInferResult>(result);
}

Maybe<const ConsistentTensorInferResult> ConsistentTensorInferCache::GetOrInfer(
const ConsistentTensorMetaInferArgs& infer_args) {
auto iter = cache_.find(infer_args);
if (iter == cache_.end()) {
const auto& user_op_expr = user_op_expr_.lock();
CHECK_OR_RETURN(static_cast<bool>(user_op_expr));
const auto& output_tensor_metas = JUST(Infer(*user_op_expr, infer_args));
iter = cache_.emplace(infer_args, output_tensor_metas).first;
}
return iter->second;
}

} // namespace one
} // namespace oneflow
Loading