From d4b62fec3ab1fb865926d40f42fd3972efb95706 Mon Sep 17 00:00:00 2001 From: Min Si Date: Wed, 17 Jan 2024 19:23:15 -0800 Subject: [PATCH] fix getPingLatency in pt2pt mode Summary: D52715647 incorrectly used P2POp in blocking `getPingLatency`/`getPingPongLatency`. P2POp enqueues ops for batch_isend_recv, and need additional call to batch_isend_recv to kick off enqueued ops. Thus, the previous version caused incorrect latency, since we just enqueued but didn't issue. This patch fixes it by using send/recv API. Reviewed By: kingchc Differential Revision: D52721469 fbshipit-source-id: 9fd44654f1fc97729ff9c333257b59938163f259 --- train/comms/pt/comms.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/train/comms/pt/comms.py b/train/comms/pt/comms.py index cd3ad13b..049d4121 100755 --- a/train/comms/pt/comms.py +++ b/train/comms/pt/comms.py @@ -630,8 +630,7 @@ def getPingLatency(self, numIters): self.collectiveArgs.global_rank ) self.collectiveArgs.dst_rank = self.collectiveArgs.dst_ranks[idx] - self.collectiveArgs.collective = "send" - self.backendFuncs.P2POp( + self.backendFuncs.send( collectiveArgs=self.collectiveArgs, ) elif self.collectiveArgs.global_rank in self.collectiveArgs.dst_ranks: @@ -639,8 +638,7 @@ def getPingLatency(self, numIters): self.collectiveArgs.global_rank ) self.collectiveArgs.src_rank = self.collectiveArgs.src_ranks[idx] - self.collectiveArgs.collective = "recv" - self.backendFuncs.P2POp( + self.backendFuncs.recv( collectiveArgs=self.collectiveArgs, ) self.backendFuncs.complete_accel_ops(self.collectiveArgs) @@ -666,13 +664,11 @@ def getPingPongLatency(self, numIters): self.collectiveArgs.global_rank ) self.collectiveArgs.dst_rank = self.collectiveArgs.dst_ranks[idx] - self.collectiveArgs.collective = "send" - self.backendFuncs.P2POp( + self.backendFuncs.send( collectiveArgs=self.collectiveArgs, ) self.collectiveArgs.src_rank = self.collectiveArgs.dst_ranks[idx] - self.collectiveArgs.collective = "recv" - self.backendFuncs.P2POp( + self.backendFuncs.recv( collectiveArgs=self.collectiveArgs, ) elif self.collectiveArgs.global_rank in self.collectiveArgs.dst_ranks: @@ -680,13 +676,11 @@ def getPingPongLatency(self, numIters): self.collectiveArgs.global_rank ) self.collectiveArgs.src_rank = self.collectiveArgs.src_ranks[idx] - self.collectiveArgs.collective = "recv" - self.backendFuncs.P2POp( + self.backendFuncs.recv( collectiveArgs=self.collectiveArgs, ) self.collectiveArgs.dst_rank = self.collectiveArgs.src_ranks[idx] - self.collectiveArgs.collective = "send" - self.backendFuncs.P2POp( + self.backendFuncs.send( collectiveArgs=self.collectiveArgs, ) self.backendFuncs.complete_accel_ops(self.collectiveArgs)