Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support RegisterCustomOpsLibrary via the Python API #4764

Merged
merged 37 commits into from
Aug 28, 2020
Merged
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
306a114
Initial commit
hariharans29 Aug 10, 2020
8bcc970
More changes
hariharans29 Aug 12, 2020
b033785
Resolve conflicts
hariharans29 Aug 12, 2020
ebb6987
Nit
hariharans29 Aug 12, 2020
a72aa6d
Nit
hariharans29 Aug 12, 2020
5a735ed
Fix warning
hariharans29 Aug 12, 2020
f65e504
More changes
hariharans29 Aug 12, 2020
cf70867
Debug linux failure
hariharans29 Aug 12, 2020
08ebd49
debug
hariharans29 Aug 12, 2020
a1c53c9
debug 2
hariharans29 Aug 12, 2020
62064b8
Debug
hariharans29 Aug 13, 2020
6514806
More changes
hariharans29 Aug 13, 2020
aecbc45
More changes
hariharans29 Aug 14, 2020
14b0b36
Changes
hariharans29 Aug 14, 2020
03569a0
Fix link
hariharans29 Aug 14, 2020
bb4bf92
Fix training build
hariharans29 Aug 19, 2020
29bb93f
Merge master and resolve conflicts
hariharans29 Aug 19, 2020
0f08762
Fix build
hariharans29 Aug 19, 2020
fa64d83
Fix build
hariharans29 Aug 19, 2020
d46b900
Merge remote-tracking branch 'origin/master' into CustomOpPythonApi
hariharans29 Aug 24, 2020
537337c
PR feednack
hariharans29 Aug 25, 2020
dcc1963
Fix training build
hariharans29 Aug 25, 2020
3c554ad
Fix training build
hariharans29 Aug 25, 2020
adfa4e3
Debug training build
hariharans29 Aug 25, 2020
7f562be
a
hariharans29 Aug 25, 2020
bc1e4d3
Fix training build
hariharans29 Aug 25, 2020
ad8ff15
Revert some changes
hariharans29 Aug 25, 2020
61450d4
More changes
hariharans29 Aug 27, 2020
51d8f43
Fix mac build
hariharans29 Aug 27, 2020
e0977b4
Fix build
hariharans29 Aug 27, 2020
bb8e0b1
More changes
hariharans29 Aug 27, 2020
ac4f920
Log
hariharans29 Aug 27, 2020
feb8501
Nit
hariharans29 Aug 27, 2020
82ef63d
Nits
hariharans29 Aug 27, 2020
1d0d7ed
refinement
hariharans29 Aug 27, 2020
57b4490
Nit
hariharans29 Aug 28, 2020
a841e6d
PR feedback
hariharans29 Aug 28, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Debug training build
  • Loading branch information
hariharans29 committed Aug 25, 2020
commit adfa4e36b17d2e23ea76b2db056c62936e9b2eb4
38 changes: 19 additions & 19 deletions orttraining/orttraining/python/orttraining_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ TrainingConfigurationResult ConfigureSessionForTraining(
auto data_group_size = parameters.world_size / parameters.horizontal_parallel_size;
if (data_group_size != parameters.data_parallel_size) {
LOGS(*(sess->GetLogger()), WARNING) << "data_parallel_size is not correct, tuned automatically to "
<< data_group_size;
<< data_group_size;
parameters.data_parallel_size = data_group_size;
}

Expand Down Expand Up @@ -145,23 +145,23 @@ TrainingConfigurationResult ConfigureSessionForTraining(

#if defined(USE_NCCL)
void CopyMPIContextToTrainingParameters(TrainingParameters& parameters, const logging::Logger* logger) {
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetWorldRank(): " << MPIContext::GetInstance().GetWorldRank();
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetLocalRank(): " << MPIContext::GetInstance().GetLocalRank();
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetWorldSize(): " << MPIContext::GetInstance().GetWorldSize();
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetLocalSize(): " << MPIContext::GetInstance().GetLocalSize();

parameters.local_rank = MPIContext::GetInstance().GetLocalRank();
parameters.local_size = MPIContext::GetInstance().GetLocalSize();
if (parameters.world_rank != MPIContext::GetInstance().GetWorldRank()) {
if (parameters.world_rank != 0)
LOGS(*logger, WARNING) << "TrainingParameters world_rank is not correct, tuned automatically to " << MPIContext::GetInstance().GetWorldRank();
parameters.world_rank = MPIContext::GetInstance().GetWorldRank();
}
if (parameters.world_size != MPIContext::GetInstance().GetWorldSize()) {
if (parameters.world_size != 1)
LOGS(*logger, WARNING) << "TrainingParameters world_size is not correct, tuned automatically to " << MPIContext::GetInstance().GetWorldSize();
parameters.world_size = MPIContext::GetInstance().GetWorldSize();
}
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetWorldRank(): " << MPIContext::GetInstance().GetWorldRank();
hariharans29 marked this conversation as resolved.
Show resolved Hide resolved
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetLocalRank(): " << MPIContext::GetInstance().GetLocalRank();
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetWorldSize(): " << MPIContext::GetInstance().GetWorldSize();
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetLocalSize(): " << MPIContext::GetInstance().GetLocalSize();

parameters.local_rank = MPIContext::GetInstance().GetLocalRank();
parameters.local_size = MPIContext::GetInstance().GetLocalSize();
if (parameters.world_rank != MPIContext::GetInstance().GetWorldRank()) {
if (parameters.world_rank != 0)
LOGS(*logger, WARNING) << "TrainingParameters world_rank is not correct, tuned automatically to " << MPIContext::GetInstance().GetWorldRank();
parameters.world_rank = MPIContext::GetInstance().GetWorldRank();
}
if (parameters.world_size != MPIContext::GetInstance().GetWorldSize()) {
if (parameters.world_size != 1)
LOGS(*logger, WARNING) << "TrainingParameters world_size is not correct, tuned automatically to " << MPIContext::GetInstance().GetWorldSize();
parameters.world_size = MPIContext::GetInstance().GetWorldSize();
}
}
#endif

Expand Down Expand Up @@ -204,7 +204,7 @@ void addObjectMethodsForTraining(py::module& m) {
return py::none();
});

py::class_<onnxruntime::training::TrainingSession, PyInferenceSession> training_session(m, "TrainingSession");
py::class_<onnxruntime::training::TrainingSession> training_session(m, "TrainingSession");
training_session.def(py::init([](const PySessionOptions& so) {
Environment& env = GetEnv();
return onnxruntime::make_unique<onnxruntime::training::TrainingSession>(so, env);
Expand Down