Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mig op arg para attr #4102

Merged
merged 37 commits into from
Jan 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
bf10f8f
GetPhysicalOpArgBlobAttr
lixinqi Jan 6, 2021
1776952
cfg hash
lixinqi Jan 6, 2021
212620c
fix bug
clackhan Jan 6, 2021
34d33ba
Merge branch 'master' of github.com:Oneflow-Inc/oneflow into cfg_hash
lixinqi Jan 6, 2021
71bd044
Merge branch 'cfg_hash' into mig_py_cfg_sbp
lixinqi Jan 6, 2021
8a2b2f2
Merge branch 'cfg_hash' of github.com:Oneflow-Inc/oneflow into mig_py…
lixinqi Jan 6, 2021
1cb49a3
cfg::SbpParallel typed OpArgParallelAttribute.sbp_parallel
lixinqi Jan 6, 2021
1a08550
cfg::OptMirroredParallel typed OpArgParallelAttribute.opt_mirrored_p…
clackhan Jan 7, 2021
2b5746a
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Jan 7, 2021
d488fd1
rename PyLazyConsistentBlob and PyLazyMirroredBlob
clackhan Jan 7, 2021
3156b6b
fix EagerConsistentBlob bug of property parallel_size
clackhan Jan 7, 2021
85d1927
use static_cast
clackhan Jan 7, 2021
e727aaf
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Jan 7, 2021
94d8d4b
Merge branch 'master' into mig_py_cfg_sbp
clackhan Jan 7, 2021
60d4aab
del redefine hash
clackhan Jan 7, 2021
71f2e58
Merge branch 'mig_py_cfg_sbp' of https://github.com/Oneflow-Inc/onefl…
clackhan Jan 7, 2021
2575e9a
mig_op_arg_para_attr
clackhan Jan 7, 2021
c172f8c
mig DumpToInterfaceBlobConf
clackhan Jan 7, 2021
2b43299
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Jan 7, 2021
ed831f2
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Jan 8, 2021
29b10bc
fix bug
clackhan Jan 8, 2021
6fb88a5
fix inter_face_blob_conf.proto
clackhan Jan 8, 2021
2c6d13b
Merge branch 'master' into mig_op_arg_para_attr
clackhan Jan 8, 2021
6c3ba07
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Jan 8, 2021
08db882
replace None with oneflow_api.INVALID_BATCH_AXIS
clackhan Jan 8, 2021
5a96611
migrate python OpNodeSignatureSymbol to c++ version
lixinqi Jan 9, 2021
813982b
merge master
lixinqi Jan 9, 2021
596b100
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Jan 11, 2021
5be3d75
Merge branch 'mig_op_arg_para_attr' of https://github.com/Oneflow-Inc…
clackhan Jan 11, 2021
bae23d7
fix CONFLICT
clackhan Jan 11, 2021
f8ec1a2
mig DumpToOpNodeSignature
clackhan Jan 11, 2021
305dda6
Merge branch 'master' into mig_op_arg_para_attr
clackhan Jan 11, 2021
03335d7
mig op_arg_util.py completely
clackhan Jan 11, 2021
158c725
Merge branch 'mig_op_arg_para_attr' of https://github.com/Oneflow-Inc…
clackhan Jan 11, 2021
f46be93
Merge branch 'master' into mig_op_arg_para_attr
clackhan Jan 11, 2021
8d75db8
Merge branch 'master' into mig_op_arg_para_attr
clackhan Jan 12, 2021
65b4dc4
Merge branch 'master' into mig_op_arg_para_attr
oneflow-ci-bot Jan 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions cmake/cfg.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ function(GENERATE_CFG_AND_PYBIND11_CPP SRCS HDRS PYBIND_SRCS ROOT_DIR)
oneflow/core/job/job_conf.proto
oneflow/core/job/placement.proto
oneflow/core/operator/op_conf.proto
oneflow/core/operator/inter_face_blob_conf.proto
oneflow/core/operator/interface_blob_conf.proto
oneflow/core/common/shape.proto
oneflow/core/record/image.proto
oneflow/core/record/record.proto
Expand All @@ -52,6 +52,7 @@ function(GENERATE_CFG_AND_PYBIND11_CPP SRCS HDRS PYBIND_SRCS ROOT_DIR)
oneflow/core/job/scope.proto
oneflow/core/job/mirrored_parallel.proto
oneflow/core/operator/op_attribute.proto
oneflow/core/operator/op_node_signature.proto
oneflow/core/register/batch_axis_signature.proto
oneflow/core/operator/arg_modifier_signature.proto
oneflow/core/job/blob_lifetime_signature.proto
Expand Down Expand Up @@ -108,13 +109,16 @@ function(GENERATE_CFG_AND_PYBIND11_CPP SRCS HDRS PYBIND_SRCS ROOT_DIR)
oneflow/core/job/scope.proto
oneflow/core/job/mirrored_parallel.proto
oneflow/core/operator/op_attribute.proto
oneflow/core/operator/op_node_signature.proto
oneflow/core/register/batch_axis_signature.proto
oneflow/core/job/parallel_signature.proto
oneflow/core/job/initializer_conf.proto
oneflow/core/job/regularizer_conf.proto
oneflow/core/job/learning_rate_schedule_conf.proto
oneflow/core/common/data_type.proto
oneflow/core/common/device_type.proto
oneflow/core/register/logical_blob_id.proto
oneflow/core/operator/inter_face_blob_conf.proto
oneflow/core/operator/interface_blob_conf.proto
oneflow/core/common/shape.proto
oneflow/core/register/blob_desc.proto
oneflow/core/register/pod.proto
Expand Down
63 changes: 27 additions & 36 deletions oneflow/api/python/deprecated.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,59 +19,50 @@ limitations under the License.
#include "oneflow/core/job/sbp_parallel.pb.h"
#include "oneflow/core/job/mirrored_parallel.cfg.h"
#include "oneflow/core/job/mirrored_parallel.pb.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/operator/op_attribute.pb.h"
#include "oneflow/core/operator/op_node_signature.pb.h"
#include "oneflow/core/operator/op_node_signature.cfg.h"
#include "oneflow/core/job/sbp_parallel.cfg.h"
#include "oneflow/core/job/mirrored_parallel.cfg.h"
#include "oneflow/core/register/blob_desc.cfg.h"
#include "oneflow/core/register/batch_axis_signature.cfg.h"
#include "oneflow/core/job/parallel_signature.cfg.h"
#include "oneflow/core/common/data_type.cfg.h"
#include "oneflow/core/common/data_type.pb.h"
#include "oneflow/core/register/blob_desc.cfg.h"
#include "oneflow/core/register/blob_desc.pb.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/common/maybe.h"

namespace py = pybind11;

namespace oneflow {

namespace {

Maybe<cfg::SbpParallel> MakeSbpParallel(const std::string& serialized_str) {
SbpParallel sbp_parallel;
CHECK_OR_RETURN(TxtString2PbMessage(serialized_str, &sbp_parallel))
<< "sbp_parallel parse failed";
return std::make_shared<cfg::SbpParallel>(sbp_parallel);
}

Maybe<cfg::OptMirroredParallel> MakeOptMirroredParallel(const std::string& serialized_str) {
OptMirroredParallel opt_mirrored_parallel;
CHECK_OR_RETURN(TxtString2PbMessage(serialized_str, &opt_mirrored_parallel))
<< "opt_mirrored_parallel parse failed";
return std::make_shared<cfg::OptMirroredParallel>(opt_mirrored_parallel);
}

Maybe<cfg::OptInt64> MakeOptInt64(const std::string& serialized_str) {
OptInt64 opt_int64;
CHECK_OR_RETURN(TxtString2PbMessage(serialized_str, &opt_int64)) << "opt_int64 parse failed";
return std::make_shared<cfg::OptInt64>(opt_int64);
}

Maybe<cfg::BlobDescProto> MakeBlobDescProto(const std::string& serialized_str) {
BlobDescProto blob_desc;
CHECK_OR_RETURN(TxtString2PbMessage(serialized_str, &blob_desc)) << "blob_desc parse failed";
return std::make_shared<cfg::BlobDescProto>(blob_desc);
Maybe<cfg::OpNodeSignature> MakeOpNodeSignatureFromSerializedOpAttribute(
const std::string& op_attribute_str) {
OpAttribute op_attribute;
CHECK_OR_RETURN(TxtString2PbMessage(op_attribute_str, &op_attribute))
<< "op_attribute parse failed";
auto op_node_signature = std::make_shared<cfg::OpNodeSignature>();
op_node_signature->mutable_sbp_signature()->InitFromProto(op_attribute.sbp_signature());
op_node_signature->mutable_mirrored_signature()->InitFromProto(op_attribute.mirrored_signature());
op_node_signature->mutable_logical_blob_desc_signature()->InitFromProto(
op_attribute.logical_blob_desc_signature());
op_node_signature->mutable_batch_axis_signature()->InitFromProto(
op_attribute.batch_axis_signature());
op_node_signature->mutable_parallel_signature()->InitFromProto(op_attribute.parallel_signature());
return op_node_signature;
}

} // namespace

ONEFLOW_API_PYBIND11_MODULE("deprecated", m) {
m.def("MakeSbpParrallelByString",
[](const std::string& str) { return MakeSbpParallel(str).GetPtrOrThrow(); });

m.def("MakeOptMirroredParrallelByString",
[](const std::string& str) { return MakeOptMirroredParallel(str).GetPtrOrThrow(); });

m.def("MakeOptInt64ByString",
[](const std::string& str) { return MakeOptInt64(str).GetPtrOrThrow(); });

m.def("MakeBlobDescProtoByString",
[](const std::string& str) { return MakeBlobDescProto(str).GetPtrOrThrow(); });
m.def("MakeOpNodeSignatureFromSerializedOpAttribute", [](const std::string& str) {
return MakeOpNodeSignatureFromSerializedOpAttribute(str).GetPtrOrThrow();
});
}

} // namespace oneflow
111 changes: 104 additions & 7 deletions oneflow/api/python/framework/op_arg_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,94 @@ limitations under the License.
#include <pybind11/operators.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/op_arg_util.h"
#include "oneflow/core/job/sbp_parallel.pb.h"
#include "oneflow/core/job/mirrored_parallel.pb.h"
#include "oneflow/core/common/data_type.cfg.h"
#include "oneflow/core/common/data_type.pb.h"
#include "oneflow/core/register/blob_desc.cfg.h"
#include "oneflow/core/register/blob_desc.pb.h"
#include "oneflow/core/operator/op_attribute.pb.h"
#include "oneflow/core/common/protobuf.h"

namespace py = pybind11;

namespace oneflow {

namespace compatible_py {

namespace {

Maybe<cfg::SbpParallel> MakeSbpParallel(const std::string& serialized_str) {
SbpParallel sbp_parallel;
CHECK_OR_RETURN(TxtString2PbMessage(serialized_str, &sbp_parallel))
<< "sbp_parallel parse failed";
return std::make_shared<cfg::SbpParallel>(sbp_parallel);
}

Maybe<cfg::OptMirroredParallel> MakeOptMirroredParallel(const std::string& serialized_str) {
OptMirroredParallel opt_mirrored_parallel;
CHECK_OR_RETURN(TxtString2PbMessage(serialized_str, &opt_mirrored_parallel))
<< "opt_mirrored_parallel parse failed";
return std::make_shared<cfg::OptMirroredParallel>(opt_mirrored_parallel);
}

Maybe<cfg::OptInt64> MakeOptInt64(const std::string& serialized_str) {
if (serialized_str.empty()) { return std::make_shared<cfg::OptInt64>(); }
OptInt64 opt_int64;
CHECK_OR_RETURN(TxtString2PbMessage(serialized_str, &opt_int64)) << "opt_int64 parse failed";
return std::make_shared<cfg::OptInt64>(opt_int64);
}

Maybe<cfg::BlobDescProto> MakeBlobDescProto(const std::string& serialized_str) {
if (serialized_str.empty()) { return std::make_shared<cfg::BlobDescProto>(); }
BlobDescProto blob_desc;
CHECK_OR_RETURN(TxtString2PbMessage(serialized_str, &blob_desc)) << "blob_desc parse failed";
return std::make_shared<cfg::BlobDescProto>(blob_desc);
}

Maybe<OpArgBlobAttribute> CreatOpArgBlobAttribute(const std::string& batch_axis_str,
const std::string& blob_desc_str,
const std::string& logical_blob_name) {
const std::shared_ptr<cfg::OptInt64>& batch_axis = JUST(MakeOptInt64(batch_axis_str));
const std::shared_ptr<cfg::BlobDescProto>& blob_desc = JUST(MakeBlobDescProto(blob_desc_str));
return std::make_shared<OpArgBlobAttribute>(batch_axis, blob_desc, logical_blob_name);
}

Maybe<OpArgParallelAttribute> CreatOpArgParallelAttribute(
std::shared_ptr<ParallelDesc> parallel_desc, const std::string& sbp_parallel_str,
const std::string& opt_mirrored_parallel_str) {
std::shared_ptr<cfg::SbpParallel> sbp_parallel = JUST(MakeSbpParallel(sbp_parallel_str));
std::shared_ptr<cfg::OptMirroredParallel> opt_mirrored_parallel =
JUST(MakeOptMirroredParallel(opt_mirrored_parallel_str));
return std::make_shared<OpArgParallelAttribute>(parallel_desc, sbp_parallel,
opt_mirrored_parallel);
}

Maybe<OpArgBlobAttribute> ApiGetOpArgBlobAttribute(const std::string& op_attribute_str,
const std::string& bn_in_op) {
OpAttribute op_attribute;
CHECK_OR_RETURN(TxtString2PbMessage(op_attribute_str, &op_attribute))
<< "op_attribute parse failed";
return GetOpArgBlobAttribute(op_attribute, bn_in_op);
}

Maybe<OpArgParallelAttribute> ApiGetOpArgParallelAttribute(
const std::shared_ptr<ParallelDesc>& parallel_desc_symbol, const std::string& op_attribute_str,
const std::string& bn_in_op) {
OpAttribute op_attribute;
CHECK_OR_RETURN(TxtString2PbMessage(op_attribute_str, &op_attribute))
<< "op_attribute parse failed";
return GetOpArgParallelAttribute(parallel_desc_symbol, op_attribute, bn_in_op);
}

} // namespace

ONEFLOW_API_PYBIND11_MODULE("", m) {
py::class_<OpArgBlobAttribute, std::shared_ptr<OpArgBlobAttribute>>(m, "OpArgBlobAttribute")
.def(py::init([](const std::shared_ptr<cfg::OptInt64>& batch_axis,
const std::shared_ptr<cfg::BlobDescProto>& blob_desc,
.def(py::init([](const std::string& batch_axis_str, const std::string& blob_desc_str,
const std::string& logical_blob_name) {
return std::make_shared<OpArgBlobAttribute>(batch_axis, blob_desc, logical_blob_name);
return CreatOpArgBlobAttribute(batch_axis_str, blob_desc_str, logical_blob_name)
.GetPtrOrThrow();
}))
.def_property_readonly("batch_axis", &OpArgBlobAttribute::batch_axis)
.def_property_readonly("blob_desc", &OpArgBlobAttribute::blob_desc)
Expand All @@ -49,16 +124,19 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
[](const std::shared_ptr<OpArgBlobAttribute>& x) {
return static_cast<int>(x->get_dtype());
})
.def("GetPhysicalOpArgBlobAttr", &OpArgBlobAttribute::GetPhysicalOpArgBlobAttr)
.def("DumpToInterfaceBlobConf", &OpArgBlobAttribute::DumpToInterfaceBlobConf)
.def("DumpToOpNodeSignature", &OpArgBlobAttribute::DumpToOpNodeSignature)
.def(py::self == py::self);

py::class_<OpArgParallelAttribute, std::shared_ptr<OpArgParallelAttribute>>(
m, "OpArgParallelAttribute")
.def(py::init([](std::shared_ptr<ParallelDesc> parallel_desc,
std::shared_ptr<cfg::SbpParallel> sbp_parallel,
std::shared_ptr<cfg::OptMirroredParallel> opt_mirrored_parallel) {
return std::make_shared<OpArgParallelAttribute>(parallel_desc, sbp_parallel,
opt_mirrored_parallel);
const std::string& sbp_parallel_str,
const std::string& opt_mirrored_parallel_str) {
return CreatOpArgParallelAttribute(parallel_desc, sbp_parallel_str,
opt_mirrored_parallel_str)
.GetPtrOrThrow();
}))
.def_property_readonly("parallel_desc_symbol", &OpArgParallelAttribute::parallel_desc_symbol)
.def_property_readonly("sbp_parallel", &OpArgParallelAttribute::sbp_parallel)
Expand All @@ -68,10 +146,29 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
.def("_Hash", &OpArgParallelAttribute::_Hash)
.def("Assign", &OpArgParallelAttribute::Assign)
.def("DumpToInterfaceBlobConf", &OpArgParallelAttribute::DumpToInterfaceBlobConf)
.def("DumpToOpNodeSignature", &OpArgParallelAttribute::DumpToOpNodeSignature)
.def("__str__", &OpArgParallelAttribute::ToString)
.def("__repr__", &OpArgParallelAttribute::ToString)
.def(py::self == py::self)
.def(py::hash(py::self));
m.def("GetOpArgBlobAttribute",
[](const std::string& op_attribute_str, const std::string& bn_in_op) {
return ApiGetOpArgBlobAttribute(op_attribute_str, bn_in_op).GetPtrOrThrow();
});
m.def("GetOpArgParallelAttribute",
[](const std::shared_ptr<ParallelDesc>& parallel_desc_symbol,
const std::string& op_attribute_str, const std::string& bn_in_op) {
return ApiGetOpArgParallelAttribute(parallel_desc_symbol, op_attribute_str, bn_in_op)
.GetPtrOrThrow();
});
m.def("MakeMirroredOpArgParallelAttribute",
[](const std::shared_ptr<ParallelDesc>& parallel_desc_symbol) {
return MakeMirroredOpArgParallelAttribute(parallel_desc_symbol).GetPtrOrThrow();
});
m.def("MakeBroadcastOpArgParallelAttribute",
[](const std::shared_ptr<ParallelDesc>& parallel_desc_symbol) {
return MakeBroadcastOpArgParallelAttribute(parallel_desc_symbol).GetPtrOrThrow();
});
}

} // namespace compatible_py
Expand Down
1 change: 1 addition & 0 deletions oneflow/api/python/op/op_mgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include "oneflow/api/python/framework/framework.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/operator/op_attribute.pb.h"
#include "oneflow/core/operator/op_node_signature.pb.h"
#include "oneflow/core/job/scope.h"
#include "oneflow/core/job/job_build_and_infer_ctx_mgr.h"
#include "oneflow/core/framework/user_op_conf.h"
Expand Down
45 changes: 45 additions & 0 deletions oneflow/api/python/symbol/op_node_signature_symbol.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/operator/op_node_signature_desc.h"
#include "oneflow/core/operator/op_node_signature.pb.h"
#include "oneflow/core/operator/op_node_signature.cfg.h"

namespace py = pybind11;

namespace oneflow {

Maybe<OpNodeSignatureDesc> CreateScopeSymbol(
int64_t symbol_id, const std::shared_ptr<cfg::OpNodeSignature>& symbol_conf) {
OpNodeSignature symbol_pb;
symbol_conf->ToProto(&symbol_pb);
return std::make_shared<OpNodeSignatureDesc>(symbol_id, symbol_pb);
}

ONEFLOW_API_PYBIND11_MODULE("", m) {
py::class_<OpNodeSignatureDesc, std::shared_ptr<OpNodeSignatureDesc>>(m, "OpNodeSignatureSymbol")
.def(
py::init([](int64_t symbol_id, const std::shared_ptr<cfg::OpNodeSignature>& symbol_conf) {
return CreateScopeSymbol(symbol_id, symbol_conf).GetPtrOrThrow();
}))
.def_property_readonly(
"symbol_id", [](const OpNodeSignatureDesc& x) { return x.symbol_id().GetOrThrow(); })
.def("data", &OpNodeSignatureDesc::cfg_op_node_signature);
}

} // namespace oneflow
25 changes: 18 additions & 7 deletions oneflow/api/python/symbol/symbol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ limitations under the License.
#include "oneflow/core/job/scope.h"
#include "oneflow/core/job/scope.cfg.h"
#include "oneflow/core/job/scope.pb.h"
#include "oneflow/core/operator/op_node_signature_desc.h"
#include "oneflow/core/operator/op_node_signature.cfg.h"
#include "oneflow/core/operator/op_node_signature.pb.h"

namespace py = pybind11;

Expand All @@ -30,14 +33,19 @@ namespace oneflow {
namespace {

template<typename SymbolConfT>
bool ApiHasSymbol(const SymbolConfT& symbol_conf) {
const auto& id_cache = *Global<symbol::IdCache<SymbolConfT>>::Get();
Maybe<bool> HasSymbol(const SymbolConfT& symbol_conf) {
const auto& id_cache = *JUST(GlobalMaybe<symbol::IdCache<SymbolConfT>>());
return id_cache.Has(symbol_conf);
}

template<typename SymbolConfT>
bool ApiHasSymbol(const SymbolConfT& symbol_conf) {
return HasSymbol(symbol_conf).GetOrThrow();
}

template<typename SymbolConfT, typename SymbolT>
Maybe<SymbolT> GetSymbol(const SymbolConfT& symbol_conf) {
const auto& id_cache = *Global<symbol::IdCache<SymbolConfT>>::Get();
const auto& id_cache = *JUST(GlobalMaybe<symbol::IdCache<SymbolConfT>>());
const auto& symbol_storage = *Global<symbol::Storage<SymbolT>>::Get();
int64_t symbol_id = JUST(id_cache.Get(symbol_conf));
const auto& ptr = JUST(symbol_storage.MaybeGetPtr(symbol_id));
Expand All @@ -51,7 +59,7 @@ Maybe<void> AddSymbol(int64_t symbol_id, const SymbolConfT& symbol_conf) {
SymbolPbT symbol_pb;
symbol_conf.ToProto(&symbol_pb);
JUST(Global<symbol::Storage<SymbolT>>::Get()->Add(symbol_id, symbol_pb));
auto* id_cache = Global<symbol::IdCache<SymbolConfT>>::Get();
auto* id_cache = JUST(GlobalMaybe<symbol::IdCache<SymbolConfT>>());
CHECK_OR_RETURN(!id_cache->Has(symbol_conf));
JUST(id_cache->FindOrCreate(symbol_conf, [&symbol_id]() -> Maybe<int64_t> { return symbol_id; }));
return Maybe<void>::Ok();
Expand Down Expand Up @@ -85,21 +93,24 @@ std::shared_ptr<SymbolT> ApiGetSymbolById(int64_t symbol_id) {
ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("HasPlacementSymbol", &ApiHasSymbol<cfg::ParallelConf>);
m.def("AddPlacementSymbol", &ApiAddSymbol<cfg::ParallelConf, ParallelConf, ParallelDesc>);

m.def("GetPlacementSymbol", &ApiGetSymbol<cfg::ParallelConf, ParallelDesc>);
m.def("GetPlacementSymbol", &ApiGetSymbolById<cfg::ParallelConf, ParallelDesc>);

m.def("HasJobConfSymbol", &ApiHasSymbol<cfg::JobConfigProto>);
m.def("AddJobConfSymbol", &ApiAddSymbol<cfg::JobConfigProto, JobConfigProto, JobDesc>);

m.def("GetJobConfSymbol", &ApiGetSymbol<cfg::JobConfigProto, JobDesc>);
m.def("GetJobConfSymbol", &ApiGetSymbolById<cfg::JobConfigProto, JobDesc>);

m.def("HasScopeSymbol", &ApiHasSymbol<cfg::ScopeProto>);
m.def("AddScopeSymbol", &ApiAddSymbol<cfg::ScopeProto, ScopeProto, Scope>);

m.def("GetScopeSymbol", &ApiGetSymbol<cfg::ScopeProto, Scope>);
m.def("GetScopeSymbol", &ApiGetSymbolById<cfg::ScopeProto, Scope>);

m.def("HasOpNodeSignatureSymbol", &ApiHasSymbol<cfg::OpNodeSignature>);
m.def("AddOpNodeSignatureSymbol",
&ApiAddSymbol<cfg::OpNodeSignature, OpNodeSignature, OpNodeSignatureDesc>);
m.def("GetOpNodeSignatureSymbol", &ApiGetSymbol<cfg::OpNodeSignature, OpNodeSignatureDesc>);
m.def("GetOpNodeSignatureSymbol", &ApiGetSymbolById<cfg::OpNodeSignature, OpNodeSignatureDesc>);
}

} // namespace oneflow
Loading