diff --git a/.lintrunner.toml b/.lintrunner.toml index 149861875b0b1..d524a4b6937be 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -68,6 +68,7 @@ include_patterns = [ 'aten/src/ATen/native/cudnn/*.cpp', 'c10/**/*.h', 'c10/**/*.cpp', + 'distributed/c10d/*DMAConnectivity.*', 'distributed/c10d/*SymmetricMemory.*', 'torch/csrc/**/*.h', 'torch/csrc/**/*.hpp', diff --git a/build_variables.bzl b/build_variables.bzl index 31ee8754946c6..d0f426de61187 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -488,6 +488,7 @@ libtorch_core_sources = sorted( libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/Backend.cpp", "torch/csrc/distributed/c10d/Backoff.cpp", + "torch/csrc/distributed/c10d/DMAConnectivity.cpp", "torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp", "torch/csrc/distributed/c10d/FileStore.cpp", "torch/csrc/distributed/c10d/Functional.cpp", @@ -677,6 +678,7 @@ libtorch_cuda_distributed_base_sources = [ # These files are only supported on Linux (and others) but not on Windows. libtorch_cuda_distributed_extra_sources = [ + "torch/csrc/distributed/c10d/CudaDMAConnectivity.cpp", "torch/csrc/distributed/c10d/NCCLUtils.cpp", "torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp", "torch/csrc/distributed/c10d/ProcessGroupUCC.cpp", diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 34d67770c526f..3738af188ca6b 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -560,6 +560,7 @@ if(USE_CUDA) append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS) set_source_files_properties( ${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp + ${TORCH_SRC_DIR}/csrc/distributed/c10d/CudaDMAConnectivity.cpp ${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemory.cu PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1" ) diff --git a/test/distributed/test_symmetric_memory.py b/test/distributed/test_symmetric_memory.py index 3e01515bdc9ec..bdd0da8a1d1fb 100644 --- a/test/distributed/test_symmetric_memory.py +++ b/test/distributed/test_symmetric_memory.py @@ -92,6 +92,19 @@ def _verify_symmetric_memory(self, symm_mem): symm_mem.barrier() + @skipIfRocm + @skip_if_lt_x_gpu(2) + def test_cuda_nvlink_connectivity_detection(self) -> None: + from torch._C._autograd import DeviceType + from torch._C._distributed_c10d import _detect_dma_connectivity + + connectivity = _detect_dma_connectivity(DeviceType.CUDA, "nvlink") + self.assertEqual(connectivity.device_type, DeviceType.CUDA) + self.assertEqual(connectivity.connection_type, "nvlink") + self.assertEqual(len(connectivity.matrix), torch.cuda.device_count()) + for row in connectivity.matrix: + self.assertEqual(len(row), torch.cuda.device_count()) + @skipIfRocm @skip_if_lt_x_gpu(2) def test_empty_strided_p2p(self) -> None: diff --git a/torch/csrc/distributed/c10d/CudaDMAConnectivity.cpp b/torch/csrc/distributed/c10d/CudaDMAConnectivity.cpp new file mode 100644 index 0000000000000..afb39bdff92e8 --- /dev/null +++ b/torch/csrc/distributed/c10d/CudaDMAConnectivity.cpp @@ -0,0 +1,120 @@ +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) +#include + +#include +#include + +#include +#include + +namespace { + +constexpr int max_nvlinks = 64; + +std::string get_bus_id(int device_idx) { + char bus_id[80]; + cudaDeviceProp prop{}; + C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_idx)); + snprintf( + bus_id, + sizeof(bus_id), + NVML_DEVICE_PCI_BUS_ID_FMT, + prop.pciDomainID, + prop.pciBusID, + prop.pciDeviceID); + return std::string(bus_id); +} + +struct C10_EXPORT NVLinkDetector : public c10d::DMAConnectivityDetector { + c10::intrusive_ptr detect() override { + int num_devices; + C10_CUDA_CHECK(cudaGetDeviceCount(&num_devices)); + + std::vector> matrix; + matrix.reserve(num_devices); + for (int i = 0; i < num_devices; ++i) { + matrix.emplace_back(num_devices, 0); + } + + // Obtain the bus_id for all visible devices + std::unordered_map bus_id_to_device_idx; + std::vector bus_ids; + bus_ids.reserve(num_devices); + for (int i = 0; i < num_devices; ++i) { + auto bus_id = get_bus_id(i); + bus_id_to_device_idx.emplace(bus_id, i); + bus_ids.push_back(std::move(bus_id)); + } + + // Obtain the nvml device for all bus_ids + auto driver_api = c10::cuda::DriverAPI::get(); + std::vector nvml_devices(num_devices, nullptr); + for (int i = 0; i < num_devices; ++i) { + TORCH_CHECK_EQ( + driver_api->nvmlDeviceGetHandleByPciBusId_v2_( + bus_ids[i].c_str(), &nvml_devices[i]), + NVML_SUCCESS); + } + + std::vector switch_link_count(num_devices, 0); + for (int i = 0; i < num_devices; ++i) { + for (int link = 0; link < max_nvlinks; ++link) { + nvmlReturn_t ret; + nvmlIntNvLinkDeviceType_t deviceType; + ret = driver_api->nvmlDeviceGetNvLinkRemoteDeviceType_( + nvml_devices[i], link, &deviceType); + if (ret != NVML_SUCCESS) { + // We've exhausted the NVLinks connected to this device. This error + // is benign. There doesn't seem to be a reliable way to obtain the + // maximum link value that can be passed to the API. Therefore, we + // simply increment the link value until the API fails or we reach a + // predefined maximum value. + break; + } + // Remote device is GPU + if (deviceType == NVML_NVLINK_DEVICE_TYPE_GPU) { + nvmlPciInfo_t pciInfo; + TORCH_CHECK_EQ( + driver_api->nvmlDeviceGetNvLinkRemotePciInfo_v2_( + nvml_devices[i], link, &pciInfo), + NVML_SUCCESS); + auto it = bus_id_to_device_idx.find(pciInfo.busId); + if (it != bus_id_to_device_idx.end()) { + if (i != it->second) { + matrix[i][it->second] += 1; + } + } + // Remote device is NVSwitch + } else if (deviceType == NVML_NVLINK_DEVICE_TYPE_SWITCH) { + switch_link_count[i] += 1; + } + } + } + + // Process NVSwitch connections. + // For simplicity, we assume that all NVSwitches are interconnected. + for (int i = 0; i < num_devices; ++i) { + for (int j = 0; j < num_devices; ++j) { + if (i == j) { + continue; + } + matrix[i][j] += std::min(switch_link_count[i], switch_link_count[j]); + } + } + + return c10::make_intrusive( + c10::DeviceType::CUDA, "nvlink", std::move(matrix)); + } +}; + +struct RegisterDetector { + RegisterDetector() { + register_dma_connectivity_detector( + c10::DeviceType::CUDA, "nvlink", c10::make_intrusive()); + } +}; + +static RegisterDetector register_detector_; + +} // namespace +#endif diff --git a/torch/csrc/distributed/c10d/DMAConnectivity.cpp b/torch/csrc/distributed/c10d/DMAConnectivity.cpp new file mode 100644 index 0000000000000..d920eb567197f --- /dev/null +++ b/torch/csrc/distributed/c10d/DMAConnectivity.cpp @@ -0,0 +1,93 @@ +#include + +namespace { + +std::string get_detector_key( + c10::DeviceType device_type, + std::string connection_type) { + std::ostringstream oss; + oss << device_type << "/" << connection_type; + return oss.str(); +} + +class DetectorMap { + public: + static DetectorMap& get() { + static DetectorMap instance; + return instance; + } + + void register_detector( + c10::DeviceType device_type, + const std::string& connection_type, + c10::intrusive_ptr detector) { + auto key = get_detector_key(device_type, connection_type); + detector_map_[key] = std::move(detector); + } + + c10::intrusive_ptr detect( + c10::DeviceType device_type, + const std::string& connection_type) { + auto key = get_detector_key(device_type, connection_type); + { + auto it = cached_.find(key); + if (it != cached_.end()) { + return it->second; + } + } + + auto it = detector_map_.find(key); + TORCH_CHECK( + it != detector_map_.end(), + "DMA connectivity detector for ", + device_type, + " over ", + connection_type, + " is not available"); + auto detector = it->second; + auto connectivity = detector->detect(); + cached_[key] = connectivity; + return connectivity; + } + + private: + DetectorMap() = default; + DetectorMap(const DetectorMap&) = delete; + DetectorMap& operator=(const DetectorMap&) = delete; + + std::unordered_map< + std::string, + c10::intrusive_ptr> + detector_map_; + + std::unordered_map> + cached_; +}; + +}; // namespace + +namespace c10d { + +DMAConnectivity::DMAConnectivity( + c10::DeviceType device_type, + std::string connection_type, + std::vector> matrix) + : device_type(device_type), + connection_type(connection_type), + matrix(std::move(matrix)) {} + +void register_dma_connectivity_detector( + c10::DeviceType device_type, + const std::string& connection_type, + c10::intrusive_ptr detector) { + return DetectorMap::get().register_detector( + device_type, connection_type, std::move(detector)); +} + +c10::intrusive_ptr detect_dma_connectivity( + c10::DeviceType device_type, + const std::string& connection_type) { + return DetectorMap::get().detect(device_type, connection_type); +} + +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/DMAConnectivity.hpp b/torch/csrc/distributed/c10d/DMAConnectivity.hpp new file mode 100644 index 0000000000000..cede6aa265c77 --- /dev/null +++ b/torch/csrc/distributed/c10d/DMAConnectivity.hpp @@ -0,0 +1,40 @@ +#pragma once + +#include + +#include + +namespace c10d { + +struct TORCH_API DMAConnectivity : c10::intrusive_ptr_target { + c10::DeviceType device_type; + std::string connection_type; + + // This is an NxN matrix representing the connectivity between N devices, + // where each element matrix[i][j] indicates the connectivity between device + // i and device j. A value of 0 denotes that there is no connection between + // device i and j. The meaning of non-zero values are specific to the + // connection type (e.g., for NVLink it represents the number of NVLinks). + std::vector> matrix; + + explicit DMAConnectivity( + c10::DeviceType device_type, + std::string connection_type, + std::vector> matrix); +}; + +struct DMAConnectivityDetector : c10::intrusive_ptr_target { + virtual c10::intrusive_ptr detect() = 0; + virtual ~DMAConnectivityDetector() {} +}; + +C10_EXPORT void register_dma_connectivity_detector( + c10::DeviceType device_type, + const std::string& connection_type, + c10::intrusive_ptr detector); + +TORCH_API c10::intrusive_ptr detect_dma_connectivity( + c10::DeviceType device_type, + const std::string& connection_type); + +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 505a30a3dc237..a2cd6b420dcf5 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -39,6 +39,7 @@ #include #include +#include #include #include @@ -975,6 +976,16 @@ This class does not support ``__members__`` property.)"); "global_ranks_in_group", &::c10d::DistributedBackendOptions::global_ranks_in_group); + py::class_< + ::c10d::DMAConnectivity, + c10::intrusive_ptr<::c10d::DMAConnectivity>>(module, "_DMAConnectivity") + .def_readonly("device_type", &::c10d::DMAConnectivity::device_type) + .def_readonly( + "connection_type", &::c10d::DMAConnectivity::connection_type) + .def_readonly("matrix", &::c10d::DMAConnectivity::matrix); + + module.def("_detect_dma_connectivity", ::c10d::detect_dma_connectivity); + using SymmetricMemory = ::c10d::symmetric_memory::SymmetricMemory; py::class_>( module, "_SymmetricMemory")