Skip to content

Commit

Permalink
fix getPingLatency in pt2pt mode
Browse files Browse the repository at this point in the history
Summary:
D52715647 incorrectly used P2POp in blocking `getPingLatency`/`getPingPongLatency`. P2POp enqueues ops for batch_isend_recv, and need additional call to batch_isend_recv to kick off enqueued ops. Thus, the previous version caused incorrect  latency, since we just enqueued but didn't issue.

This patch fixes it by using send/recv API.

Reviewed By: kingchc

Differential Revision: D52721469

fbshipit-source-id: 9fd44654f1fc97729ff9c333257b59938163f259
  • Loading branch information
minsii authored and facebook-github-bot committed Jan 18, 2024
1 parent 2b441ab commit d4b62fe
Showing 1 changed file with 6 additions and 12 deletions.
18 changes: 6 additions & 12 deletions train/comms/pt/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,17 +630,15 @@ def getPingLatency(self, numIters):
self.collectiveArgs.global_rank
)
self.collectiveArgs.dst_rank = self.collectiveArgs.dst_ranks[idx]
self.collectiveArgs.collective = "send"
self.backendFuncs.P2POp(
self.backendFuncs.send(
collectiveArgs=self.collectiveArgs,
)
elif self.collectiveArgs.global_rank in self.collectiveArgs.dst_ranks:
idx = self.collectiveArgs.dst_ranks.index(
self.collectiveArgs.global_rank
)
self.collectiveArgs.src_rank = self.collectiveArgs.src_ranks[idx]
self.collectiveArgs.collective = "recv"
self.backendFuncs.P2POp(
self.backendFuncs.recv(
collectiveArgs=self.collectiveArgs,
)
self.backendFuncs.complete_accel_ops(self.collectiveArgs)
Expand All @@ -666,27 +664,23 @@ def getPingPongLatency(self, numIters):
self.collectiveArgs.global_rank
)
self.collectiveArgs.dst_rank = self.collectiveArgs.dst_ranks[idx]
self.collectiveArgs.collective = "send"
self.backendFuncs.P2POp(
self.backendFuncs.send(
collectiveArgs=self.collectiveArgs,
)
self.collectiveArgs.src_rank = self.collectiveArgs.dst_ranks[idx]
self.collectiveArgs.collective = "recv"
self.backendFuncs.P2POp(
self.backendFuncs.recv(
collectiveArgs=self.collectiveArgs,
)
elif self.collectiveArgs.global_rank in self.collectiveArgs.dst_ranks:
idx = self.collectiveArgs.dst_ranks.index(
self.collectiveArgs.global_rank
)
self.collectiveArgs.src_rank = self.collectiveArgs.src_ranks[idx]
self.collectiveArgs.collective = "recv"
self.backendFuncs.P2POp(
self.backendFuncs.recv(
collectiveArgs=self.collectiveArgs,
)
self.collectiveArgs.dst_rank = self.collectiveArgs.src_ranks[idx]
self.collectiveArgs.collective = "send"
self.backendFuncs.P2POp(
self.backendFuncs.send(
collectiveArgs=self.collectiveArgs,
)
self.backendFuncs.complete_accel_ops(self.collectiveArgs)
Expand Down

0 comments on commit d4b62fe

Please sign in to comment.