diff --git a/torch/nn/modules/_functions.py b/torch/nn/modules/_functions.py index 464c56a548a657..099c2b164daa61 100644 --- a/torch/nn/modules/_functions.py +++ b/torch/nn/modules/_functions.py @@ -56,7 +56,7 @@ def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, combined_size * world_size, dtype=combined.dtype, device=combined.device) - dist._all_gather_base(combined_flat, combined, process_group, async_op=False) + dist.all_gather_into_tensor(combined_flat, combined, process_group, async_op=False) combined = torch.reshape(combined_flat, (world_size, combined_size)) # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)