Skip to content

Commit adc9a1e

Browse files
rohan-varmapytorchmergebot
authored andcommitted
Enforce explicit ProcessGroup passed into DefaultState (pytorch#84105)
Would prefer to enforce that users pass in explicit PG into these state objects when using comm hooks with FSDP, so that it is clear and easy debugable over which processes communication is taking place. Pull Request resolved: pytorch#84105 Approved by: https://github.com/mrshenli, https://github.com/zhaojuanmao
1 parent 092fe71 commit adc9a1e

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

test/distributed/fsdp/test_fsdp_comm_hooks.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch.nn as nn
88
import torch.nn.functional as F
99
from torch import distributed as dist
10+
from torch.distributed.distributed_c10d import _get_default_group
1011
from torch.distributed.algorithms._comm_hooks import default_hooks
1112
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1213
from torch.distributed.fsdp import MixedPrecision
@@ -423,7 +424,7 @@ def test_fp16_hook(
423424
sharding_strategy: Optional[ShardingStrategy]
424425
):
425426

426-
state = default_hooks.LowPrecisionState(process_group=None)
427+
state = default_hooks.LowPrecisionState(process_group=_get_default_group())
427428
hook = default_hooks.fp16_compress_hook
428429

429430
self._check_low_precision_hook(state, hook, sharding_strategy, torch.float16, has_wrapping)
@@ -450,7 +451,7 @@ def test_bf16_hook(
450451
sharding_strategy: Optional[ShardingStrategy]
451452
):
452453

453-
state = default_hooks.LowPrecisionState(process_group=None)
454+
state = default_hooks.LowPrecisionState(process_group=_get_default_group())
454455
hook = default_hooks.bf16_compress_hook
455456

456457
self._check_low_precision_hook(state, hook, sharding_strategy, torch.bfloat16, has_wrapping)

torch/distributed/algorithms/_comm_hooks/default_hooks.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import functools
22
import torch
33
import torch.distributed as dist
4-
from torch.distributed import distributed_c10d
54

65

76
class DefaultState(object):
@@ -22,9 +21,11 @@ class DefaultState(object):
2221

2322
def __init__(
2423
self,
25-
process_group
24+
process_group: dist.ProcessGroup
2625
):
27-
self.process_group = process_group if process_group is not None else distributed_c10d._get_default_group()
26+
if process_group is None:
27+
raise ValueError(f"Expected to pass in an explicit ProcessGroup to {self}.")
28+
self.process_group = process_group
2829
self.world_size = dist.get_world_size(process_group)
2930
# Setting two factors `self.gradient_predivide_factor`
3031
# and `self.gradient_postdivide_factor` to avoid underflow and overflow

0 commit comments

Comments
 (0)