Skip to content

Commit

Permalink
address CR comments from Mark
Browse files Browse the repository at this point in the history
  • Loading branch information
fmegen committed Oct 14, 2016
1 parent 9e42119 commit 3d4723c
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 182 deletions.
2 changes: 1 addition & 1 deletion CNTK.Cpp.props
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@

<ItemDefinitionGroup>
<ClCompile>
<PreprocessorDefinitions>HAS_OPENMPI=1</PreprocessorDefinitions>
<PreprocessorDefinitions>HAS_MPI=1</PreprocessorDefinitions>
</ClCompile>
</ItemDefinitionGroup>

Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ SSE_FLAGS = -msse4.1 -mssse3
SOURCEDIR:= Source
INCLUDEPATH:= $(addprefix $(SOURCEDIR)/, Common/Include CNTKv2LibraryDll CNTKv2LibraryDll/API Math CNTK ActionsLib ComputationNetworkLib SGDLib SequenceTrainingLib CNTK/BrainScript Readers/ReaderLib)
# COMMON_FLAGS include settings that are passed both to NVCC and C++ compilers.
COMMON_FLAGS:= -DHAS_OPENMPI=$(HAS_MPI) -D_POSIX_SOURCE -D_XOPEN_SOURCE=600 -D__USE_XOPEN2K -std=c++11
COMMON_FLAGS:= -DHAS_MPI=$(HAS_MPI) -D_POSIX_SOURCE -D_XOPEN_SOURCE=600 -D__USE_XOPEN2K -std=c++11
CPPFLAGS:=
CXXFLAGS:= $(SSE_FLAGS) -std=c++0x -fopenmp -fpermissive -fPIC -Werror -fcheck-new
LIBPATH:=
Expand Down
161 changes: 13 additions & 148 deletions Source/Common/Include/MPIWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
//
#pragma once

#if HAS_OPENMPI
#if HAS_MPI
// Please see https://github.com/Microsoft/CNTK/wiki/Setup-CNTK-on-Windows#ms-mpi or
// https://github.com/Microsoft/CNTK/wiki/Setup-CNTK-on-Linux#open-mpi for setup instructions
// of an MPI implementation on your platform.
Expand All @@ -19,7 +19,7 @@
#include "mpi.h"
#endif
#else
// Note: the following macros define some of the MPI related functions and constants such that code
// Note: the following macros/typedefs define some of the MPI related functions and constants such that code
// using these functionality will compile cleanly - but will not actually perform the MPI operation.
// The clean way to go is to move any code related to mpi into the mpiwrapper class implementation and decide
// in this class if to use mpi.h or not.
Expand Down Expand Up @@ -107,158 +107,23 @@ class MPIWrapper : public std::enable_shared_from_this<MPIWrapper>
static MPI_Datatype GetDataType(size_t *);

// allreduce of a vector
virtual void AllReduce(std::vector<size_t>&accumulator) const = 0;
virtual void AllReduce(std::vector<int>&accumulator) const = 0;
virtual void AllReduce(std::vector<double>&accumulator) const = 0;
virtual void AllReduce(std::vector<float>&accumulator) const = 0;
virtual void AllReduce(std::vector<size_t>& accumulator) const = 0;
virtual void AllReduce(std::vector<int>& accumulator) const = 0;
virtual void AllReduce(std::vector<double>& accumulator) const = 0;
virtual void AllReduce(std::vector<float>& accumulator) const = 0;

// for raw pointer
virtual void AllReduce(size_t*pData, size_t nData) = 0;
virtual void AllReduce(int*pData, size_t nData) = 0;
virtual void AllReduce(double*pData, size_t nData) = 0;
virtual void AllReduce(float*pData, size_t nData) = 0;
virtual void AllReduce(size_t* pData, size_t nData) = 0;
virtual void AllReduce(int* pData, size_t nData) = 0;
virtual void AllReduce(double* pData, size_t nData) = 0;
virtual void AllReduce(float* pData, size_t nData) = 0;

virtual void Bcast(size_t*pData, size_t nData, size_t srcRank) = 0;
virtual void Bcast(double*pData, size_t nData, size_t srcRank) = 0;
virtual void Bcast(float*pData, size_t nData, size_t srcRank) = 0;
virtual void Bcast(size_t* pData, size_t nData, size_t srcRank) = 0;
virtual void Bcast(double* pData, size_t nData, size_t srcRank) = 0;
virtual void Bcast(float* pData, size_t nData, size_t srcRank) = 0;

// wait for all ranks to reach here
virtual int WaitAll() = 0;
};


#if HAS_OPENMPI

class MPIWrapperMpi : public MPIWrapper
{
int m_myRank;
int m_numMPINodes;
size_t m_numNodesInUse;

// MPI communicator that reflects the current subset selection
MPI_Comm m_currentComm;

// MPI_Init() with delay-loading the msmpi.dll (possibly causing a failure if missing; we want to catch that)
int MPI_Init_DL();

// Workaround for the issue with MPI hanging when we have non-0 exit codes from CNTK processes
// OpenMPI has a confirmed race condition on killing child process vs. handling their non-zero exit statuses, resulting
// in a deadlock, where all processes killed but MPI is still waiting.
// This happens when several perfectly synchronized processes (for example on MPI barrier)
// simulatenously exit with non-0 exit code.
// As a workaround, we simply sleep 50*rank miliseconds, effectively "de-synchronizing processes" at exit,
// allowing MPI to sequentially handle terminations
static int s_myRank;
static void MPIWorkaroundAtExit();

public:
MPIWrapperMpi();

// Note: we don't clear the sub-communication here although we should, because in case of a crash, this prevents the EXE from terminating.
// It's OK since this class is a singleton anyway that gets instantiated exactly once at program startup.
~MPIWrapperMpi();

private:
void Ping(const char *msg) const;
MPI_Comm Communicator() const;

void RequestNodes(const char *msg, size_t requestednodes = SIZE_MAX /*default: all*/);

public:

size_t NumNodesInUse() const;
size_t CurrentNodeRank() const;
bool IsMainNode() const;
bool IsIdle() const;
bool UsingAllNodes() const;
size_t MainNodeRank() const;

// -----------------------------------------------------------------------
// data-exchange functions (wrappers around MPI functions)
// -----------------------------------------------------------------------

virtual int Finalize(void);
virtual int Wait(MPI_Request* request, MPI_Status* status);
virtual int Waitany(int count, MPI_Request array_of_requests[], int* index, MPI_Status* status);
virtual int Waitall(int count, MPI_Request array_of_requests[], MPI_Status array_of_statuses[]);
virtual int Isend(const void* buf, int count, MPI_Datatype datatype, int dest, int tag, /*MPI_Comm comm,*/ MPI_Request* request);
virtual int Recv(void* buf, int count, MPI_Datatype datatype, int source, int tag, /*MPI_Comm comm,*/ MPI_Status* status);
virtual int Irecv(void* buf, int count, MPI_Datatype datatype, int source, int tag, /*MPI_Comm comm,*/ MPI_Request* request);
virtual int Iallreduce(const void* sendbuf, void* recvbuf, int count, MPI_Datatype datatype, MPI_Op op, /*MPI_Comm comm,*/ MPI_Request* request);
virtual int Abort(int errorcode);
virtual int Error_string(int errorcode, char* string, int* resultlen);

// allreduce of a vector
virtual void AllReduce(std::vector<size_t>&accumulator) const;
virtual void AllReduce(std::vector<int>&accumulator) const;
virtual void AllReduce(std::vector<double>&accumulator) const;
virtual void AllReduce(std::vector<float>&accumulator) const;

// for raw pointer
virtual void AllReduce(size_t*pData, size_t nData);
virtual void AllReduce(int*pData, size_t nData);
virtual void AllReduce(double*pData, size_t nData);
virtual void AllReduce(float*pData, size_t nData);

virtual void Bcast(size_t*pData, size_t nData, size_t srcRank);
virtual void Bcast(double*pData, size_t nData, size_t srcRank);
virtual void Bcast(float*pData, size_t nData, size_t srcRank);

// wait for all ranks to reach here
int WaitAll();
};

#endif

class MPIWrapperEmpty : public MPIWrapper
{
public:
MPIWrapperEmpty();

// Note: we don't clear the sub-communication here although we should, because in case of a crash, this prevents the EXE from terminating.
// It's OK since this class is a singleton anyway that gets instantiated exactly once at program startup.
~MPIWrapperEmpty();

size_t NumNodesInUse() const;
size_t CurrentNodeRank() const;
bool IsMainNode() const;
bool IsIdle() const;
bool UsingAllNodes() const;
size_t MainNodeRank() const;

// -----------------------------------------------------------------------
// data-exchange functions (wrappers around MPI functions)
// -----------------------------------------------------------------------

virtual int Finalize(void);
virtual int Wait(MPI_Request* request, MPI_Status* status);
virtual int Waitany(int count, MPI_Request array_of_requests[], int* index, MPI_Status* status);
virtual int Waitall(int count, MPI_Request array_of_requests[], MPI_Status array_of_statuses[]);
virtual int Isend(const void* buf, int count, MPI_Datatype datatype, int dest, int tag, /*MPI_Comm comm,*/ MPI_Request* request);
virtual int Recv(void* buf, int count, MPI_Datatype datatype, int source, int tag, /*MPI_Comm comm,*/ MPI_Status* status);
virtual int Irecv(void* buf, int count, MPI_Datatype datatype, int source, int tag, /*MPI_Comm comm,*/ MPI_Request* request);
virtual int Iallreduce(const void* sendbuf, void* recvbuf, int count, MPI_Datatype datatype, MPI_Op op, /*MPI_Comm comm,*/ MPI_Request* request);
virtual int Abort(int errorcode);
virtual int Error_string(int errorcode, char* string, int* resultlen);

// allreduce of a vector
virtual void AllReduce(std::vector<size_t>&accumulator) const;
virtual void AllReduce(std::vector<int>&accumulator) const;
virtual void AllReduce(std::vector<double>&accumulator) const;
virtual void AllReduce(std::vector<float>&accumulator) const;

// for raw pointer
virtual void AllReduce(size_t*pData, size_t nData);
virtual void AllReduce(int*pData, size_t nData);
virtual void AllReduce(double*pData, size_t nData);
virtual void AllReduce(float*pData, size_t nData);

virtual void Bcast(size_t*pData, size_t nData, size_t srcRank);
virtual void Bcast(double*pData, size_t nData, size_t srcRank);
virtual void Bcast(float*pData, size_t nData, size_t srcRank);

// wait for all ranks to reach here
int WaitAll();
};

}}}
Loading

0 comments on commit 3d4723c

Please sign in to comment.