|
4 | 4 |
|
5 | 5 | #include <faabric/scheduler/MpiWorld.h>
|
6 | 6 | #include <faabric/util/exec_graph.h>
|
| 7 | +#include <faabric/util/json.h> |
7 | 8 | #include <faabric/util/macros.h>
|
8 | 9 |
|
| 10 | +// DELEETEEE ME |
| 11 | +#include <faabric/util/logging.h> |
| 12 | + |
9 | 13 | namespace tests {
|
10 | 14 | TEST_CASE_METHOD(MpiTestFixture,
|
11 | 15 | "Test tracing the number of MPI messages",
|
@@ -74,4 +78,59 @@ TEST_CASE_METHOD(MpiTestFixture,
|
74 | 78 | REQUIRE(msg.intexecgraphdetails_size() == 0);
|
75 | 79 | REQUIRE(msg.execgraphdetails_size() == 0);
|
76 | 80 | }
|
| 81 | + |
| 82 | +TEST_CASE_METHOD(MpiBaseTestFixture, |
| 83 | + "Test different threads populate the graph", |
| 84 | + "[util][exec-graph]") |
| 85 | +{ |
| 86 | + int rank = 0; |
| 87 | + int otherRank = 1; |
| 88 | + int worldSize = 2; |
| 89 | + int worldId = 123; |
| 90 | + |
| 91 | + faabric::Message msg = faabric::util::messageFactory("mpi", "hellompi"); |
| 92 | + msg.set_ismpi(true); |
| 93 | + msg.set_recordexecgraph(true); |
| 94 | + msg.set_mpiworldsize(worldSize); |
| 95 | + msg.set_mpiworldid(worldId); |
| 96 | + |
| 97 | + faabric::Message otherMsg = msg; |
| 98 | + otherMsg.set_mpirank(otherRank); |
| 99 | + msg.set_mpirank(rank); |
| 100 | + |
| 101 | + faabric::scheduler::MpiWorld& thisWorld = |
| 102 | + faabric::scheduler::getMpiWorldRegistry().createWorld(msg, worldId); |
| 103 | + |
| 104 | + std::vector<int> messageData = { 0, 1, 2 }; |
| 105 | + auto buffer = new int[messageData.size()]; |
| 106 | + std::thread otherWorldThread([&messageData, &otherMsg, rank, otherRank] { |
| 107 | + faabric::scheduler::MpiWorld& otherWorld = |
| 108 | + faabric::scheduler::getMpiWorldRegistry().getOrInitialiseWorld( |
| 109 | + otherMsg); |
| 110 | + |
| 111 | + otherWorld.send(otherRank, |
| 112 | + rank, |
| 113 | + BYTES(messageData.data()), |
| 114 | + MPI_INT, |
| 115 | + messageData.size()); |
| 116 | + |
| 117 | + otherWorld.destroy(); |
| 118 | + }); |
| 119 | + |
| 120 | + thisWorld.recv( |
| 121 | + otherRank, rank, BYTES(buffer), MPI_INT, messageData.size(), nullptr); |
| 122 | + |
| 123 | + thisWorld.destroy(); |
| 124 | + |
| 125 | + if (otherWorldThread.joinable()) { |
| 126 | + otherWorldThread.join(); |
| 127 | + } |
| 128 | + |
| 129 | + std::string expectedKey = |
| 130 | + faabric::util::exec_graph::mpiMsgCountPrefix + std::to_string(rank); |
| 131 | + REQUIRE(otherMsg.mpirank() == otherRank); |
| 132 | + REQUIRE(otherMsg.intexecgraphdetails_size() == 1); |
| 133 | + REQUIRE(otherMsg.intexecgraphdetails().count(expectedKey) == 1); |
| 134 | + REQUIRE(otherMsg.intexecgraphdetails().at(expectedKey) == 1); |
| 135 | +} |
77 | 136 | }
|
0 commit comments