Skip to content

Commit

Permalink
Update SyncBatchNorm _all_gather_base to all_gather_into_tensor (pyto…
Browse files Browse the repository at this point in the history
…rch#89521)

Summary: Fixes pytorch#88568

`_all_gather_base` is deprecated. So replacing its usage with `all_gather_into_tensor`

Test Plan: CI

Differential Revision: D41479983

Pull Request resolved: pytorch#89521
Approved by: https://github.com/wz337
  • Loading branch information
H-Huang authored and pytorchmergebot committed Nov 24, 2022
1 parent 94a88b5 commit 9497552
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torch/nn/modules/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def forward(self, input, weight, bias, running_mean, running_var, eps, momentum,
combined_size * world_size,
dtype=combined.dtype,
device=combined.device)
dist._all_gather_base(combined_flat, combined, process_group, async_op=False)
dist.all_gather_into_tensor(combined_flat, combined, process_group, async_op=False)
combined = torch.reshape(combined_flat, (world_size, combined_size))
# world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
Expand Down

0 comments on commit 9497552

Please sign in to comment.