Skip to content

Commit 2ec4f69

Browse files
Yi Wangfacebook-github-bot
authored andcommitted
[DDP Comm Hook] Do not expose hook_then_optimizer as a public method (pytorch#62532)
Summary: Pull Request resolved: pytorch#62532 This method is not stable at this time, so avoid releasing it when DDP communication hook feature is released as a stable feature. ghstack-source-id: 134787831 Test Plan: buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_ddp_hook_with_optimizer_parity buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork -- test_hook_then_optimizer_nccl Reviewed By: rohan-varma Differential Revision: D30031222 fbshipit-source-id: e03a8e13fee5116a5ddd724eb76316ee98f2a676
1 parent b161ac5 commit 2ec4f69

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

test/distributed/test_c10d_nccl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,12 +1607,12 @@ def _test_hook_then_optimizer(self, gradient_as_bucket_view=False):
16071607
sgd_lr = 1e-2
16081608
sgd_momentum = 0.9
16091609
sgd_weight_decay = 0.01
1610-
opt_hook_state = default.OptimizerHookState(
1610+
opt_hook_state = default._OptimizerHookState(
16111611
_FunctionalSGD, sgd_lr, momentum=sgd_momentum, weight_decay=sgd_weight_decay
16121612
)
16131613
gpu_model = self._gpu_model_with_ddp_comm_hook(
16141614
process_group,
1615-
default.hook_then_optimizer(hook, opt_hook_state),
1615+
default._hook_then_optimizer(hook, opt_hook_state),
16161616
gradient_as_bucket_view,
16171617
hook_state,
16181618
)

torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def decompress(fut):
7070
return fut.then(decompress)
7171

7272

73-
class OptimizerHookState(object):
73+
class _OptimizerHookState(object):
7474
"""
7575
Holds state for running optimizer in-line after DDP communication hook.
7676
Currently contains only optimizer class which must have a method `step_param`.
@@ -93,11 +93,18 @@ def __init__(
9393
)
9494

9595

96-
def hook_then_optimizer(
96+
# TODO: Add an example to use such a wrapper.
97+
def _hook_then_optimizer(
9798
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]],
98-
optimizer_state: OptimizerHookState,
99+
optimizer_state: _OptimizerHookState,
99100
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
100-
"""Runs optimizer in a functional fashion after DDP communication hook."""
101+
r"""
102+
Runs optimizer in a functional fashion after DDP communication hook.
103+
104+
.. warning ::
105+
This API is experimental adn subject to change.
106+
"""
107+
101108

102109
def hook_then_optimizer_wrapper(
103110
hook_state, bucket: dist.GradBucket

torch/testing/_internal/distributed/distributed_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3865,15 +3865,15 @@ def _test_ddp_hook_with_optimizer_parity(
38653865

38663866
# Register hook that runs allreduce + functional SGD step.
38673867
allreduce_hook = default.allreduce_hook
3868-
opt_hook_state = default.OptimizerHookState(
3868+
opt_hook_state = default._OptimizerHookState(
38693869
_FunctionalSGD,
38703870
sgd_lr,
38713871
momentum=sgd_momentum,
38723872
weight_decay=sgd_weight_decay,
38733873
)
38743874
ddp_model_with_optimizer_hook.register_comm_hook(
38753875
None,
3876-
default.hook_then_optimizer(allreduce_hook, opt_hook_state),
3876+
default._hook_then_optimizer(allreduce_hook, opt_hook_state),
38773877
)
38783878
# Create DDP model with no hook that does optimizer after
38793879
# backward.

0 commit comments

Comments
 (0)