Skip to content

Commit

Permalink
[BE][Easy] enable UFMT for torch/distributed/{fsdp,optim,rpc}/ (pyt…
Browse files Browse the repository at this point in the history
…orch#128869)

Part of pytorch#123062

- pytorch#123062

Pull Request resolved: pytorch#128869
Approved by: https://github.com/fegin
ghstack dependencies: pytorch#128868
  • Loading branch information
XuehaiPan authored and pytorchmergebot committed Jun 18, 2024
1 parent cec3105 commit 3b798df
Show file tree
Hide file tree
Showing 41 changed files with 300 additions and 213 deletions.
27 changes: 0 additions & 27 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -1413,35 +1413,8 @@ exclude_patterns = [
'torch/distributed/nn/jit/instantiator.py',
'torch/distributed/nn/jit/templates/__init__.py',
'torch/distributed/nn/jit/templates/remote_module_template.py',
'torch/distributed/optim/__init__.py',
'torch/distributed/optim/apply_optimizer_in_backward.py',
'torch/distributed/optim/functional_adadelta.py',
'torch/distributed/optim/functional_adagrad.py',
'torch/distributed/optim/functional_adam.py',
'torch/distributed/optim/functional_adamax.py',
'torch/distributed/optim/functional_adamw.py',
'torch/distributed/optim/functional_rmsprop.py',
'torch/distributed/optim/functional_rprop.py',
'torch/distributed/optim/functional_sgd.py',
'torch/distributed/optim/named_optimizer.py',
'torch/distributed/optim/optimizer.py',
'torch/distributed/optim/post_localSGD_optimizer.py',
'torch/distributed/optim/utils.py',
'torch/distributed/optim/zero_redundancy_optimizer.py',
'torch/distributed/remote_device.py',
'torch/distributed/rendezvous.py',
'torch/distributed/rpc/__init__.py',
'torch/distributed/rpc/_testing/__init__.py',
'torch/distributed/rpc/_testing/faulty_agent_backend_registry.py',
'torch/distributed/rpc/_utils.py',
'torch/distributed/rpc/api.py',
'torch/distributed/rpc/backend_registry.py',
'torch/distributed/rpc/constants.py',
'torch/distributed/rpc/functions.py',
'torch/distributed/rpc/internal.py',
'torch/distributed/rpc/options.py',
'torch/distributed/rpc/rref_proxy.py',
'torch/distributed/rpc/server_process_global_profiler.py',
'torch/distributed/run.py',
'torch/fft/__init__.py',
'torch/func/__init__.py',
Expand Down
1 change: 1 addition & 0 deletions torch/distributed/fsdp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
StateDictType,
)


__all__ = [
"BackwardPrefetch",
"CPUOffload",
Expand Down
2 changes: 2 additions & 0 deletions torch/distributed/fsdp/_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@
StateDictType,
)


if TYPE_CHECKING:
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions

from ._flat_param import FlatParamHandle

FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module"
Expand Down
1 change: 1 addition & 0 deletions torch/distributed/fsdp/_debug_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
clean_tensor_name,
)


logger = logging.getLogger(__name__)


Expand Down
1 change: 1 addition & 0 deletions torch/distributed/fsdp/_flat_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
FSDPExtensions,
)


__all__ = [
"FlatParameter",
"FlatParamHandle",
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/fsdp/_init_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@
from torch.distributed.fsdp.wrap import _Policy
from torch.distributed.tensor.parallel.fsdp import DTensorExtensions
from torch.distributed.utils import _sync_params_and_buffers

from torch.utils._python_dispatch import is_traceable_wrapper_subclass


if TYPE_CHECKING:
from torch.utils.hooks import RemovableHandle

Expand Down
1 change: 1 addition & 0 deletions torch/distributed/fsdp/_optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
)
from torch.utils._pytree import tree_map_only


if TYPE_CHECKING:
from torch.distributed._shard.sharded_tensor import ShardedTensor

Expand Down
1 change: 1 addition & 0 deletions torch/distributed/fsdp/_runtime_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from torch.utils import _pytree as pytree


logger = logging.getLogger(__name__)

# Do not include "process_group" to enable hybrid shard and MoE cases
Expand Down
3 changes: 0 additions & 3 deletions torch/distributed/fsdp/_state_dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@

import torch
import torch.distributed as dist

import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as checkpoint_wrapper

import torch.nn as nn
import torch.nn.functional as F
from torch.distributed._shard.sharded_tensor import (
Expand All @@ -29,7 +27,6 @@
)
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import _mesh_resources

from torch.distributed.fsdp._common_utils import (
_FSDPState,
_get_module_fsdp_state_if_fully_sharded_module,
Expand Down
1 change: 1 addition & 0 deletions torch/distributed/fsdp/_unshard_param_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from ._flat_param import FlatParamHandle


FLAT_PARAM = "_flat_param"


Expand Down
1 change: 0 additions & 1 deletion torch/distributed/fsdp/_wrap_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
_get_module_fsdp_state,
_override_module_mixed_precision,
)

from torch.distributed.fsdp.wrap import (
_construct_wrap_fn,
_or_policy,
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/fsdp/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

from dataclasses import dataclass
from enum import auto, Enum

from typing import Optional, Sequence, Type

import torch
from torch.nn.modules.batchnorm import _BatchNorm


__all__ = [
"ShardingStrategy",
"BackwardPrefetch",
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/fsdp/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@
StateDictType,
)
from torch.distributed.utils import _p_assert
from ._flat_param import FlatParameter, FlatParamHandle

from ._flat_param import FlatParameter, FlatParamHandle
from ._optim_utils import (
_flatten_optim_state_dict,
_get_param_id_to_param_from_optim_input,
Expand Down
1 change: 1 addition & 0 deletions torch/distributed/fsdp/sharded_grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.amp.grad_scaler import _MultiDeviceReplicator, GradScaler, OptState
from torch.distributed.distributed_c10d import ProcessGroup


logger = logging.getLogger(__name__)


Expand Down
1 change: 1 addition & 0 deletions torch/distributed/fsdp/wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import torch.nn as nn


__all__ = [
"always_wrap_policy",
"lambda_auto_wrap_policy",
Expand Down
10 changes: 8 additions & 2 deletions torch/distributed/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
_get_in_backward_optimizers,
)
from .functional_adadelta import _FunctionalAdadelta

from .functional_adagrad import _FunctionalAdagrad
from .functional_adam import _FunctionalAdam
from .functional_adamax import _FunctionalAdamax
Expand All @@ -26,6 +25,7 @@
from .named_optimizer import _NamedOptimizer
from .utils import as_functional_optim


with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
Expand All @@ -44,4 +44,10 @@
from .post_localSGD_optimizer import PostLocalSGDOptimizer
from .zero_redundancy_optimizer import ZeroRedundancyOptimizer

__all__ = ["as_functional_optim", "DistributedOptimizer", "PostLocalSGDOptimizer", "ZeroRedundancyOptimizer"]

__all__ = [
"as_functional_optim",
"DistributedOptimizer",
"PostLocalSGDOptimizer",
"ZeroRedundancyOptimizer",
]
10 changes: 6 additions & 4 deletions torch/distributed/optim/apply_optimizer_in_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch


__all__: List[str] = []

# WeakTensorKeyDictionary to store relevant meta-data for the Tensor/Parameter
Expand All @@ -11,6 +12,7 @@
param_to_optim_hook_handle_map = torch.utils.weak.WeakTensorKeyDictionary()
param_to_acc_grad_map = torch.utils.weak.WeakTensorKeyDictionary()


@no_type_check
def _apply_optimizer_in_backward(
optimizer_class: Type[torch.optim.Optimizer],
Expand Down Expand Up @@ -48,9 +50,7 @@ def _apply_optimizer_in_backward(
# have their registered optimizer(s) applied.
"""
torch._C._log_api_usage_once(
"torch.distributed.optim.apply_optimizer_in_backward"
)
torch._C._log_api_usage_once("torch.distributed.optim.apply_optimizer_in_backward")

@no_type_check
def _apply_optimizer_in_backward_to_param(param: torch.nn.Parameter) -> None:
Expand All @@ -62,7 +62,9 @@ def _apply_optimizer_in_backward_to_param(param: torch.nn.Parameter) -> None:
# Don't create a new acc_grad if we already have one
# i.e. for shared parameters or attaching multiple optimizers to a param.
if param not in param_to_acc_grad_map:
param_to_acc_grad_map[param] = param.view_as(param).grad_fn.next_functions[0][0]
param_to_acc_grad_map[param] = param.view_as(param).grad_fn.next_functions[
0
][0]

optimizer = optimizer_class([param], **optimizer_kwargs)

Expand Down
5 changes: 3 additions & 2 deletions torch/distributed/optim/functional_adadelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

import torch
import torch.optim._functional as F

from torch import Tensor


__all__: List[str] = []


# Define a TorchScript compatible Functional Adadelta Optimizer
# where we use these optimizer in a functional way.
# Instead of using the `param.grad` when updating parameters,
Expand Down Expand Up @@ -102,5 +103,5 @@ def step(self, gradients: List[Optional[Tensor]]):
weight_decay=weight_decay,
foreach=self.foreach,
maximize=self.maximize,
has_complex=has_complex
has_complex=has_complex,
)
3 changes: 2 additions & 1 deletion torch/distributed/optim/functional_adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

import torch
import torch.optim._functional as F

from torch import Tensor


__all__: List[str] = []


# Define a TorchScript compatible Functional Adagrad Optimizer
# where we use these optimizer in a functional way.
# Instead of using the `param.grad` when updating parameters,
Expand Down
3 changes: 2 additions & 1 deletion torch/distributed/optim/functional_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

import torch
import torch.optim._functional as F

from torch import Tensor


__all__: List[str] = []


# Define a TorchScript compatible Functional Adam Optimizer
# where we use these optimizer in a functional way.
# Instead of using the `param.grad` when updating parameters,
Expand Down
3 changes: 2 additions & 1 deletion torch/distributed/optim/functional_adamax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

import torch
import torch.optim._functional as F

from torch import Tensor


__all__: List[str] = []


# Define a TorchScript compatible Functional Adamax Optimizer
# where we use these optimizer in a functional way.
# Instead of using the `param.grad` when updating parameters,
Expand Down
3 changes: 2 additions & 1 deletion torch/distributed/optim/functional_adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

import torch
import torch.optim._functional as F

from torch import Tensor


__all__: List[str] = []


# Define a TorchScript compatible Functional AdamW Optimizer
# where we use these optimizer in a functional way.
# Instead of using the `param.grad` when updating parameters,
Expand Down
3 changes: 2 additions & 1 deletion torch/distributed/optim/functional_rmsprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

import torch
import torch.optim._functional as F

from torch import Tensor


__all__: List[str] = []


# Define a TorchScript compatible Functional RMSprop Optimizer
# where we use these optimizer in a functional way.
# Instead of using the `param.grad` when updating parameters,
Expand Down
3 changes: 2 additions & 1 deletion torch/distributed/optim/functional_rprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

import torch
import torch.optim._functional as F

from torch import Tensor


__all__: List[str] = []


# Define a TorchScript compatible Functional Rprop Optimizer
# where we use these optimizer in a functional way.
# Instead of using the `param.grad` when updating parameters,
Expand Down
3 changes: 2 additions & 1 deletion torch/distributed/optim/functional_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

import torch
import torch.optim._functional as F

from torch import Tensor


__all__: List[str] = []


# Define a TorchScript compatible Functional SGD Optimizer
# where we use these optimizer in a functional way.
# Instead of using the `param.grad` when updating parameters,
Expand Down
13 changes: 11 additions & 2 deletions torch/distributed/optim/named_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
# mypy: allow-untyped-defs
import logging
import warnings

from copy import deepcopy
from typing import Any, Callable, Collection, Dict, List, Mapping, Optional, Union, overload
from typing import (
Any,
Callable,
Collection,
Dict,
List,
Mapping,
Optional,
overload,
Union,
)

import torch
import torch.nn as nn
Expand Down
Loading

0 comments on commit 3b798df

Please sign in to comment.