Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Locality Aware Broadcast #185

Merged
merged 7 commits into from
Dec 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions include/faabric/scheduler/MpiWorld.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@
#define MPI_MSGTYPE_COUNT_PREFIX "mpi-msgtype-torank"

namespace faabric::scheduler {

// -----------------------------------
// Mocking
// -----------------------------------
// MPITOPTP - mocking at the MPI level won't be needed when using the PTP broker
// as the broker already has mocking capabilities
std::vector<faabric::MpiHostsToRanksMessage> getMpiHostsToRanksMessages();

std::vector<std::shared_ptr<faabric::MPIMessage>> getMpiMockedMessages(
int sendRank);
csegarragonz marked this conversation as resolved.
Show resolved Hide resolved

typedef faabric::util::Queue<std::shared_ptr<faabric::MPIMessage>>
InMemoryMpiQueue;

Expand Down Expand Up @@ -76,8 +87,9 @@ class MpiWorld
faabric::MPIMessage::MPIMessageType messageType =
faabric::MPIMessage::NORMAL);

void broadcast(int sendRank,
const uint8_t* buffer,
void broadcast(int rootRank,
int thisRank,
uint8_t* buffer,
faabric_datatype_t* dataType,
int count,
faabric::MPIMessage::MPIMessageType messageType =
Expand Down Expand Up @@ -214,6 +226,14 @@ class MpiWorld
// Track at which host each rank lives
std::vector<std::string> rankHosts;
int getIndexForRanks(int sendRank, int recvRank);
std::vector<int> getRanksForHost(const std::string& host);

// Track ranks that are local to this world, and local/remote leaders
// MPITOPTP - this information exists in the broker
int localLeader;
std::vector<int> localRanks;
std::vector<int> remoteLeaders;
void initLocalRemoteLeaders();

// In-memory queues for local messaging
std::vector<std::shared_ptr<InMemoryMpiQueue>> localQueues;
Expand Down
22 changes: 8 additions & 14 deletions src/mpi_native/mpi_native.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,20 +212,14 @@ int MPI_Bcast(void* buffer,
faabric::scheduler::MpiWorld& world = getExecutingWorld();

int rank = executingContext.getRank();
if (rank == root) {
SPDLOG_DEBUG(fmt::format("MPI_Bcast {} -> all", rank));
world.broadcast(
rank, (uint8_t*)buffer, datatype, count, faabric::MPIMessage::NORMAL);
} else {
SPDLOG_DEBUG(fmt::format("MPI_Bcast {} <- {}", rank, root));
world.recv(root,
rank,
(uint8_t*)buffer,
datatype,
count,
nullptr,
faabric::MPIMessage::NORMAL);
}
SPDLOG_DEBUG("MPI_Bcast {} -> all", rank);
world.broadcast(root,
rank,
(uint8_t*)buffer,
datatype,
count,
faabric::MPIMessage::BROADCAST);
csegarragonz marked this conversation as resolved.
Show resolved Hide resolved

return MPI_SUCCESS;
}

Expand Down
1 change: 1 addition & 0 deletions src/proto/faabric.proto
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ message MPIMessage {
ALLREDUCE = 8;
ALLTOALL = 9;
SENDRECV = 10;
BROADCAST = 11;
};

MPIMessageType messageType = 1;
Expand Down
206 changes: 147 additions & 59 deletions src/scheduler/MpiWorld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,32 @@ static thread_local std::unordered_map<
// Id of the message that created this thread-local instance
static thread_local faabric::Message* thisRankMsg = nullptr;

// This is used for mocking in tests
namespace faabric::scheduler {

// -----------------------------------
// Mocking
// -----------------------------------
static std::mutex mockMutex;

static std::vector<faabric::MpiHostsToRanksMessage> rankMessages;

namespace faabric::scheduler {
// The identifier in this map is the sending rank. For the receiver's rank
// we can inspect the MPIMessage object
static std::map<int, std::vector<std::shared_ptr<faabric::MPIMessage>>>
mpiMockedMessages;

std::vector<faabric::MpiHostsToRanksMessage> getMpiHostsToRanksMessages()
{
faabric::util::UniqueLock lock(mockMutex);
return rankMessages;
}

std::vector<std::shared_ptr<faabric::MPIMessage>> getMpiMockedMessages(
int sendRank)
{
faabric::util::UniqueLock lock(mockMutex);
return mpiMockedMessages[sendRank];
}

MpiWorld::MpiWorld()
: thisHost(faabric::util::getSystemConfig().endpointHost)
Expand Down Expand Up @@ -223,6 +245,12 @@ void MpiWorld::create(faabric::Message& call, int newId, int newSize)
rankHosts = executedAt;
basePorts = initLocalBasePorts(executedAt);

// Record which ranks are local to this world, and query for all leaders
initLocalRemoteLeaders();
// Given that we are initialising the whole MpiWorld here, the local leader
// should also be rank 0
assert(localLeader == 0);

// Initialise the memory queues for message reception
initLocalQueues();
}
Expand Down Expand Up @@ -295,6 +323,10 @@ void MpiWorld::destroy()
iSendRequests.size());
throw std::runtime_error("Destroying world with outstanding requests");
}

// Clear structures used for mocking
rankMessages.clear();
mpiMockedMessages.clear();
}

void MpiWorld::initialiseFromMsg(faabric::Message& msg)
Expand Down Expand Up @@ -322,6 +354,9 @@ void MpiWorld::initialiseFromMsg(faabric::Message& msg)
basePorts = { hostRankMsg.baseports().begin(),
hostRankMsg.baseports().end() };

// Record which ranks are local to this world, and query for all leaders
initLocalRemoteLeaders();

// Initialise the memory queues for message reception
initLocalQueues();
}
Expand All @@ -344,6 +379,40 @@ std::string MpiWorld::getHostForRank(int rank)
return host;
}

std::vector<int> MpiWorld::getRanksForHost(const std::string& host)
{
assert(rankHosts.size() == size);

std::vector<int> ranksForHost;
for (int i = 0; i < rankHosts.size(); i++) {
if (rankHosts.at(i) == host) {
ranksForHost.push_back(i);
}
}

return ranksForHost;
}

// The local leader for an MPI world is defined as the lowest rank assigned to
// this host
void MpiWorld::initLocalRemoteLeaders()
{
std::set<std::string> uniqueHosts(rankHosts.begin(), rankHosts.end());

for (const std::string& host : uniqueHosts) {
auto ranksInHost = getRanksForHost(host);
// Persist the ranks that are colocated in this host for further use
if (host == thisHost) {
localRanks = ranksInHost;
localLeader =
*std::min_element(ranksInHost.begin(), ranksInHost.end());
} else {
remoteLeaders.push_back(
*std::min_element(ranksInHost.begin(), ranksInHost.end()));
}
}
}

// Returns a pair (sendPort, recvPort)
// To assign the send and recv ports, we follow a protocol establishing:
// 1) Port range (offset) corresponding to the world that receives
Expand Down Expand Up @@ -580,6 +649,12 @@ void MpiWorld::send(int sendRank,
m->set_buffer(buffer, dataType->size * count);
}

// Mock the message sending in tests
if (faabric::util::isMockMode()) {
mpiMockedMessages[sendRank].push_back(m);
return;
}

// Dispatch the message locally or globally
if (isLocal) {
SPDLOG_TRACE("MPI - send {} -> {}", sendRank, recvRank);
Expand Down Expand Up @@ -616,6 +691,11 @@ void MpiWorld::recv(int sendRank,
// Sanity-check input parameters
checkRanksRange(sendRank, recvRank);

// If mocking the messages, ignore calls to receive that may block
if (faabric::util::isMockMode()) {
return;
}

// Recv message from underlying transport
std::shared_ptr<faabric::MPIMessage> m =
recvBatchReturnLast(sendRank, recvRank);
Expand Down Expand Up @@ -699,21 +779,59 @@ void MpiWorld::sendRecv(uint8_t* sendBuffer,
}

void MpiWorld::broadcast(int sendRank,
const uint8_t* buffer,
int recvRank,
uint8_t* buffer,
faabric_datatype_t* dataType,
int count,
faabric::MPIMessage::MPIMessageType messageType)
{
SPDLOG_TRACE("MPI - bcast {} -> all", sendRank);

for (int r = 0; r < size; r++) {
// Skip this rank (it's doing the broadcasting)
if (r == sendRank) {
continue;
if (recvRank == sendRank) {
// The sending rank sends a message to all local ranks in the broadcast,
// and all remote leaders
for (const int localRecvRank : localRanks) {
if (localRecvRank == recvRank) {
continue;
}

send(recvRank, localRecvRank, buffer, dataType, count, messageType);
}

for (const int remoteRecvRank : remoteLeaders) {
send(
recvRank, remoteRecvRank, buffer, dataType, count, messageType);
}
} else if (recvRank == localLeader) {
// If we are the local leader, first we receive the message sent by
// the sending rank
recv(sendRank, recvRank, buffer, dataType, count, nullptr, messageType);

// If the broadcast originated locally, we are done. If not, we now
// distribute to all our local ranks
if (getHostForRank(sendRank) != thisHost) {
for (const int localRecvRank : localRanks) {
if (localRecvRank == recvRank) {
continue;
}

// Send to the other ranks
send(sendRank, r, buffer, dataType, count, messageType);
send(recvRank,
localRecvRank,
buffer,
dataType,
count,
messageType);
}
}
} else {
// If we are neither the sending rank nor a local leader, we receive
// from either our leader master if the broadcast originated in a
// different host, or the sending rank itself if we are on the same host
int sendingRank =
getHostForRank(sendRank) == thisHost ? sendRank : localLeader;

recv(
sendingRank, recvRank, buffer, dataType, count, nullptr, messageType);
}
}

Expand Down Expand Up @@ -874,23 +992,14 @@ void MpiWorld::allGather(int rank,
// Note that sendCount and recvCount here are per-rank, so we need to work
// out the full buffer size
int fullCount = recvCount * size;
if (rank == root) {
// Broadcast the result
broadcast(root,
recvBuffer,
recvType,
fullCount,
faabric::MPIMessage::ALLGATHER);
} else {
// Await the broadcast from the master
recv(root,
rank,
recvBuffer,
recvType,
fullCount,
nullptr,
faabric::MPIMessage::ALLGATHER);
}

// Do a broadcast with a hard-coded root
broadcast(root,
rank,
recvBuffer,
recvType,
fullCount,
faabric::MPIMessage::ALLGATHER);
}

void MpiWorld::awaitAsyncRequest(int requestId)
Expand Down Expand Up @@ -1006,26 +1115,12 @@ void MpiWorld::allReduce(int rank,
faabric_op_t* operation)
{
// Rank 0 coordinates the allreduce operation
if (rank == 0) {
// Run the standard reduce
reduce(0, 0, sendBuffer, recvBuffer, datatype, count, operation);

// Broadcast the result
broadcast(
0, recvBuffer, datatype, count, faabric::MPIMessage::ALLREDUCE);
} else {
// Run the standard reduce
reduce(rank, 0, sendBuffer, recvBuffer, datatype, count, operation);
// First, all ranks reduce to rank 0
reduce(rank, 0, sendBuffer, recvBuffer, datatype, count, operation);

// Await the broadcast from the master
recv(0,
rank,
recvBuffer,
datatype,
count,
nullptr,
faabric::MPIMessage::ALLREDUCE);
}
// Second, 0 broadcasts the result to all ranks
broadcast(
0, rank, recvBuffer, datatype, count, faabric::MPIMessage::ALLREDUCE);
}

void MpiWorld::op_reduce(faabric_op_t* operation,
Expand Down Expand Up @@ -1244,8 +1339,9 @@ void MpiWorld::probe(int sendRank, int recvRank, MPI_Status* status)

void MpiWorld::barrier(int thisRank)
{
// Rank 0 coordinates the barrier operation
if (thisRank == 0) {
// This is the root, hence just does the waiting
// This is the root, hence waits for all ranks to get to the barrier
SPDLOG_TRACE("MPI - barrier init {}", thisRank);

// Await messages from all others
Expand All @@ -1255,25 +1351,17 @@ void MpiWorld::barrier(int thisRank)
r, 0, nullptr, MPI_INT, 0, &s, faabric::MPIMessage::BARRIER_JOIN);
SPDLOG_TRACE("MPI - recv barrier join {}", s.MPI_SOURCE);
}

// Broadcast that the barrier is done
broadcast(0, nullptr, MPI_INT, 0, faabric::MPIMessage::BARRIER_DONE);
} else {
// Tell the root that we're waiting
SPDLOG_TRACE("MPI - barrier join {}", thisRank);
send(
thisRank, 0, nullptr, MPI_INT, 0, faabric::MPIMessage::BARRIER_JOIN);

// Receive a message saying the barrier is done
recv(0,
thisRank,
nullptr,
MPI_INT,
0,
nullptr,
faabric::MPIMessage::BARRIER_DONE);
SPDLOG_TRACE("MPI - barrier done {}", thisRank);
}

// Rank 0 broadcasts that the barrier is done (the others block here)
broadcast(
0, thisRank, nullptr, MPI_INT, 0, faabric::MPIMessage::BARRIER_DONE);
SPDLOG_TRACE("MPI - barrier done {}", thisRank);
}

std::shared_ptr<InMemoryMpiQueue> MpiWorld::getLocalQueue(int sendRank,
Expand Down
Loading