From 3b798df853444d66077ffa846f5682e621b07388 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 18 Jun 2024 23:21:44 +0800 Subject: [PATCH] [BE][Easy] enable UFMT for `torch/distributed/{fsdp,optim,rpc}/` (#128869) Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128869 Approved by: https://github.com/fegin ghstack dependencies: #128868 --- .lintrunner.toml | 27 ---- torch/distributed/fsdp/__init__.py | 1 + torch/distributed/fsdp/_common_utils.py | 2 + torch/distributed/fsdp/_debug_utils.py | 1 + torch/distributed/fsdp/_flat_param.py | 1 + torch/distributed/fsdp/_init_utils.py | 2 +- torch/distributed/fsdp/_optim_utils.py | 1 + torch/distributed/fsdp/_runtime_utils.py | 1 + torch/distributed/fsdp/_state_dict_utils.py | 3 - .../distributed/fsdp/_unshard_param_utils.py | 1 + torch/distributed/fsdp/_wrap_utils.py | 1 - torch/distributed/fsdp/api.py | 2 +- .../fsdp/fully_sharded_data_parallel.py | 2 +- torch/distributed/fsdp/sharded_grad_scaler.py | 1 + torch/distributed/fsdp/wrap.py | 1 + torch/distributed/optim/__init__.py | 10 +- .../optim/apply_optimizer_in_backward.py | 10 +- .../distributed/optim/functional_adadelta.py | 5 +- torch/distributed/optim/functional_adagrad.py | 3 +- torch/distributed/optim/functional_adam.py | 3 +- torch/distributed/optim/functional_adamax.py | 3 +- torch/distributed/optim/functional_adamw.py | 3 +- torch/distributed/optim/functional_rmsprop.py | 3 +- torch/distributed/optim/functional_rprop.py | 3 +- torch/distributed/optim/functional_sgd.py | 3 +- torch/distributed/optim/named_optimizer.py | 13 +- torch/distributed/optim/optimizer.py | 5 +- torch/distributed/optim/utils.py | 2 + .../optim/zero_redundancy_optimizer.py | 37 +++--- torch/distributed/rpc/__init__.py | 69 +++++----- torch/distributed/rpc/_testing/__init__.py | 5 +- .../_testing/faulty_agent_backend_registry.py | 11 +- torch/distributed/rpc/_utils.py | 19 ++- torch/distributed/rpc/api.py | 118 ++++++++++-------- torch/distributed/rpc/backend_registry.py | 99 ++++++++++----- torch/distributed/rpc/constants.py | 3 +- torch/distributed/rpc/functions.py | 2 + torch/distributed/rpc/internal.py | 5 +- torch/distributed/rpc/options.py | 2 + torch/distributed/rpc/rref_proxy.py | 17 ++- .../rpc/server_process_global_profiler.py | 13 +- 41 files changed, 300 insertions(+), 213 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index e3f1b58027c3ec..99c04cac4fbb39 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1413,35 +1413,8 @@ exclude_patterns = [ 'torch/distributed/nn/jit/instantiator.py', 'torch/distributed/nn/jit/templates/__init__.py', 'torch/distributed/nn/jit/templates/remote_module_template.py', - 'torch/distributed/optim/__init__.py', - 'torch/distributed/optim/apply_optimizer_in_backward.py', - 'torch/distributed/optim/functional_adadelta.py', - 'torch/distributed/optim/functional_adagrad.py', - 'torch/distributed/optim/functional_adam.py', - 'torch/distributed/optim/functional_adamax.py', - 'torch/distributed/optim/functional_adamw.py', - 'torch/distributed/optim/functional_rmsprop.py', - 'torch/distributed/optim/functional_rprop.py', - 'torch/distributed/optim/functional_sgd.py', - 'torch/distributed/optim/named_optimizer.py', - 'torch/distributed/optim/optimizer.py', - 'torch/distributed/optim/post_localSGD_optimizer.py', - 'torch/distributed/optim/utils.py', - 'torch/distributed/optim/zero_redundancy_optimizer.py', 'torch/distributed/remote_device.py', 'torch/distributed/rendezvous.py', - 'torch/distributed/rpc/__init__.py', - 'torch/distributed/rpc/_testing/__init__.py', - 'torch/distributed/rpc/_testing/faulty_agent_backend_registry.py', - 'torch/distributed/rpc/_utils.py', - 'torch/distributed/rpc/api.py', - 'torch/distributed/rpc/backend_registry.py', - 'torch/distributed/rpc/constants.py', - 'torch/distributed/rpc/functions.py', - 'torch/distributed/rpc/internal.py', - 'torch/distributed/rpc/options.py', - 'torch/distributed/rpc/rref_proxy.py', - 'torch/distributed/rpc/server_process_global_profiler.py', 'torch/distributed/run.py', 'torch/fft/__init__.py', 'torch/func/__init__.py', diff --git a/torch/distributed/fsdp/__init__.py b/torch/distributed/fsdp/__init__.py index d887730f442f6d..6180dbb3df299e 100644 --- a/torch/distributed/fsdp/__init__.py +++ b/torch/distributed/fsdp/__init__.py @@ -18,6 +18,7 @@ StateDictType, ) + __all__ = [ "BackwardPrefetch", "CPUOffload", diff --git a/torch/distributed/fsdp/_common_utils.py b/torch/distributed/fsdp/_common_utils.py index aae2405d0bb50d..10d0f821265119 100644 --- a/torch/distributed/fsdp/_common_utils.py +++ b/torch/distributed/fsdp/_common_utils.py @@ -44,9 +44,11 @@ StateDictType, ) + if TYPE_CHECKING: from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions + from ._flat_param import FlatParamHandle FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module" diff --git a/torch/distributed/fsdp/_debug_utils.py b/torch/distributed/fsdp/_debug_utils.py index 523330e5580dfd..163d9a045b68ea 100644 --- a/torch/distributed/fsdp/_debug_utils.py +++ b/torch/distributed/fsdp/_debug_utils.py @@ -15,6 +15,7 @@ clean_tensor_name, ) + logger = logging.getLogger(__name__) diff --git a/torch/distributed/fsdp/_flat_param.py b/torch/distributed/fsdp/_flat_param.py index 816b91433063af..8bc975dc72fd5a 100644 --- a/torch/distributed/fsdp/_flat_param.py +++ b/torch/distributed/fsdp/_flat_param.py @@ -50,6 +50,7 @@ FSDPExtensions, ) + __all__ = [ "FlatParameter", "FlatParamHandle", diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py index c8b58091bf89b5..aaeedf22397a42 100644 --- a/torch/distributed/fsdp/_init_utils.py +++ b/torch/distributed/fsdp/_init_utils.py @@ -58,9 +58,9 @@ from torch.distributed.fsdp.wrap import _Policy from torch.distributed.tensor.parallel.fsdp import DTensorExtensions from torch.distributed.utils import _sync_params_and_buffers - from torch.utils._python_dispatch import is_traceable_wrapper_subclass + if TYPE_CHECKING: from torch.utils.hooks import RemovableHandle diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index 54f800a168653a..4cfe761769a3b9 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -55,6 +55,7 @@ ) from torch.utils._pytree import tree_map_only + if TYPE_CHECKING: from torch.distributed._shard.sharded_tensor import ShardedTensor diff --git a/torch/distributed/fsdp/_runtime_utils.py b/torch/distributed/fsdp/_runtime_utils.py index 833c1d45697aef..f84e7dd3e5055e 100644 --- a/torch/distributed/fsdp/_runtime_utils.py +++ b/torch/distributed/fsdp/_runtime_utils.py @@ -39,6 +39,7 @@ ) from torch.utils import _pytree as pytree + logger = logging.getLogger(__name__) # Do not include "process_group" to enable hybrid shard and MoE cases diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index 797a0116587bb3..815cfb2dd4a1ff 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -17,9 +17,7 @@ import torch import torch.distributed as dist - import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as checkpoint_wrapper - import torch.nn as nn import torch.nn.functional as F from torch.distributed._shard.sharded_tensor import ( @@ -29,7 +27,6 @@ ) from torch.distributed._tensor import DTensor from torch.distributed.device_mesh import _mesh_resources - from torch.distributed.fsdp._common_utils import ( _FSDPState, _get_module_fsdp_state_if_fully_sharded_module, diff --git a/torch/distributed/fsdp/_unshard_param_utils.py b/torch/distributed/fsdp/_unshard_param_utils.py index 435193a88703a1..4143d2928c8b83 100644 --- a/torch/distributed/fsdp/_unshard_param_utils.py +++ b/torch/distributed/fsdp/_unshard_param_utils.py @@ -26,6 +26,7 @@ from ._flat_param import FlatParamHandle + FLAT_PARAM = "_flat_param" diff --git a/torch/distributed/fsdp/_wrap_utils.py b/torch/distributed/fsdp/_wrap_utils.py index 84cdf250d8ae1e..895bcbd8e967b4 100644 --- a/torch/distributed/fsdp/_wrap_utils.py +++ b/torch/distributed/fsdp/_wrap_utils.py @@ -11,7 +11,6 @@ _get_module_fsdp_state, _override_module_mixed_precision, ) - from torch.distributed.fsdp.wrap import ( _construct_wrap_fn, _or_policy, diff --git a/torch/distributed/fsdp/api.py b/torch/distributed/fsdp/api.py index 0272ee0c57c9fc..f2e4bdb7ea0231 100644 --- a/torch/distributed/fsdp/api.py +++ b/torch/distributed/fsdp/api.py @@ -5,12 +5,12 @@ from dataclasses import dataclass from enum import auto, Enum - from typing import Optional, Sequence, Type import torch from torch.nn.modules.batchnorm import _BatchNorm + __all__ = [ "ShardingStrategy", "BackwardPrefetch", diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index 9edd057a8f371e..1567bb973b22a6 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -85,8 +85,8 @@ StateDictType, ) from torch.distributed.utils import _p_assert -from ._flat_param import FlatParameter, FlatParamHandle +from ._flat_param import FlatParameter, FlatParamHandle from ._optim_utils import ( _flatten_optim_state_dict, _get_param_id_to_param_from_optim_input, diff --git a/torch/distributed/fsdp/sharded_grad_scaler.py b/torch/distributed/fsdp/sharded_grad_scaler.py index 3487e01263c719..7c1b2f83528683 100644 --- a/torch/distributed/fsdp/sharded_grad_scaler.py +++ b/torch/distributed/fsdp/sharded_grad_scaler.py @@ -8,6 +8,7 @@ from torch.amp.grad_scaler import _MultiDeviceReplicator, GradScaler, OptState from torch.distributed.distributed_c10d import ProcessGroup + logger = logging.getLogger(__name__) diff --git a/torch/distributed/fsdp/wrap.py b/torch/distributed/fsdp/wrap.py index acb5a6f1f642ad..f8604bbb1bb048 100644 --- a/torch/distributed/fsdp/wrap.py +++ b/torch/distributed/fsdp/wrap.py @@ -24,6 +24,7 @@ import torch.nn as nn + __all__ = [ "always_wrap_policy", "lambda_auto_wrap_policy", diff --git a/torch/distributed/optim/__init__.py b/torch/distributed/optim/__init__.py index fe33265fd532f4..924b993ec8414b 100644 --- a/torch/distributed/optim/__init__.py +++ b/torch/distributed/optim/__init__.py @@ -15,7 +15,6 @@ _get_in_backward_optimizers, ) from .functional_adadelta import _FunctionalAdadelta - from .functional_adagrad import _FunctionalAdagrad from .functional_adam import _FunctionalAdam from .functional_adamax import _FunctionalAdamax @@ -26,6 +25,7 @@ from .named_optimizer import _NamedOptimizer from .utils import as_functional_optim + with warnings.catch_warnings(): warnings.simplefilter("always") warnings.warn( @@ -44,4 +44,10 @@ from .post_localSGD_optimizer import PostLocalSGDOptimizer from .zero_redundancy_optimizer import ZeroRedundancyOptimizer -__all__ = ["as_functional_optim", "DistributedOptimizer", "PostLocalSGDOptimizer", "ZeroRedundancyOptimizer"] + +__all__ = [ + "as_functional_optim", + "DistributedOptimizer", + "PostLocalSGDOptimizer", + "ZeroRedundancyOptimizer", +] diff --git a/torch/distributed/optim/apply_optimizer_in_backward.py b/torch/distributed/optim/apply_optimizer_in_backward.py index 6bd182cca5736f..36f679f4eba49b 100644 --- a/torch/distributed/optim/apply_optimizer_in_backward.py +++ b/torch/distributed/optim/apply_optimizer_in_backward.py @@ -2,6 +2,7 @@ import torch + __all__: List[str] = [] # WeakTensorKeyDictionary to store relevant meta-data for the Tensor/Parameter @@ -11,6 +12,7 @@ param_to_optim_hook_handle_map = torch.utils.weak.WeakTensorKeyDictionary() param_to_acc_grad_map = torch.utils.weak.WeakTensorKeyDictionary() + @no_type_check def _apply_optimizer_in_backward( optimizer_class: Type[torch.optim.Optimizer], @@ -48,9 +50,7 @@ def _apply_optimizer_in_backward( # have their registered optimizer(s) applied. """ - torch._C._log_api_usage_once( - "torch.distributed.optim.apply_optimizer_in_backward" - ) + torch._C._log_api_usage_once("torch.distributed.optim.apply_optimizer_in_backward") @no_type_check def _apply_optimizer_in_backward_to_param(param: torch.nn.Parameter) -> None: @@ -62,7 +62,9 @@ def _apply_optimizer_in_backward_to_param(param: torch.nn.Parameter) -> None: # Don't create a new acc_grad if we already have one # i.e. for shared parameters or attaching multiple optimizers to a param. if param not in param_to_acc_grad_map: - param_to_acc_grad_map[param] = param.view_as(param).grad_fn.next_functions[0][0] + param_to_acc_grad_map[param] = param.view_as(param).grad_fn.next_functions[ + 0 + ][0] optimizer = optimizer_class([param], **optimizer_kwargs) diff --git a/torch/distributed/optim/functional_adadelta.py b/torch/distributed/optim/functional_adadelta.py index bc5f7c63dd1751..3ad51348b6afab 100644 --- a/torch/distributed/optim/functional_adadelta.py +++ b/torch/distributed/optim/functional_adadelta.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional Adadelta Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, @@ -102,5 +103,5 @@ def step(self, gradients: List[Optional[Tensor]]): weight_decay=weight_decay, foreach=self.foreach, maximize=self.maximize, - has_complex=has_complex + has_complex=has_complex, ) diff --git a/torch/distributed/optim/functional_adagrad.py b/torch/distributed/optim/functional_adagrad.py index 93a1fe2b2240df..67f7328489ed21 100644 --- a/torch/distributed/optim/functional_adagrad.py +++ b/torch/distributed/optim/functional_adagrad.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional Adagrad Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_adam.py b/torch/distributed/optim/functional_adam.py index 34868d23d8a53c..3ed271765170c6 100644 --- a/torch/distributed/optim/functional_adam.py +++ b/torch/distributed/optim/functional_adam.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional Adam Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_adamax.py b/torch/distributed/optim/functional_adamax.py index 32bce65dfe1f50..8f1fdc0ccc02be 100644 --- a/torch/distributed/optim/functional_adamax.py +++ b/torch/distributed/optim/functional_adamax.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional Adamax Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_adamw.py b/torch/distributed/optim/functional_adamw.py index 43addd0508221f..d3f1f80e9209bd 100644 --- a/torch/distributed/optim/functional_adamw.py +++ b/torch/distributed/optim/functional_adamw.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional AdamW Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_rmsprop.py b/torch/distributed/optim/functional_rmsprop.py index 851119c8600c0e..7a03e8e9f462f8 100644 --- a/torch/distributed/optim/functional_rmsprop.py +++ b/torch/distributed/optim/functional_rmsprop.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional RMSprop Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_rprop.py b/torch/distributed/optim/functional_rprop.py index 60742bc68896fc..615015a95a316b 100644 --- a/torch/distributed/optim/functional_rprop.py +++ b/torch/distributed/optim/functional_rprop.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional Rprop Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_sgd.py b/torch/distributed/optim/functional_sgd.py index 3a8176e877057c..32381855db6b55 100644 --- a/torch/distributed/optim/functional_sgd.py +++ b/torch/distributed/optim/functional_sgd.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional SGD Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/named_optimizer.py b/torch/distributed/optim/named_optimizer.py index 9e1e5377873d10..8e0b539b148264 100644 --- a/torch/distributed/optim/named_optimizer.py +++ b/torch/distributed/optim/named_optimizer.py @@ -1,9 +1,18 @@ # mypy: allow-untyped-defs import logging import warnings - from copy import deepcopy -from typing import Any, Callable, Collection, Dict, List, Mapping, Optional, Union, overload +from typing import ( + Any, + Callable, + Collection, + Dict, + List, + Mapping, + Optional, + overload, + Union, +) import torch import torch.nn as nn diff --git a/torch/distributed/optim/optimizer.py b/torch/distributed/optim/optimizer.py index f2eca606c02611..65df14770c21c4 100644 --- a/torch/distributed/optim/optimizer.py +++ b/torch/distributed/optim/optimizer.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs import logging - from collections import defaultdict from threading import Lock from typing import List, Optional @@ -12,8 +11,10 @@ import torch.nn as nn from torch import Tensor from torch.distributed.rpc import RRef + from .utils import functional_optim_map + __all__ = ["DistributedOptimizer"] logger = logging.getLogger(__name__) @@ -205,7 +206,7 @@ def __init__(self, optimizer_class, params_rref, *args, **kwargs): "(i.e. Distributed Model Parallel training on CPU) due to the Python's " "Global Interpreter Lock (GIL). Please file an issue if you need this " "optimizer in TorchScript. ", - optimizer_class + optimizer_class, ) optimizer_new_func = _new_local_optimizer diff --git a/torch/distributed/optim/utils.py b/torch/distributed/optim/utils.py index af2220ca557493..d2c75eee7e39bc 100644 --- a/torch/distributed/optim/utils.py +++ b/torch/distributed/optim/utils.py @@ -2,6 +2,7 @@ from typing import Type from torch import optim + from .functional_adadelta import _FunctionalAdadelta from .functional_adagrad import _FunctionalAdagrad from .functional_adam import _FunctionalAdam @@ -11,6 +12,7 @@ from .functional_rprop import _FunctionalRprop from .functional_sgd import _FunctionalSGD + # dict to map a user passed in optimizer_class to a functional # optimizer class if we have already defined inside the # distributed.optim package, this is so that we hide the diff --git a/torch/distributed/optim/zero_redundancy_optimizer.py b/torch/distributed/optim/zero_redundancy_optimizer.py index 8a3be3b0181536..f664d11afb79c0 100644 --- a/torch/distributed/optim/zero_redundancy_optimizer.py +++ b/torch/distributed/optim/zero_redundancy_optimizer.py @@ -20,11 +20,12 @@ from torch.optim import Optimizer -logger = logging.getLogger(__name__) - __all__ = ["ZeroRedundancyOptimizer"] +logger = logging.getLogger(__name__) + + # Credits: classy_vision/generic/distributed_util.py def _recursive_copy_to_device( value: Any, @@ -925,9 +926,9 @@ def _bucket_assignments_per_rank(self) -> List[Dict[int, _DDPBucketAssignment]]: mapping bucket indices to :class:`_DDPBucketAssignment` s for each rank. """ - assert self._overlap_with_ddp, ( - "`_bucket_assignments_per_rank` only be used if `overlap_with_ddp=True`" - ) + assert ( + self._overlap_with_ddp + ), "`_bucket_assignments_per_rank` only be used if `overlap_with_ddp=True`" if len(self._bucket_assignments_per_rank_cache) > 0: return self._bucket_assignments_per_rank_cache @@ -1074,9 +1075,9 @@ def _local_step( "Specifying `gradients` should not " "be used when `overlap_with_ddp=False`" ) - assert closure is None, ( - "`closure` is not supported when using a local functional optimizer" - ) + assert ( + closure is None + ), "`closure` is not supported when using a local functional optimizer" loss = self.optim.step(gradients=gradients) # Sync any updated attributes in the local optimizer to the exposed @@ -1504,7 +1505,7 @@ def _init_local_optimizer(self) -> None: "%s does not support the argument " "`_allow_empty_param_list`; ZeroRedundancyOptimizer may " "error due to an empty parameter list", - self._optim_constructor + self._optim_constructor, ) self.optim: Any = self._optim_constructor(params, **self._optim_defaults) # type: ignore[no-redef] @@ -1515,17 +1516,16 @@ def _init_local_optimizer(self) -> None: self._bucket_assignments_per_rank[self.global_rank] ) logger.info( - "rank %s with %s parameters " - "across %s buckets", - self.global_rank, local_numel, num_assigned_buckets + "rank %s with %s parameters " "across %s buckets", + self.global_rank, + local_numel, + num_assigned_buckets, ) if self.global_rank == 0: logger.info( - "%s DDP " - "buckets and " - "%s bucket " - "assignments", - len(self._overlap_info.params_per_bucket), self._overlap_info.num_bucket_assignments + "%s DDP " "buckets and " "%s bucket " "assignments", + len(self._overlap_info.params_per_bucket), + self._overlap_info.num_bucket_assignments, ) else: # NOTE: Passing `param_groups` into the local optimizer constructor @@ -1640,7 +1640,8 @@ def _get_optimizer_constructor(self, optimizer_class: Any) -> Any: "Using the functional optimizer %s " "instead of %s since " "`overlap_with_ddp=True`", - optim_constructor, optimizer_class + optim_constructor, + optimizer_class, ) return optim_constructor else: diff --git a/torch/distributed/rpc/__init__.py b/torch/distributed/rpc/__init__.py index 581433d220c63e..6c6608a2a773f3 100644 --- a/torch/distributed/rpc/__init__.py +++ b/torch/distributed/rpc/__init__.py @@ -1,22 +1,25 @@ # mypy: allow-untyped-defs -from datetime import timedelta import logging import os import threading import warnings +from datetime import timedelta from typing import Generator, Tuple from urllib.parse import urlparse import torch import torch.distributed as dist + +__all__ = ["is_available"] + + logger = logging.getLogger(__name__) _init_counter = 0 _init_counter_lock = threading.Lock() -__all__ = ["is_available"] def is_available() -> bool: return hasattr(torch._C, "_rpc_init") @@ -27,54 +30,51 @@ def is_available() -> bool: if is_available(): + import numbers + + import torch.distributed.autograd as dist_autograd from torch._C._distributed_c10d import Store - from torch._C._distributed_rpc import ( + from torch._C._distributed_rpc import ( # noqa: F401 + _cleanup_python_rpc_handler, + _DEFAULT_INIT_METHOD, + _DEFAULT_NUM_WORKER_THREADS, + _DEFAULT_RPC_TIMEOUT_SEC, + _delete_all_user_and_unforked_owner_rrefs, + _destroy_rref_context, _disable_jit_rref_pickle, - _enable_jit_rref_pickle, _disable_server_process_global_profiler, + _enable_jit_rref_pickle, _enable_server_process_global_profiler, - _set_and_start_rpc_agent, - _reset_current_rpc_agent, - _delete_all_user_and_unforked_owner_rrefs, - _destroy_rref_context, - _set_profiler_node_id, - _is_current_rpc_agent_set, - _rref_context_get_debug_info, - _cleanup_python_rpc_handler, - _invoke_rpc_builtin, - _invoke_rpc_python_udf, - _invoke_rpc_torchscript, + _get_current_rpc_agent, _invoke_remote_builtin, _invoke_remote_python_udf, _invoke_remote_torchscript, + _invoke_rpc_builtin, + _invoke_rpc_python_udf, + _invoke_rpc_torchscript, + _is_current_rpc_agent_set, + _reset_current_rpc_agent, + _rref_context_get_debug_info, + _set_and_start_rpc_agent, + _set_profiler_node_id, _set_rpc_timeout, - _get_current_rpc_agent, - get_rpc_timeout, - enable_gil_profiling, - RpcBackendOptions, _TensorPipeRpcBackendOptionsBase, - RpcAgent, + _UNSET_RPC_TIMEOUT, + enable_gil_profiling, + get_rpc_timeout, PyRRef, - TensorPipeAgent, RemoteProfilerManager, + RpcAgent, + RpcBackendOptions, + TensorPipeAgent, WorkerInfo, - _DEFAULT_INIT_METHOD, - _DEFAULT_NUM_WORKER_THREADS, - _UNSET_RPC_TIMEOUT, - _DEFAULT_RPC_TIMEOUT_SEC, - ) # noqa: F401 + ) from . import api, backend_registry, functions from .api import * # noqa: F401,F403 - import numbers - - import torch.distributed.autograd as dist_autograd - from .backend_registry import BackendType from .options import TensorPipeRpcBackendOptions # noqa: F401 - from .server_process_global_profiler import ( - _server_process_global_profile, - ) + from .server_process_global_profiler import _server_process_global_profile rendezvous_iterator: Generator[Tuple[Store, int, int], None, None] @@ -153,7 +153,7 @@ def init_rpc( "corresponding to %(backend)s, hence that backend will be used " "instead of the default BackendType.TENSORPIPE. To silence this " "warning pass `backend=%(backend)s` explicitly.", - {'backend': backend} + {"backend": backend}, ) if backend is None: @@ -224,7 +224,6 @@ def _init_rpc_backend( world_size=None, rpc_backend_options=None, ): - _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options) if _is_current_rpc_agent_set(): diff --git a/torch/distributed/rpc/_testing/__init__.py b/torch/distributed/rpc/_testing/__init__.py index 640c4d09f06281..8ac1c02f4cee4c 100644 --- a/torch/distributed/rpc/_testing/__init__.py +++ b/torch/distributed/rpc/_testing/__init__.py @@ -12,8 +12,9 @@ def is_available(): if is_available(): # Registers FAULTY_TENSORPIPE RPC backend. - from . import faulty_agent_backend_registry from torch._C._distributed_rpc_testing import ( - FaultyTensorPipeRpcBackendOptions, FaultyTensorPipeAgent, + FaultyTensorPipeRpcBackendOptions, ) + + from . import faulty_agent_backend_registry diff --git a/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py b/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py index 9e8660989e5a7c..d04882e16e79a9 100644 --- a/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py +++ b/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py @@ -4,6 +4,7 @@ import torch.distributed as dist import torch.distributed.rpc as rpc + def _faulty_tensorpipe_construct_rpc_backend_options_handler( rpc_timeout, init_method, @@ -11,7 +12,7 @@ def _faulty_tensorpipe_construct_rpc_backend_options_handler( messages_to_fail, messages_to_delay, num_fail_sends, - **kwargs + **kwargs, ): from . import FaultyTensorPipeRpcBackendOptions @@ -28,16 +29,14 @@ def _faulty_tensorpipe_construct_rpc_backend_options_handler( def _faulty_tensorpipe_init_backend_handler( store, name, rank, world_size, rpc_backend_options ): - from . import FaultyTensorPipeAgent - from . import FaultyTensorPipeRpcBackendOptions from torch.distributed.rpc import api + from . import FaultyTensorPipeAgent, FaultyTensorPipeRpcBackendOptions + if not isinstance(store, dist.Store): raise TypeError(f"`store` must be a c10d::Store. {store}") - if not isinstance( - rpc_backend_options, FaultyTensorPipeRpcBackendOptions - ): + if not isinstance(rpc_backend_options, FaultyTensorPipeRpcBackendOptions): raise TypeError( f"`rpc_backend_options` must be a `FaultyTensorPipeRpcBackendOptions`. {rpc_backend_options}" ) diff --git a/torch/distributed/rpc/_utils.py b/torch/distributed/rpc/_utils.py index 6499a80e0e1724..8925bc662b5f97 100644 --- a/torch/distributed/rpc/_utils.py +++ b/torch/distributed/rpc/_utils.py @@ -1,12 +1,14 @@ # mypy: allow-untyped-defs +import logging from contextlib import contextmanager from typing import cast -import logging -from . import api -from . import TensorPipeAgent + +from . import api, TensorPipeAgent + logger = logging.getLogger(__name__) + @contextmanager def _group_membership_management(store, name, is_join): token_key = "RpcGroupManagementToken" @@ -29,10 +31,17 @@ def _group_membership_management(store, name, is_join): try: store.wait([returned]) except RuntimeError: - logger.error("Group membership token %s timed out waiting for %s to be released.", my_token, returned) + logger.error( + "Group membership token %s timed out waiting for %s to be released.", + my_token, + returned, + ) raise + def _update_group_membership(worker_info, my_devices, reverse_device_map, is_join): agent = cast(TensorPipeAgent, api._get_current_rpc_agent()) - ret = agent._update_group_membership(worker_info, my_devices, reverse_device_map, is_join) + ret = agent._update_group_membership( + worker_info, my_devices, reverse_device_map, is_join + ) return ret diff --git a/torch/distributed/rpc/api.py b/torch/distributed/rpc/api.py index a33358eb0dc674..5fc9e61aa5592a 100644 --- a/torch/distributed/rpc/api.py +++ b/torch/distributed/rpc/api.py @@ -1,6 +1,4 @@ # mypy: allow-untyped-defs -__all__ = ["shutdown", "get_worker_info", "remote", "rpc_sync", - "rpc_async", "RRef", "AllGatherStates", "method_factory", "new_method"] import collections import contextlib @@ -8,17 +6,10 @@ import inspect import logging import threading -from typing import Dict, Generic, TypeVar, Set, Any, TYPE_CHECKING +from typing import Any, Dict, Generic, Set, TYPE_CHECKING, TypeVar import torch -from torch.futures import Future - from torch._C._distributed_rpc import ( - PyRRef, - RemoteProfilerManager, - WorkerInfo, - TensorPipeAgent, - get_rpc_timeout, _cleanup_python_rpc_handler, _delete_all_user_and_unforked_owner_rrefs, _destroy_rref_context, @@ -32,18 +23,36 @@ _is_current_rpc_agent_set, _reset_current_rpc_agent, _set_and_start_rpc_agent, + get_rpc_timeout, + PyRRef, + RemoteProfilerManager, + TensorPipeAgent, + WorkerInfo, ) +from torch.futures import Future +from ._utils import _group_membership_management, _update_group_membership +from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT from .internal import ( + _build_rpc_profiling_key, + _internal_rpc_pickler, PythonUDF, RPCExecMode, - _internal_rpc_pickler, - _build_rpc_profiling_key, ) -from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT -from ._utils import _group_membership_management, _update_group_membership +__all__ = [ + "shutdown", + "get_worker_info", + "remote", + "rpc_sync", + "rpc_async", + "RRef", + "AllGatherStates", + "method_factory", + "new_method", +] + logger = logging.getLogger(__name__) @@ -59,6 +68,7 @@ _ignore_rref_leak = True _default_pickler = _internal_rpc_pickler + @contextlib.contextmanager def _use_rpc_pickler(rpc_pickler): r""" @@ -107,7 +117,9 @@ def __init__(self): _ALL_WORKER_NAMES: Set[Any] = set() _all_gather_dict_lock = threading.RLock() _all_gather_sequence_id: Dict[str, int] = {} -_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict(AllGatherStates) +_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict( + AllGatherStates +) def _init_rpc_states(agent): @@ -146,6 +158,7 @@ def _broadcast_to_followers(sequence_id, objects_map): states.gathered_objects = objects_map states.proceed_signal.set() + _thread_local_var = threading.local() @@ -245,7 +258,7 @@ def _all_gather(obj, worker_names=None, timeout: float = UNSET_RPC_TIMEOUT): follower_name, _broadcast_to_followers, args=(sequence_id, states.gathered_objects), - timeout=rpc_timeout + timeout=rpc_timeout, ) worker_name_to_response_future_dict[follower_name] = fut @@ -283,9 +296,7 @@ def _barrier(worker_names): try: _all_gather(None, set(worker_names)) except RuntimeError as ex: - logger.error( - "Failed to complete barrier, got error %s", ex - ) + logger.error("Failed to complete barrier, got error %s", ex) @_require_initialized @@ -371,7 +382,11 @@ def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT): all_worker_infos = agent.get_worker_infos() for worker in all_worker_infos: if worker.name != my_name: - rpc_sync(worker.name, _update_group_membership, args=(my_worker_info, [], {}, False)) + rpc_sync( + worker.name, + _update_group_membership, + args=(my_worker_info, [], {}, False), + ) agent.join(shutdown=True, timeout=timeout) finally: # In case of errors, continue to complete the local shutdown. @@ -445,13 +460,10 @@ def _rref_typeof_on_owner(rref, blocking: bool = True): return future -def _rref_typeof_on_user(rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: bool = True): - fut = rpc_async( - rref.owner(), - _rref_typeof_on_owner, - args=(rref,), - timeout=timeout - ) +def _rref_typeof_on_user( + rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: bool = True +): + fut = rpc_async(rref.owner(), _rref_typeof_on_owner, args=(rref,), timeout=timeout) if blocking: return fut.wait() else: @@ -463,13 +475,16 @@ def _rref_typeof_on_user(rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: boo if TYPE_CHECKING: + class RRef(PyRRef[T], Generic[T]): pass + else: try: # Combine the implementation class and the type class. class RRef(PyRRef, Generic[T]): pass + except TypeError: # TypeError: metaclass conflict: the metaclass of a derived class # must be a (non-strict) subclass of the metaclasses of all its bases @@ -517,7 +532,9 @@ def method(self, *args, **kwargs): assert docstring is not None, "RRef user-facing methods should all have docstrings." # Do surgery on pybind11 generated docstrings. - docstring = docstring.replace("torch.distributed.rpc.PyRRef", "torch.distributed.rpc.RRef") + docstring = docstring.replace( + "torch.distributed.rpc.PyRRef", "torch.distributed.rpc.RRef" + ) # Attach user-facing RRef method with modified docstring. new_method = method_factory(method_name, docstring) @@ -633,7 +650,9 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): dst_worker_info = _to_worker_info(to) should_profile = _get_should_profile() - ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, RPCExecMode.REMOTE, dst_worker_info) + ctx_manager = _enable_rpc_profiler( + should_profile, qualified_name, func, RPCExecMode.REMOTE, dst_worker_info + ) with ctx_manager as rf: args = args if args else () @@ -647,7 +666,9 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): func = wrapped if qualified_name is not None: - rref = _invoke_remote_builtin(dst_worker_info, qualified_name, timeout, *args, **kwargs) + rref = _invoke_remote_builtin( + dst_worker_info, qualified_name, timeout, *args, **kwargs + ) elif isinstance(func, torch.jit.ScriptFunction): rref = _invoke_remote_torchscript( dst_worker_info.name, @@ -662,11 +683,7 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): PythonUDF(func, args, kwargs) ) rref = _invoke_remote_python_udf( - dst_worker_info, - pickled_python_udf, - tensors, - timeout, - is_async_exec + dst_worker_info, pickled_python_udf, tensors, timeout, is_async_exec ) # attach profiling information if should_profile: @@ -678,7 +695,9 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): return rref -def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = UNSET_RPC_TIMEOUT): +def _invoke_rpc( + to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = UNSET_RPC_TIMEOUT +): if not callable(func): raise TypeError("function should be callable.") @@ -687,7 +706,9 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = should_profile = _get_should_profile() - ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info) + ctx_manager = _enable_rpc_profiler( + should_profile, qualified_name, func, rpc_type, dst_worker_info + ) with ctx_manager as rf: args = args if args else () @@ -702,11 +723,7 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = if qualified_name is not None: fut = _invoke_rpc_builtin( - dst_worker_info, - qualified_name, - rpc_timeout, - *args, - **kwargs + dst_worker_info, qualified_name, rpc_timeout, *args, **kwargs ) elif isinstance(func, torch.jit.ScriptFunction): fut = _invoke_rpc_torchscript( @@ -715,18 +732,14 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = args, kwargs, rpc_timeout, - is_async_exec + is_async_exec, ) else: (pickled_python_udf, tensors) = _default_pickler.serialize( PythonUDF(func, args, kwargs) ) fut = _invoke_rpc_python_udf( - dst_worker_info, - pickled_python_udf, - tensors, - rpc_timeout, - is_async_exec + dst_worker_info, pickled_python_udf, tensors, rpc_timeout, is_async_exec ) if should_profile: assert torch.autograd._profiler_enabled() @@ -915,12 +928,15 @@ def _get_should_profile(): # Kineto profiler. ActiveProfilerType = torch._C._profiler.ActiveProfilerType return ( - torch.autograd._profiler_enabled() and - torch._C._autograd._profiler_type() == ActiveProfilerType.LEGACY # type: ignore[attr-defined] + torch.autograd._profiler_enabled() + and torch._C._autograd._profiler_type() + == ActiveProfilerType.LEGACY # type: ignore[attr-defined] ) -def _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info): +def _enable_rpc_profiler( + should_profile, qualified_name, func, rpc_type, dst_worker_info +): ctx_manager = contextlib.nullcontext() if should_profile: diff --git a/torch/distributed/rpc/backend_registry.py b/torch/distributed/rpc/backend_registry.py index 6290f9e8e2054b..a06f0276ede95a 100644 --- a/torch/distributed/rpc/backend_registry.py +++ b/torch/distributed/rpc/backend_registry.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -__all__ = ["init_backend", "backend_registered", "construct_rpc_backend_options", "register_backend", "BackendType", "BackendValue"] + import collections import enum @@ -7,13 +7,19 @@ import torch import torch.distributed as dist + +from . import api, constants as rpc_constants from ._utils import _group_membership_management, _update_group_membership -from . import api -from . import constants as rpc_constants -__all__ = ["backend_registered", "register_backend", "construct_rpc_backend_options", "init_backend", - "BackendValue", "BackendType"] +__all__ = [ + "backend_registered", + "register_backend", + "construct_rpc_backend_options", + "init_backend", + "BackendValue", + "BackendType", +] BackendValue = collections.namedtuple( "BackendValue", ["construct_rpc_backend_options_handler", "init_backend_handler"] @@ -41,6 +47,7 @@ def _backend_type_repr(self): if BackendType.__doc__: BackendType.__doc__ = _backend_type_doc + def backend_registered(backend_name): """ Checks if backend_name is registered as an RPC backend. @@ -80,7 +87,7 @@ def register_backend( init_backend_handler=init_backend_handler, ) }, - **existing_enum_dict + **existing_enum_dict, ) # Can't handle Function Enum API (mypy bug #9079) BackendType = enum.Enum(value="BackendType", names=extended_enum_dict) # type: ignore[misc] @@ -90,20 +97,22 @@ def register_backend( BackendType.__doc__ = _backend_type_doc return BackendType[backend_name] + def construct_rpc_backend_options( backend, rpc_timeout=rpc_constants.DEFAULT_RPC_TIMEOUT_SEC, init_method=rpc_constants.DEFAULT_INIT_METHOD, - **kwargs + **kwargs, ): - return backend.value.construct_rpc_backend_options_handler( rpc_timeout, init_method, **kwargs ) + def init_backend(backend, *args, **kwargs): return backend.value.init_backend_handler(*args, **kwargs) + def _init_process_group(store, rank, world_size): # Initialize ProcessGroup. process_group_timeout = rpc_constants.DEFAULT_PROCESS_GROUP_TIMEOUT @@ -115,22 +124,21 @@ def _init_process_group(store, rank, world_size): assert group is not None, "Failed to initialize default ProcessGroup." if (rank != -1) and (rank != group.rank()): - raise RuntimeError( - f"rank argument {rank} doesn't match pg rank {group.rank()}" - ) + raise RuntimeError(f"rank argument {rank} doesn't match pg rank {group.rank()}") if (world_size != -1) and (world_size != group.size()): raise RuntimeError( f"world_size argument {world_size} doesn't match pg size {group.size()}" ) return group + def _tensorpipe_construct_rpc_backend_options_handler( rpc_timeout, init_method, num_worker_threads=rpc_constants.DEFAULT_NUM_WORKER_THREADS, _transports=None, _channels=None, - **kwargs + **kwargs, ): from . import TensorPipeRpcBackendOptions @@ -155,9 +163,9 @@ def _tensorpipe_validate_devices(devices, device_count): def _tensorpipe_exchange_and_check_all_device_maps( my_name, my_device_count, my_device_maps, my_devices, group ): - gathered: List[Tuple[ - str, int, Dict[str, Dict[torch.device, torch.device]], List[torch.device] - ]] = [("", 0, {}, []) for _ in range(group.size())] + gathered: List[ + Tuple[str, int, Dict[str, Dict[torch.device, torch.device]], List[torch.device]] + ] = [("", 0, {}, []) for _ in range(group.size())] dist.all_gather_object( gathered, (my_name, my_device_count, my_device_maps, my_devices), group ) @@ -173,13 +181,15 @@ def _tensorpipe_exchange_and_check_all_device_maps( my_devices = _create_device_list(my_devices, my_device_maps, reverse_device_maps) return reverse_device_maps, my_devices -def _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices, is_static_group=True): + +def _validate_device_maps( + all_names, all_device_counts, all_device_maps, all_devices, is_static_group=True +): for node in all_names: devices = all_devices[node] if len(set(devices)) != len(devices): raise ValueError( - f"Node {node} has duplicated devices\n" - f"devices = {devices}" + f"Node {node} has duplicated devices\n" f"devices = {devices}" ) if not _tensorpipe_validate_devices(devices, all_device_counts[node]): raise ValueError( @@ -190,7 +200,9 @@ def _validate_device_maps(all_names, all_device_counts, all_device_maps, all_dev for source_node in all_names: # For dynamic group (non-static) do not check the target node name since it may not have joined yet - if is_static_group and not set(all_device_maps[source_node].keys()).issubset(all_names): + if is_static_group and not set(all_device_maps[source_node].keys()).issubset( + all_names + ): raise ValueError( f"Node {source_node} has invalid target node names in its device maps\n" f"device maps = {all_device_maps[source_node].keys()}\n" @@ -238,6 +250,7 @@ def _validate_device_maps(all_names, all_device_counts, all_device_maps, all_dev f"device count = {all_device_counts[target_node]}" ) + def _create_device_list(my_devices, my_device_maps, reverse_device_maps): if not my_devices: devices_set: Set[torch.device] = set() @@ -250,6 +263,7 @@ def _create_device_list(my_devices, my_device_maps, reverse_device_maps): my_devices = sorted(my_devices, key=lambda d: d.index) return my_devices + def _create_reverse_mapping(my_name, all_names, all_device_maps): reverse_device_maps: Dict[str, Dict[torch.device, torch.device]] = {} for node in all_names: @@ -259,8 +273,10 @@ def _create_reverse_mapping(my_name, all_names, all_device_maps): } return reverse_device_maps + def _get_device_infos(): from . import TensorPipeAgent + agent = cast(TensorPipeAgent, api._get_current_rpc_agent()) opts = agent._get_backend_options() device_count = torch.cuda.device_count() @@ -268,8 +284,10 @@ def _get_device_infos(): torch.cuda.init() return device_count, opts.device_maps, opts.devices + def _set_devices_and_reverse_device_map(agent): from . import TensorPipeAgent + agent = cast(TensorPipeAgent, agent) # Group state is retrieved from local agent # On initialization, tensorpipe agent retrieves information from all existing workers, so group state is valid @@ -282,34 +300,52 @@ def _set_devices_and_reverse_device_map(agent): worker_name = worker_info.name if worker_name != my_name: # TODO: make async? - device_count, device_map, devices = api.rpc_sync(worker_name, _get_device_infos) + device_count, device_map, devices = api.rpc_sync( + worker_name, _get_device_infos + ) else: opts = agent._get_backend_options() - device_count, device_map, devices = torch.cuda.device_count(), opts.device_maps, opts.devices + device_count, device_map, devices = ( + torch.cuda.device_count(), + opts.device_maps, + opts.devices, + ) all_device_counts[worker_name] = device_count all_device_maps[worker_name] = device_map all_devices[worker_name] = devices all_names.append(worker_name) - _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices, is_static_group=False) + _validate_device_maps( + all_names, + all_device_counts, + all_device_maps, + all_devices, + is_static_group=False, + ) reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps) # Perform RPC call to all workers, including itself, to include newly joined worker information and device maps for worker_name in all_names: # Set device list for each worker - all_devices[worker_name] = _create_device_list(all_devices[worker_name], all_device_maps[worker_name], reverse_device_maps) - api.rpc_sync(worker_name, _update_group_membership, - args=(my_worker_info, all_devices[worker_name], reverse_device_maps, True)) + all_devices[worker_name] = _create_device_list( + all_devices[worker_name], all_device_maps[worker_name], reverse_device_maps + ) + api.rpc_sync( + worker_name, + _update_group_membership, + args=(my_worker_info, all_devices[worker_name], reverse_device_maps, True), + ) + + +def _tensorpipe_init_backend_handler( + store, name, rank, world_size, rpc_backend_options +): + from . import TensorPipeAgent, TensorPipeRpcBackendOptions -def _tensorpipe_init_backend_handler(store, name, rank, world_size, rpc_backend_options): - from . import TensorPipeAgent - from . import TensorPipeRpcBackendOptions if not isinstance(store, dist.Store): raise TypeError(f"`store` must be a c10d::Store. {store}") - if not isinstance( - rpc_backend_options, TensorPipeRpcBackendOptions - ): + if not isinstance(rpc_backend_options, TensorPipeRpcBackendOptions): raise TypeError( f"`rpc_backend_options` must be a `TensorPipeRpcBackendOptions`. {rpc_backend_options}" ) @@ -389,6 +425,7 @@ def _tensorpipe_init_backend_handler(store, name, rank, world_size, rpc_backend_ raise return agent + register_backend( "TENSORPIPE", _tensorpipe_construct_rpc_backend_options_handler, diff --git a/torch/distributed/rpc/constants.py b/torch/distributed/rpc/constants.py index 3bc525b70d9bb1..56f6db4db259df 100644 --- a/torch/distributed/rpc/constants.py +++ b/torch/distributed/rpc/constants.py @@ -1,5 +1,6 @@ from datetime import timedelta from typing import List + from torch._C._distributed_rpc import ( _DEFAULT_INIT_METHOD, _DEFAULT_NUM_WORKER_THREADS, @@ -17,7 +18,7 @@ DEFAULT_NUM_WORKER_THREADS: int = _DEFAULT_NUM_WORKER_THREADS # Ensure that we don't time out when there are long periods of time without # any operations against the underlying ProcessGroup. -DEFAULT_PROCESS_GROUP_TIMEOUT: timedelta = timedelta(milliseconds=2 ** 31 - 1) +DEFAULT_PROCESS_GROUP_TIMEOUT: timedelta = timedelta(milliseconds=2**31 - 1) # Value indicating that timeout is not set for RPC call, and the default should be used. UNSET_RPC_TIMEOUT: float = _UNSET_RPC_TIMEOUT diff --git a/torch/distributed/rpc/functions.py b/torch/distributed/rpc/functions.py index c9e92980cf5662..e48ea8cc534ab8 100644 --- a/torch/distributed/rpc/functions.py +++ b/torch/distributed/rpc/functions.py @@ -159,9 +159,11 @@ def async_execution(fn): >>> ret = rref.remote().static_async_add("worker2", torch.ones(2), 1, 2).to_here() >>> print(ret) # prints tensor([4., 4.]) """ + @functools.wraps(fn) def wrapper(*args, **kwargs): return fn(*args, **kwargs) + # Can't declare and use attributes of function objects (mypy#2087) wrapper._wrapped_async_rpc_function = fn # type: ignore[attr-defined] return wrapper diff --git a/torch/distributed/rpc/internal.py b/torch/distributed/rpc/internal.py index 2fc647c414d969..5faf7d14d0da57 100644 --- a/torch/distributed/rpc/internal.py +++ b/torch/distributed/rpc/internal.py @@ -12,6 +12,7 @@ import torch.distributed as dist from torch._C._distributed_rpc import _get_current_rpc_agent + __all__ = ["RPCExecMode", "serialize", "deserialize", "PythonUDF", "RemoteException"] # Thread local tensor tables to store tensors while pickling torch.Tensor @@ -251,7 +252,9 @@ def _build_rpc_profiling_key( Returns: String representing profiling key """ - profile_key = f"rpc_{exec_type.value}#{func_name}({current_worker_name} -> {dst_worker_name})" + profile_key = ( + f"rpc_{exec_type.value}#{func_name}({current_worker_name} -> {dst_worker_name})" + ) return profile_key diff --git a/torch/distributed/rpc/options.py b/torch/distributed/rpc/options.py index 70328f34596958..53bf473ba56287 100644 --- a/torch/distributed/rpc/options.py +++ b/torch/distributed/rpc/options.py @@ -3,6 +3,7 @@ import torch from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase + from . import constants as rpc_contants @@ -10,6 +11,7 @@ __all__ = ["TensorPipeRpcBackendOptions"] + def _to_device(device: DeviceType) -> torch.device: device = torch.device(device) if device.type != "cuda": diff --git a/torch/distributed/rpc/rref_proxy.py b/torch/distributed/rpc/rref_proxy.py index cdb0a5d22b7423..85927b68bacb9c 100644 --- a/torch/distributed/rpc/rref_proxy.py +++ b/torch/distributed/rpc/rref_proxy.py @@ -1,20 +1,22 @@ # mypy: allow-untyped-defs from functools import partial -from . import functions -from . import rpc_async - import torch -from .constants import UNSET_RPC_TIMEOUT from torch.futures import Future +from . import functions, rpc_async +from .constants import UNSET_RPC_TIMEOUT + + def _local_invoke(rref, func_name, args, kwargs): return getattr(rref.local_value(), func_name)(*args, **kwargs) + @functions.async_execution def _local_invoke_async_execution(rref, func_name, args, kwargs): return getattr(rref.local_value(), func_name)(*args, **kwargs) + def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs): def _rref_type_cont(rref_fut): rref_type = rref_fut.value() @@ -33,7 +35,7 @@ def _rref_type_cont(rref_fut): rref.owner(), _invoke_func, args=(rref, func_name, args, kwargs), - timeout=timeout + timeout=timeout, ) rref_fut = rref._get_type(timeout=timeout, blocking=False) @@ -63,6 +65,7 @@ def _complete_op(fut): rref_fut.then(_wrap_rref_type_cont) return result + # This class manages proxied RPC API calls for RRefs. It is entirely used from # C++ (see python_rpc_handler.cpp). class RRefProxy: @@ -72,4 +75,6 @@ def __init__(self, rref, rpc_api, timeout=UNSET_RPC_TIMEOUT): self.rpc_timeout = timeout def __getattr__(self, func_name): - return partial(_invoke_rpc, self.rref, self.rpc_api, func_name, self.rpc_timeout) + return partial( + _invoke_rpc, self.rref, self.rpc_api, func_name, self.rpc_timeout + ) diff --git a/torch/distributed/rpc/server_process_global_profiler.py b/torch/distributed/rpc/server_process_global_profiler.py index 0543ab56a877fb..b5d089d305253f 100644 --- a/torch/distributed/rpc/server_process_global_profiler.py +++ b/torch/distributed/rpc/server_process_global_profiler.py @@ -2,18 +2,20 @@ # mypy: allow-untyped-defs import itertools +from typing import List import torch from torch.autograd.profiler_legacy import profile -from typing import List from . import ( _disable_server_process_global_profiler, _enable_server_process_global_profiler, ) + __all__: List[str] = [] + class _server_process_global_profile(profile): """ It has the same API as ``torch.autograd.profiler.profile`` class, @@ -123,7 +125,8 @@ def __enter__(self): False, False, False, - torch.profiler._ExperimentalConfig()) + torch.profiler._ExperimentalConfig(), + ) _enable_server_process_global_profiler(profiler_config) return self @@ -152,8 +155,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): process_global_function_events = [] for thread_local_events in process_global_events: # Parse from ``Event``s to ``FunctionEvent``s. - thread_local_function_events = torch.autograd.profiler_legacy._parse_legacy_records( - thread_local_events + thread_local_function_events = ( + torch.autograd.profiler_legacy._parse_legacy_records( + thread_local_events + ) ) thread_local_function_events.sort( key=lambda function_event: [