diff --git a/src/command/marian.cpp b/src/command/marian.cpp index e90ec92e2..6aa9ae441 100644 --- a/src/command/marian.cpp +++ b/src/command/marian.cpp @@ -11,7 +11,7 @@ #include "training/graph_group_async_drop.h" #endif -bool configureMPI(int, char**); +bool configureMPI(int, char**, int); int main(int argc, char** argv) { using namespace marian; @@ -20,7 +20,11 @@ int main(int argc, char** argv) { auto devices = options->getDevices(); if(options->get("multi-node")) { - ABORT_IF(!configureMPI(argc, argv), "MPI not found."); + if (options->get("sync-sgd")) { + ABORT_IF(!configureMPI(argc, argv, MPI_THREAD_SERIALIZED), "MPI not found."); + } else { + ABORT_IF(!configureMPI(argc, argv, MPI_THREAD_MULTIPLE), "MPI not found."); + } LOG(warn, "[experimental] Running multi-node training"); if (!options->get("sync-sgd")) { @@ -46,7 +50,7 @@ int main(int argc, char** argv) { return 0; } -bool configureMPI(int argc, char** argv) { +bool configureMPI(int argc, char** argv, int required_mode) { bool enable = false; #if MPI_FOUND int provided_thread_mode = 0; @@ -55,7 +59,7 @@ bool configureMPI(int argc, char** argv) { MPI_Comm_set_errhandler(MPI_COMM_WORLD, MPI_ERRORS_RETURN); ABORT_IF( - provided_thread_mode < MPI_THREAD_MULTIPLE, + provided_thread_mode < required_mode, "Your version of MPI does not support multi-threaded communication."); enable = true; diff --git a/src/training/graph_group_multinode_sync.h b/src/training/graph_group_multinode_sync.h index 17aedfc5e..7b5b76ffe 100644 --- a/src/training/graph_group_multinode_sync.h +++ b/src/training/graph_group_multinode_sync.h @@ -217,7 +217,6 @@ class MultiNodeGraphGroupSync : public GraphGroup { */ virtual ~MultiNodeGraphGroupSync() { //@TODO merge with finalize method - MPI_Finalize(); delete clientThreadPool_; } @@ -309,6 +308,7 @@ class MultiNodeGraphGroupSync : public GraphGroup { virtual void finalize() { finalized_ = true; + MPI_Finalize(); } }; }