Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[enhancement] use conversion operator semantics to wrap types #804

Merged
merged 3 commits into from
Jan 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 57 additions & 57 deletions src/mpi/communicator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,19 @@ struct op_wrapper;
template <>
struct op_wrapper<op_t::sum>
{
static auto kind() {return MPI_SUM;}
operator MPI_Op() const noexcept {return MPI_SUM;}
};

template <>
struct op_wrapper<op_t::max>
{
static auto kind() {return MPI_MAX;}
operator MPI_Op() const noexcept {return MPI_MAX;}
};

template <>
struct op_wrapper<op_t::min>
{
static auto kind() {return MPI_MIN;}
operator MPI_Op() const noexcept {return MPI_MIN;}
};

template <typename T>
Expand All @@ -89,79 +89,79 @@ struct type_wrapper;
template <>
struct type_wrapper<float>
{
static auto kind() {return MPI_FLOAT;}
operator MPI_Datatype() const noexcept {return MPI_FLOAT;}
};

template <>
struct type_wrapper<std::complex<float>>
{
static auto kind() {return MPI_C_FLOAT_COMPLEX;}
operator MPI_Datatype() const noexcept {return MPI_C_FLOAT_COMPLEX;}
};

template <>
struct type_wrapper<double>
{
static auto kind() {return MPI_DOUBLE;}
operator MPI_Datatype() const noexcept {return MPI_DOUBLE;}
};

template <>
struct type_wrapper<long double>
struct type_wrapper<std::complex<double>>
{
static auto kind() {return MPI_LONG_DOUBLE;}
operator MPI_Datatype() const noexcept {return MPI_C_DOUBLE_COMPLEX;}
};

template <>
struct type_wrapper<std::complex<double>>
struct type_wrapper<long double>
{
static auto kind() {return MPI_C_DOUBLE_COMPLEX;}
operator MPI_Datatype() const noexcept {return MPI_LONG_DOUBLE;}
};

template <>
struct type_wrapper<int>
{
static auto kind() {return MPI_INT;}
operator MPI_Datatype() const noexcept {return MPI_INT;}
};

template <>
struct type_wrapper<int16_t>
{
static auto kind() {return MPI_SHORT;}
operator MPI_Datatype() const noexcept {return MPI_SHORT;}
};

template <>
struct type_wrapper<char>
{
static auto kind() {return MPI_CHAR;}
operator MPI_Datatype() const noexcept {return MPI_CHAR;}
};

template <>
struct type_wrapper<unsigned char>
{
static auto kind() {return MPI_UNSIGNED_CHAR;}
operator MPI_Datatype() const noexcept {return MPI_UNSIGNED_CHAR;}
};

template <>
struct type_wrapper<unsigned long long>
{
static auto kind() {return MPI_UNSIGNED_LONG_LONG;}
operator MPI_Datatype() const noexcept {return MPI_UNSIGNED_LONG_LONG;}
};

template <>
struct type_wrapper<unsigned long>
{
static auto kind() {return MPI_UNSIGNED_LONG;}
operator MPI_Datatype() const noexcept {return MPI_UNSIGNED_LONG;}
};

template <>
struct type_wrapper<bool>
{
static auto kind() {return MPI_CXX_BOOL;}
operator MPI_Datatype() const noexcept {return MPI_CXX_BOOL;}
};

template <>
struct type_wrapper<uint32_t>
{
static auto kind() {return MPI_UINT32_T;}
operator MPI_Datatype() const noexcept {return MPI_UINT32_T;}
};

struct block_data_descriptor
Expand Down Expand Up @@ -421,46 +421,46 @@ class Communicator
inline void reduce(T* buffer__, int count__, int root__) const
{
if (root__ == rank()) {
CALL_MPI(MPI_Reduce, (MPI_IN_PLACE, buffer__, count__, type_wrapper<T>::kind(),
op_wrapper<mpi_op__>::kind(), root__, this->native()));
CALL_MPI(MPI_Reduce, (MPI_IN_PLACE, buffer__, count__, type_wrapper<T>(),
op_wrapper<mpi_op__>(), root__, this->native()));
} else {
CALL_MPI(MPI_Reduce, (buffer__, NULL, count__, type_wrapper<T>::kind(),
op_wrapper<mpi_op__>::kind(), root__, this->native()));
CALL_MPI(MPI_Reduce, (buffer__, NULL, count__, type_wrapper<T>(),
op_wrapper<mpi_op__>(), root__, this->native()));
}
}

template <typename T, op_t mpi_op__ = op_t::sum>
inline void reduce(T* buffer__, int count__, int root__, MPI_Request* req__) const
{
if (root__ == rank()) {
CALL_MPI(MPI_Ireduce, (MPI_IN_PLACE, buffer__, count__, type_wrapper<T>::kind(),
op_wrapper<mpi_op__>::kind(), root__, this->native(), req__));
CALL_MPI(MPI_Ireduce, (MPI_IN_PLACE, buffer__, count__, type_wrapper<T>(),
op_wrapper<mpi_op__>(), root__, this->native(), req__));
} else {
CALL_MPI(MPI_Ireduce, (buffer__, NULL, count__, type_wrapper<T>::kind(),
op_wrapper<mpi_op__>::kind(), root__, this->native(), req__));
CALL_MPI(MPI_Ireduce, (buffer__, NULL, count__, type_wrapper<T>(),
op_wrapper<mpi_op__>(), root__, this->native(), req__));
}
}

template <typename T, op_t mpi_op__ = op_t::sum>
void reduce(T const* sendbuf__, T* recvbuf__, int count__, int root__) const
{
CALL_MPI(MPI_Reduce, (sendbuf__, recvbuf__, count__, type_wrapper<T>::kind(),
op_wrapper<mpi_op__>::kind(), root__, this->native()));
CALL_MPI(MPI_Reduce, (sendbuf__, recvbuf__, count__, type_wrapper<T>(),
op_wrapper<mpi_op__>(), root__, this->native()));
}

template <typename T, op_t mpi_op__ = op_t::sum>
void reduce(T const* sendbuf__, T* recvbuf__, int count__, int root__, MPI_Request* req__) const
{
CALL_MPI(MPI_Ireduce, (sendbuf__, recvbuf__, count__, type_wrapper<T>::kind(),
op_wrapper<mpi_op__>::kind(), root__, this->native(), req__));
CALL_MPI(MPI_Ireduce, (sendbuf__, recvbuf__, count__, type_wrapper<T>(),
op_wrapper<mpi_op__>(), root__, this->native(), req__));
}

/// Perform the in-place (the output buffer is used as the input buffer) all-to-all reduction.
template <typename T, op_t mpi_op__ = op_t::sum>
inline void allreduce(T* buffer__, int count__) const
{
CALL_MPI(MPI_Allreduce, (MPI_IN_PLACE, buffer__, count__, type_wrapper<T>::kind(),
op_wrapper<mpi_op__>::kind(), this->native()));
CALL_MPI(MPI_Allreduce, (MPI_IN_PLACE, buffer__, count__, type_wrapper<T>(),
op_wrapper<mpi_op__>(), this->native()));
}

/// Perform the in-place (the output buffer is used as the input buffer) all-to-all reduction.
Expand All @@ -476,8 +476,8 @@ class Communicator
#if defined(__PROFILE_MPI)
PROFILE("MPI_Iallreduce");
#endif
CALL_MPI(MPI_Iallreduce, (MPI_IN_PLACE, buffer__, count__, type_wrapper<T>::kind(),
op_wrapper<mpi_op__>::kind(), this->native(), req__));
CALL_MPI(MPI_Iallreduce, (MPI_IN_PLACE, buffer__, count__, type_wrapper<T>(),
op_wrapper<mpi_op__>(), this->native(), req__));
}

/// Perform buffer broadcast.
Expand All @@ -487,7 +487,7 @@ class Communicator
#if defined(__PROFILE_MPI)
PROFILE("MPI_Bcast");
#endif
CALL_MPI(MPI_Bcast, (buffer__, count__, type_wrapper<T>::kind(), root__, this->native()));
CALL_MPI(MPI_Bcast, (buffer__, count__, type_wrapper<T>(), root__, this->native()));
}

inline void bcast(std::string& str__, int root__) const
Expand Down Expand Up @@ -515,7 +515,7 @@ class Communicator
PROFILE("MPI_Allgatherv");
#endif
CALL_MPI(MPI_Allgatherv, (MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, buffer__, recvcounts__, displs__,
type_wrapper<T>::kind(), this->native()));
type_wrapper<T>(), this->native()));
}

/// Out-of-place MPI_Allgatherv.
Expand All @@ -526,8 +526,8 @@ class Communicator
#if defined(__PROFILE_MPI)
PROFILE("MPI_Allgatherv");
#endif
CALL_MPI(MPI_Allgatherv, (sendbuf__, sendcount__, type_wrapper<T>::kind(), recvbuf__, recvcounts__,
displs__, type_wrapper<T>::kind(), this->native()));
CALL_MPI(MPI_Allgatherv, (sendbuf__, sendcount__, type_wrapper<T>(), recvbuf__, recvcounts__,
displs__, type_wrapper<T>(), this->native()));
}

template <typename T>
Expand All @@ -539,7 +539,7 @@ class Communicator
v[2 * rank() + 1] = displs__;

CALL_MPI(MPI_Allgather,
(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, v.data(), 2, type_wrapper<int>::kind(), this->native()));
(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, v.data(), 2, type_wrapper<int>(), this->native()));

std::vector<int> counts(size());
std::vector<int> displs(size());
Expand All @@ -549,8 +549,8 @@ class Communicator
displs[i] = v[2 * i + 1];
}

CALL_MPI(MPI_Allgatherv, (sendbuf__, count__, type_wrapper<T>::kind(), recvbuf__, counts.data(),
displs.data(), type_wrapper<T>::kind(), this->native()));
CALL_MPI(MPI_Allgatherv, (sendbuf__, count__, type_wrapper<T>(), recvbuf__, counts.data(),
displs.data(), type_wrapper<T>(), this->native()));
}

/// In-place MPI_Allgatherv.
Expand All @@ -563,7 +563,7 @@ class Communicator
v[2 * rank() + 1] = displs__;

CALL_MPI(MPI_Allgather,
(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, v.data(), 2, type_wrapper<int>::kind(), this->native()));
(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, v.data(), 2, type_wrapper<int>(), this->native()));

std::vector<int> counts(size());
std::vector<int> displs(size());
Expand All @@ -581,7 +581,7 @@ class Communicator
#if defined(__PROFILE_MPI)
PROFILE("MPI_Send");
#endif
CALL_MPI(MPI_Send, (buffer__, count__, type_wrapper<T>::kind(), dest__, tag__, this->native()));
CALL_MPI(MPI_Send, (buffer__, count__, type_wrapper<T>(), dest__, tag__, this->native()));
}

template <typename T>
Expand All @@ -591,7 +591,7 @@ class Communicator
#if defined(__PROFILE_MPI)
PROFILE("MPI_Isend");
#endif
CALL_MPI(MPI_Isend, (buffer__, count__, type_wrapper<T>::kind(), dest__, tag__, this->native(), &req.handler()));
CALL_MPI(MPI_Isend, (buffer__, count__, type_wrapper<T>(), dest__, tag__, this->native(), &req.handler()));
return req;
}

Expand All @@ -602,7 +602,7 @@ class Communicator
PROFILE("MPI_Recv");
#endif
CALL_MPI(MPI_Recv,
(buffer__, count__, type_wrapper<T>::kind(), source__, tag__, this->native(), MPI_STATUS_IGNORE));
(buffer__, count__, type_wrapper<T>(), source__, tag__, this->native(), MPI_STATUS_IGNORE));
}

template <typename T>
Expand All @@ -612,7 +612,7 @@ class Communicator
#if defined(__PROFILE_MPI)
PROFILE("MPI_Irecv");
#endif
CALL_MPI(MPI_Irecv, (buffer__, count__, type_wrapper<T>::kind(), source__, tag__, this->native(), &req.handler()));
CALL_MPI(MPI_Irecv, (buffer__, count__, type_wrapper<T>(), source__, tag__, this->native(), &req.handler()));
return req;
}

Expand All @@ -624,8 +624,8 @@ class Communicator
#if defined(__PROFILE_MPI)
PROFILE("MPI_Gatherv");
#endif
CALL_MPI(MPI_Gatherv, (sendbuf__, sendcount, type_wrapper<T>::kind(), recvbuf__, recvcounts__, displs__,
type_wrapper<T>::kind(), root__, this->native()));
CALL_MPI(MPI_Gatherv, (sendbuf__, sendcount, type_wrapper<T>(), recvbuf__, recvcounts__, displs__,
type_wrapper<T>(), root__, this->native()));
}

/// Gather data on a given rank.
Expand All @@ -641,7 +641,7 @@ class Communicator
v[2 * rank() + 1] = offset__;

CALL_MPI(MPI_Allgather,
(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, v.data(), 2, type_wrapper<int>::kind(), this->native()));
(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, v.data(), 2, type_wrapper<int>(), this->native()));

std::vector<int> counts(size());
std::vector<int> offsets(size());
Expand All @@ -650,8 +650,8 @@ class Communicator
counts[i] = v[2 * i];
offsets[i] = v[2 * i + 1];
}
CALL_MPI(MPI_Gatherv, (sendbuf__, count__, type_wrapper<T>::kind(), recvbuf__, counts.data(),
offsets.data(), type_wrapper<T>::kind(), root__, this->native()));
CALL_MPI(MPI_Gatherv, (sendbuf__, count__, type_wrapper<T>(), recvbuf__, counts.data(),
offsets.data(), type_wrapper<T>(), root__, this->native()));
}

template <typename T>
Expand All @@ -661,8 +661,8 @@ class Communicator
PROFILE("MPI_Scatterv");
#endif
int recvcount = sendcounts__[rank()];
CALL_MPI(MPI_Scatterv, (sendbuf__, sendcounts__, displs__, type_wrapper<T>::kind(), recvbuf__, recvcount,
type_wrapper<T>::kind(), root__, this->native()));
CALL_MPI(MPI_Scatterv, (sendbuf__, sendcounts__, displs__, type_wrapper<T>(), recvbuf__, recvcount,
type_wrapper<T>(), root__, this->native()));
}

template <typename T>
Expand All @@ -671,8 +671,8 @@ class Communicator
#if defined(__PROFILE_MPI)
PROFILE("MPI_Alltoall");
#endif
CALL_MPI(MPI_Alltoall, (sendbuf__, sendcounts__, type_wrapper<T>::kind(), recvbuf__, recvcounts__,
type_wrapper<T>::kind(), this->native()));
CALL_MPI(MPI_Alltoall, (sendbuf__, sendcounts__, type_wrapper<T>(), recvbuf__, recvcounts__,
type_wrapper<T>(), this->native()));
}

template <typename T>
Expand All @@ -682,8 +682,8 @@ class Communicator
#if defined(__PROFILE_MPI)
PROFILE("MPI_Alltoallv");
#endif
CALL_MPI(MPI_Alltoallv, (sendbuf__, sendcounts__, sdispls__, type_wrapper<T>::kind(), recvbuf__,
recvcounts__, rdispls__, type_wrapper<T>::kind(), this->native()));
CALL_MPI(MPI_Alltoallv, (sendbuf__, sendcounts__, sdispls__, type_wrapper<T>(), recvbuf__,
recvcounts__, rdispls__, type_wrapper<T>(), this->native()));
}
};

Expand Down