Skip to content

Commit 5cf4542

Browse files
Revert "Enforce explicit ProcessGroup passed into DefaultState (pytorch#84105)"
This reverts commit adc9a1e. Reverted pytorch#84105 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally
1 parent ff23f3a commit 5cf4542

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

test/distributed/fsdp/test_fsdp_comm_hooks.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import torch.nn as nn
88
import torch.nn.functional as F
99
from torch import distributed as dist
10-
from torch.distributed.distributed_c10d import _get_default_group
1110
from torch.distributed.algorithms._comm_hooks import default_hooks
1211
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1312
from torch.distributed.fsdp import MixedPrecision
@@ -424,7 +423,7 @@ def test_fp16_hook(
424423
sharding_strategy: Optional[ShardingStrategy]
425424
):
426425

427-
state = default_hooks.LowPrecisionState(process_group=_get_default_group())
426+
state = default_hooks.LowPrecisionState(process_group=None)
428427
hook = default_hooks.fp16_compress_hook
429428

430429
self._check_low_precision_hook(state, hook, sharding_strategy, torch.float16, has_wrapping)
@@ -451,7 +450,7 @@ def test_bf16_hook(
451450
sharding_strategy: Optional[ShardingStrategy]
452451
):
453452

454-
state = default_hooks.LowPrecisionState(process_group=_get_default_group())
453+
state = default_hooks.LowPrecisionState(process_group=None)
455454
hook = default_hooks.bf16_compress_hook
456455

457456
self._check_low_precision_hook(state, hook, sharding_strategy, torch.bfloat16, has_wrapping)

torch/distributed/algorithms/_comm_hooks/default_hooks.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import functools
22
import torch
33
import torch.distributed as dist
4+
from torch.distributed import distributed_c10d
45

56

67
class DefaultState(object):
@@ -21,11 +22,9 @@ class DefaultState(object):
2122

2223
def __init__(
2324
self,
24-
process_group: dist.ProcessGroup
25+
process_group
2526
):
26-
if process_group is None:
27-
raise ValueError(f"Expected to pass in an explicit ProcessGroup to {self}.")
28-
self.process_group = process_group
27+
self.process_group = process_group if process_group is not None else distributed_c10d._get_default_group()
2928
self.world_size = dist.get_world_size(process_group)
3029
# Setting two factors `self.gradient_predivide_factor`
3130
# and `self.gradient_postdivide_factor` to avoid underflow and overflow

0 commit comments

Comments
 (0)