@@ -250,6 +250,8 @@ Zero :: ~Zero()
250250inline void Zero :: tryIncrement(const Op op, const Phase phase,
251251 const TagID tag) noexcept
252252{
253+ if (tag == LPF_MSG_DEFAULT) return ;
254+
253255 switch (phase) {
254256 case Phase::INIT:
255257 // dynamically increase the capacity
@@ -766,13 +768,10 @@ void Zero :: put( SlotID srcSlot, size_t srcOffset,
766768 // we only need a signal from the last message in the queue
767769 sr->send_flags = lastMsg ? IBV_SEND_SIGNALED : 0 ;
768770 sr->opcode = lastMsg? IBV_WR_RDMA_WRITE_WITH_IMM : IBV_WR_RDMA_WRITE;
769- /* use wr_id to later demultiplex srcSlot */
770- sr->wr_id = attr; // srcSlot;
771- /*
772- * In HiCR, we need to know at receiver end which slot
773- * has received the message. But here is a trick:
774- */
775- sr->imm_data = attr; // dstSlot;
771+ // use wr_id to store the comm tag (passed as attr)
772+ sr->wr_id = attr;
773+ // use wr_id to store the comm tag (passed as attr)
774+ sr->imm_data = attr;
776775
777776 sr->sg_list = &sges[i];
778777 sr->num_sge = 1 ;
@@ -785,7 +784,7 @@ void Zero :: put( SlotID srcSlot, size_t srcOffset,
785784 dstOffset += sge->length ;
786785
787786 LOG (4 , " PID " << m_pid << " : Enqueued put message of " << sge->length
788- << " bytes to " << dstPid << " on slot" << dstSlot );
787+ << " bytes to " << dstPid << " on slot" << dstSlot << " and tag " << attr );
789788 }
790789 struct ibv_send_wr *bad_wr = NULL ;
791790 // srs[0] should be sufficient because the rest of srs are on a chain
@@ -910,7 +909,7 @@ void Zero :: doLocalProgress(int& error) {
910909 << wcs[i].vendor_err );
911910 const char * status_descr;
912911 status_descr = ibv_wc_status_str (wcs[i].status );
913- LOG ( 2 , " The work completion status string: " << status_descr);
912+ LOG ( 2 , " Process " << m_pid << " : The work completion status string: " << status_descr);
914913 error = 1 ;
915914 }
916915 else {
@@ -932,7 +931,7 @@ void Zero :: doLocalProgress(int& error) {
932931 // This is a put call completing
933932 if (wcs[i].opcode == IBV_WC_RDMA_WRITE) {
934933 tryIncrement (Op::SEND, Phase::POST, slot);
935- LOG (4 , " Rank " << m_pid << " with SEND, increments getMsgCount to "
934+ LOG (4 , " Rank " << m_pid << " with SEND, increments sentMsgCount to "
936935 << sentMsgCount[slot] << " for LPF slot " << slot);
937936 }
938937
@@ -982,30 +981,48 @@ void Zero :: countingSyncPerSlot(const TagID tag, const size_t expectedSent,
982981 if (expectedSent == 0 ) { sentOK = true ; }
983982 if (expectedRecvd == 0 ) { recvdOK = true ; }
984983 int error;
985- if (tagActive[tag]) {
986- do {
987- doLocalProgress (error);
988- if (error) {
989- LOG (1 , " Error in doLocalProgress" );
990- throw std::runtime_error (" Error in doLocalProgress" );
991- }
992- // this call triggers doRemoteProgress
993- doRemoteProgress ();
994-
995- /*
996- * 1) Are we expecting nothing here (sentOK/recvdOK = true)
997- * 2) do the sent and received messages match our expectations?
998- */
999- sentOK = (sentOK || sentMsgCount[tag] >= expectedSent);
1000- // We can receive messages passively (from remote puts) and actively (from our gets)
1001- recvdOK = (recvdOK || (rcvdMsgCount[tag] + getMsgCount[tag]) >= expectedRecvd);
1002- LOG (4 , " PID: " << m_pid << " rcvdMsgCount[" << tag << " ] = " << rcvdMsgCount[tag]
1003- << " expectedRecvd = " << expectedRecvd
1004- << " rcvdMsgCount[" << tag << " ] = " << rcvdMsgCount[tag]
1005- << " getMsgCount[" << tag << " ] = " << getMsgCount[tag]
1006- << " sentMsgCount[" << tag << " ] = " << sentMsgCount[tag]
1007- << " expectedSent = " << expectedSent);
1008- } while (!(sentOK && recvdOK));
984+
985+ // This is semantically equivalent to a non-blocking test call,
986+ // triggering progress on the network card without expecting anything
987+ // from a particular tag
988+ if (tag == INVALID_TAG && sentOK && recvdOK) {
989+ doLocalProgress (error);
990+ if (error) {
991+ LOG (1 , " Error in doLocalProgress" );
992+ throw std::runtime_error (" Error in doLocalProgress" );
993+ }
994+ // this call triggers doRemoteProgress
995+ doRemoteProgress ();
996+ }
997+
998+ // This is a blocking call on a particular tag with some expected
999+ // sent / received messages
1000+ else {
1001+ if (tagActive[tag]) {
1002+ do {
1003+ doLocalProgress (error);
1004+ if (error) {
1005+ LOG (1 , " Error in doLocalProgress" );
1006+ throw std::runtime_error (" Error in doLocalProgress" );
1007+ }
1008+ // this call triggers doRemoteProgress
1009+ doRemoteProgress ();
1010+
1011+ /*
1012+ * 1) Are we expecting nothing here (sentOK/recvdOK = true)
1013+ * 2) do the sent and received messages match our expectations?
1014+ */
1015+ sentOK = (sentOK || sentMsgCount[tag] >= expectedSent);
1016+ // We can receive messages passively (from remote puts) and actively (from our gets)
1017+ recvdOK = (recvdOK || (rcvdMsgCount[tag] + getMsgCount[tag]) >= expectedRecvd);
1018+ LOG (4 , " PID: " << m_pid << " rcvdMsgCount[" << tag << " ] = " << rcvdMsgCount[tag]
1019+ << " expectedRecvd = " << expectedRecvd
1020+ << " rcvdMsgCount[" << tag << " ] = " << rcvdMsgCount[tag]
1021+ << " getMsgCount[" << tag << " ] = " << getMsgCount[tag]
1022+ << " sentMsgCount[" << tag << " ] = " << sentMsgCount[tag]
1023+ << " expectedSent = " << expectedSent);
1024+ } while (!(sentOK && recvdOK));
1025+ }
10091026 }
10101027}
10111028
@@ -1031,33 +1048,35 @@ void Zero :: syncPerTag(TagID tag) {
10311048
10321049void Zero :: sync(bool resized,const struct SyncAttr * attr)
10331050{
1034- const bool defaultSync = attr == nullptr || (attr->tag == INVALID_TAG &&
1035- attr->expected_sent == 0 && attr->expected_rcvd == 0 );
1051+ const bool defaultSync = (attr == nullptr ) ;
10361052 if (defaultSync)
10371053 {
1054+ LOG (4 , " Process " << m_pid << " going for default sync (uses barrier)" );
10381055 (void ) resized;
10391056
10401057 // flush send queues
10411058 flushSent ();
10421059 // flush receive queues
10431060 flushReceived ();
10441061
1045- LOG (4 , " Process " << m_pid << " will call barrier at end of sync\n " );
10461062 m_comm.barrier ();
10471063
10481064 // done
10491065 return ;
10501066 }
10511067
10521068 ASSERT (attr != NULL );
1069+
10531070 const bool tagSync = attr->expected_sent == 0 && attr->expected_rcvd == 0
10541071 && attr->tag != INVALID_TAG;
10551072 if (tagSync)
10561073 {
1074+ LOG (4 , " Process " << m_pid << " going for syncPerTag (uses barrier)" );
10571075 syncPerTag (attr->tag );
10581076 return ;
10591077 }
10601078
1079+ LOG (4 , " Process " << m_pid << " going for countingSync (no barrier!)" );
10611080 countingSyncPerSlot (attr->tag ,attr->expected_sent ,attr->expected_rcvd );
10621081}
10631082
0 commit comments