diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index 16bf63447378f..5436ecb64f0e5 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -1610,8 +1610,8 @@ def test_backend_config(self): # Ensure backend config can be created with the following arguments backend_config_strings_and_expected_values = [ (dist.Backend.GLOO, "cpu:gloo,cuda:gloo"), - (dist.Backend.NCCL, "cpu:nccl,cuda:nccl"), - (dist.Backend.MPI, "cpu:mpi,cuda:mpi"), + (dist.Backend.NCCL, "cuda:nccl"), + (dist.Backend.MPI, "cpu:mpi"), (dist.Backend.UCC, "cpu:ucc,cuda:ucc"), (dist.Backend.DUMMY, "cpu:dummy,cuda:dummy"), ("DUMMY", "cpu:dummy,cuda:dummy"), @@ -1620,7 +1620,6 @@ def test_backend_config(self): ("cpu:dummy,cuda:nccl", "cpu:dummy,cuda:nccl"), ("cpu:gloo,cuda:dummy", "cpu:gloo,cuda:dummy"), ("cpu:gloo,cuda:nccl", "cpu:gloo,cuda:nccl"), - ("cPu:gLoO,cuDa:NcCl", "cpu:gloo,cuda:nccl") ] for config_str, expected_value in backend_config_strings_and_expected_values: diff --git a/test/distributed/test_pg_wrapper.py b/test/distributed/test_pg_wrapper.py index 841a979d3632f..3c0656a671bf3 100644 --- a/test/distributed/test_pg_wrapper.py +++ b/test/distributed/test_pg_wrapper.py @@ -126,20 +126,6 @@ def _test_collectives_op_mismatch(self, wrapper_pg, use_cuda=False): tensor=tensor, ) - with self.assertRaisesRegex(RuntimeError, ".*") as cm: - scatter_result = [torch.ones(4) * i for i in range(self.world_size)] - scattered_tensor = torch.empty(4) - if self.rank == 0: - wrapper_pg.scatter(scattered_tensor, scatter_result, 0) - else: - wrapper_pg.reduce_scatter(scattered_tensor, scatter_result) - self._validate_error( - exception=cm.exception, - op_type="SCATTER" if self.rank == 0 else "REDUCE_SCATTER", - rank=self.rank, - tensor=scattered_tensor, - ) - with self.assertRaisesRegex(RuntimeError, ".*") as cm: if self.rank == 0: wrapper_pg.broadcast(tensor, 0) diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index a5bbffedac129..73e308f450dff 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -548,7 +548,9 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { c10::DeviceType deviceType, BackendType backendType, const c10::optional>& backend) { + // TODO: should we add these entries after the backend setting succeeds? deviceTypeToBackendType_[deviceType] = backendType; + deviceTypes_.insert(deviceType); // if the backendType is already set then reuse it for this device if (backendTypeToBackend_.find(backendType) != backendTypeToBackend_.end()) { @@ -585,6 +587,19 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { return backendTypeToBackend_.at(backendType); } + // Return device types supported by this ProcessGroup. + // Note: the return type is `Device` rather than `DeviceType` for the purpose + // of easy comparison at Python level. The `Device` will have default index + // (-1). + std::vector getDeviceTypes() const { + std::vector devices; + devices.reserve(deviceTypes_.size()); + for (auto& dt : deviceTypes_) { + devices.push_back(c10::Device(dt)); + } + return devices; + } + protected: // Implementations of this interface need to call this to setup // appropriate logging etc. @@ -603,6 +618,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { DebugLevel dist_debug_level_; // Backend classes for this ProcessGroup + std::unordered_set deviceTypes_; std::unordered_map deviceTypeToBackendType_; std::unordered_map> deviceTypeToBackend_; diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 2bdee787b927a..b3bdf088fd569 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1603,6 +1603,8 @@ that adds a prefix to each key inserted to the store. py::arg("timeout") = ::c10d::kUnsetTimeout, py::arg("wait_all_ranks") = false, py::call_guard()) + .def_property_readonly( + "_device_types", &::c10d::ProcessGroup::getDeviceTypes) .def( "_get_backend_name", &::c10d::ProcessGroup::getBackendName, diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index c5309c9047cb7..0d55350adadba 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -204,6 +204,13 @@ class Backend: backend_list = [UNDEFINED, GLOO, NCCL, UCC, MPI] + backend_capability: Dict[str, List[str]] = { + GLOO : ["cpu", "cuda"], + NCCL : ["cuda"], + UCC : ["cpu", "cuda"], + MPI : ["cpu"], + } + def __new__(cls, name: str): if not isinstance(name, str): raise ValueError(f"Backend name must be a string, but got: {name}") @@ -214,7 +221,7 @@ def __new__(cls, name: str): return value @classmethod - def register_backend(cls, name, func, extended_api=False): + def register_backend(cls, name, func, extended_api=False, devices: Optional[Union[str, List[str]]] = None): """ Registers a new backend with the given name and instantiating function. @@ -232,6 +239,9 @@ def register_backend(cls, name, func, extended_api=False): Default: ``False``. If set to ``True``, the backend will get an instance of ``c10d::DistributedBackendOptions``, and a process group options object as defined by the backend implementation. + device (str or list of str, optional): device type this backend + supports, e.g. "cpu", "cuda", etc. If `None`, + assuming both "cpu" and "cuda" .. note:: This support of 3rd party backend is experimental and subject to change. @@ -248,6 +258,23 @@ def register_backend(cls, name, func, extended_api=False): setattr(Backend, name.upper(), name.upper()) Backend.backend_list.append(name.lower()) + + # Update device capability matrix in Backend class + if devices is None: + # This is more of a backward support for groups like `threaded`: + # assume default devices "cpu" and "cuda", but warn + warnings.warn( + f"Device capability of {name} unspecified, assuming `cpu` and " + "`cuda`. Please specify it via the `devices` argument of " + "`register_backend`." + ) + Backend.backend_capability[name.lower()] = ["cpu", "cuda"] + elif isinstance(devices, str): + # Single device string specified. Simply convert to list. + Backend.backend_capability[name.lower()] = [devices] + else: + Backend.backend_capability[name.lower()] = devices + Backend._plugins[name.upper()] = Backend._BackendPlugin(func, extended_api) class BackendConfig: @@ -255,21 +282,23 @@ class BackendConfig: def __init__(self, backend: Union[str, Backend]): self.device_backend_map: Dict[torch.device, Backend] = {} - # Cases for when backend is a single string (without device types) if backend == Backend.UNDEFINED: # default config when backend is not specified + # supported since PyTorch 2.0 self.device_backend_map = { "cpu": Backend.GLOO, "cuda": Backend.NCCL, } elif backend.lower() in Backend.backend_list: - # backend applies to all devices (e.g. "NCCL", "GLOO", "UCC", "MPI", "custom_backend") + # Cases for when backend is a single string (without device types) + # e.g. "nccl", "gloo", "ucc", "mpi" + supported_devices = Backend.backend_capability[backend.lower()] backend_val = Backend(backend) self.device_backend_map = { - "cpu": backend_val, - "cuda": backend_val, + device : backend_val for device in supported_devices } - else: + elif ":" in backend.lower(): + # Backend specified in "device:backend" format # make sure the backend string is in the correct format # "{device_type1}:{backend1},{device_type2}:{backend2}" # e.g. "cpu:gloo,cuda:nccl" @@ -288,6 +317,24 @@ def __init__(self, backend: Union[str, Backend]): raise ValueError(f"Duplicate device type {device} \ in backend string: {backend}. {backend_str_error_message}") self.device_backend_map[device] = Backend(backend) + else: + # User specified a single backend name whose device capability is + # unknown, assuming it can support the default devices of PyTorch + # (cpu and cuda) + warnings.warn( + f"Device capability of {backend} unknown, assuming `cpu` and " + "`cuda`. You can specify it in `device:backend` format in " + "`init_process_group` call." + ) + backend_val = Backend(backend) + self.device_backend_map = { + "cpu" : backend_val, + "cuda" : backend_val, + } + + logger.info( + f"Using backend config: {self.device_backend_map}" # noqa: G004 + ) def __repr__(self): # string with all the device:backend pairs separated by commas @@ -406,6 +453,7 @@ class _World: def __init__(self): self._default_pg = None self._pg_coalesce_state: Dict[ProcessGroup, List[Union[_CollOp, P2POp]]] = {} + self._pg_object_coll_device: Dict[ProcessGroup, torch.device] = {} @property def default_pg(self): @@ -491,6 +539,10 @@ def pg_to_tag(self) -> Dict[ProcessGroup, str]: def pg_coalesce_state(self) -> Dict[ProcessGroup, List[Union[_CollOp, P2POp]]]: return self._pg_coalesce_state + @property + def pg_object_coll_device(self) -> Dict[ProcessGroup, torch.device]: + return self._pg_object_coll_device + _world = _World() """Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it""" @@ -521,14 +573,54 @@ class GroupMember(metaclass=_WorldMeta): STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key" -def _get_pg_device(group: ProcessGroup): - """ - Returns the device to use with ``group``. - This is cuda for NCCL and CPU for everything else - """ - if _check_for_nccl_backend(group): - return torch.device("cuda", torch.cuda.current_device()) - return torch.device("cpu") +def _get_object_coll_device(group: Optional[ProcessGroup] = None): + group = group or _get_default_group() + if group in _world.pg_object_coll_device: + # Previously searched and cached; just return + return _world.pg_object_coll_device[group] + + if not isinstance(group, ProcessGroup): + # Provide backward compatibility to cases where `group` passed in is + # actually a Backend (like `ProcessGroupGloo`) rather than a + # `ProcessGroup` in PT 2.0 sense + warnings.warn( + f"You are using a Backend {type(group)} as a ProcessGroup. " + "This usage is deprecated since PyTorch 2.0. Please use a public API " + "of PyTorch Distributed instead." + ) + # Most users create Gloo with private API for object collectives + _world.pg_object_coll_device[group] = torch.device("cpu") + return _world.pg_object_coll_device[group] + + """ + ``group._device_types`` is a property pybind that returns the devices + ("cpu", "cuda", etc) supported by ``group``. Can be multiple if the + ``group`` supports multiple devices. + """ + devices = group._device_types + + if len(devices) == 1: + # User fixed exactly one backend in `init_process_group` + _world.pg_object_coll_device[group] = devices[0] + elif len(devices) == 0: + # No backend has been registered with this PG (maybe because no + # collective has been run?) We pick cpu as the default and hopefully + # this would lazily init Gloo or other available cpu backend. + _world.pg_object_coll_device[group] = torch.device("cpu") + elif torch.device("cpu") in devices: + # There are multiple backends in this PG and cpu is among them. + # cpu is preferred as the object is in cpu memory. No need for device + # copy. + _world.pg_object_coll_device[group] = torch.device("cpu") + else: + # No cpu in the backend list. Randomly pick the first backend + _world.pg_object_coll_device[group] = devices[0] + + logger.info( + f"Using device {_world.pg_object_coll_device[group]} for object " # noqa: G004 + "collectives." + ) + return _world.pg_object_coll_device[group] # Environment variable to control whether we do a barrier after process group @@ -1271,6 +1363,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None): _world.pg_to_tag.clear() _world.tags_to_pg.clear() _world.pg_coalesce_state.clear() + _world.pg_object_coll_device.clear() # when process group doesn't have an explicit name (only WORLD (default) # process group can have an explicit name), we use global _world.group_count @@ -1286,6 +1379,8 @@ def destroy_process_group(group: Optional[ProcessGroup] = None): del _world.pg_names[pg] del _world.pg_group_ranks[pg] del _world.pg_backend_config[pg] + if pg in _world.pg_object_coll_device: + del _world.pg_object_coll_device[pg] if pg in _world.pg_coalesce_state.keys(): warnings.warn( "Some coalesced collectives haven't been launched when " @@ -2236,7 +2331,7 @@ def all_gather_object(object_list, obj, group=None): _warn_not_in_group("all_gather_object") return - current_device = _get_pg_device(group) + current_device = _get_object_coll_device(group) input_tensor, local_size = _object_to_tensor(obj, current_device) # Gather all local sizes. This is so that we can find the max size, and index @@ -2337,7 +2432,7 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None): # Ensure object_gather_list is specified appropriately. my_rank = get_rank() _validate_output_list_for_rank(my_rank, dst, object_gather_list) - current_device = _get_pg_device(group) + current_device = _get_object_coll_device(group) input_tensor, local_size = _object_to_tensor(obj, current_device) # Gather all local sizes. This is so that we can find the max size, and index @@ -2451,7 +2546,7 @@ def broadcast_object_list(object_list, src=0, group=None, device=None): # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the # case it is not ``None`` we move the size and object tensors to be # broadcasted to this device. - current_device = device or _get_pg_device(group) + current_device = device or _get_object_coll_device(group) my_rank = get_rank() # Serialize object_list elements to tensors on src rank. if my_rank == src: @@ -2556,7 +2651,7 @@ def scatter_object_list( ) my_rank = get_rank() - pg_device = _get_pg_device(group) + pg_device = _get_object_coll_device(group) if my_rank == src: tensor_list, tensor_sizes = zip( *[_object_to_tensor(obj, pg_device) for obj in scatter_object_input_list] diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 52b0f08862035..42e5de575b8d4 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -1488,20 +1488,6 @@ def test_batch_isend_irecv_gloo_tags(self): self._barrier() - # NCCL Batch SEND RECV Tensor Error - @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only") - @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv") - def test_batch_isend_irecv_tensor_err(self): - self._barrier() - rank = dist.get_rank() - if rank == 0: - with self.assertRaisesRegex( - RuntimeError, "Tensors must be CUDA and dense" - ): - send_tensor = _build_tensor(rank + 1) - send_op = dist.P2POp(dist.isend, send_tensor, 1) - dist.batch_isend_irecv([send_op]) - # NCCL Batch SEND RECV Op Error @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only") @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv") diff --git a/torch/testing/_internal/distributed/multi_threaded_pg.py b/torch/testing/_internal/distributed/multi_threaded_pg.py index abe1dcf4d6e2c..7936e5efad495 100644 --- a/torch/testing/_internal/distributed/multi_threaded_pg.py +++ b/torch/testing/_internal/distributed/multi_threaded_pg.py @@ -381,6 +381,7 @@ class WorldData: tags_to_pg: Dict[str, List[dist.ProcessGroup]] pg_to_tag: Dict[dist.ProcessGroup, str] pg_coalesce_state: Dict[dist.ProcessGroup, List[Union[_CollOp, P2POp]]] + pg_object_coll_device: Dict[dist.ProcessGroup, torch.device] class ThreadLocalWorld: @@ -388,7 +389,7 @@ class ThreadLocalWorld: def _get_world(self) -> WorldData: if not hasattr(ThreadLocalWorld._world, "world"): - ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0, {}, {}, {}) + ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0, {}, {}, {}, {}) return ThreadLocalWorld._world.world @property @@ -435,6 +436,10 @@ def pg_to_tag(self): def pg_coalesce_state(self) -> Dict[dist.ProcessGroup, List[Union[_CollOp, P2POp]]]: return self._get_world().pg_coalesce_state + @property + def pg_object_coll_device(self) -> Dict[dist.ProcessGroup, torch.device]: + return self._get_world().pg_object_coll_device + _old_pg_world = None