Skip to content

Commit 9afaf02

Browse files
sryapfacebook-github-bot
authored andcommitted
Fix the sync point caused by iter_cpu.item() (pytorch#489)
Summary: X-link: pytorch#3401 Pull Request resolved: facebookresearch/FBGEMM#489 Although `self.iter_cpu` is created on CPU. It might be transferred to GPU by the user. So, we need to transfer it to CPU explicitly. This should be done only once. Reviewed By: csmiler Differential Revision: D66311970 fbshipit-source-id: 4ced9b28f2e6c69fc12ccdc7e88404caf8276627
1 parent 67c870e commit 9afaf02

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1824,6 +1824,12 @@ def forward( # noqa: C901
18241824
# `Union[Module, Tensor]`.
18251825
placements=self.momentum2_placements,
18261826
)
1827+
1828+
# Although self.iter_cpu is created on CPU. It might be transferred to
1829+
# GPU by the user. So, we need to transfer it to CPU explicitly. This
1830+
# should be done only once.
1831+
self.iter_cpu = self.iter_cpu.cpu()
1832+
18271833
# Sync with loaded state
18281834
if (
18291835
not is_torchdynamo_compiling()

0 commit comments

Comments
 (0)