forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[c10d] Introduce a util for detecting DMA connectivity among devices (p…
…ytorch#129510) This PR introduces `_detect_dma_connectivity` - a utility for detecting DMA connectivity among devices. The "DMA connectivity" in this context is more stringent than the ability to perform memory copy without CPU involvement. We define it as the ability for a device to issue load/store instructions and perform atomic operations on memory that resides on connected devices. The ability translates to the ability to run most aten GPU operations with operands backed by remote memory. `_detect_dma_connectivity` can help PyTorch and its users to determine whether certain DMA-based optimizations are possible. `_detect_dma_connectivity` takes a `(device_type, connection_type)` pair and returns a matrix describing the connectivity. Connectivity detectors are statically registered on a `(device_type, connection_type)` basis. This PR implements the detector for `(CUDA, "nvlink")`. Later, detectors for pairs such as `(ROCM, "infinity_fabric")` can be introduced. Example: ```python3 >>> from torch._C._autograd import DeviceType >>> from torch._C._distributed_c10d import _detect_dma_connectivity >>> connectivity = _detect_dma_connectivity(DeviceType.CUDA, "nvlink") >>> for row in connectivity.matrix: ... print(row) ... [0, 18, 18, 18, 18, 18, 18, 18] [18, 0, 18, 18, 18, 18, 18, 18] [18, 18, 0, 18, 18, 18, 18, 18] [18, 18, 18, 0, 18, 18, 18, 18] [18, 18, 18, 18, 0, 18, 18, 18] [18, 18, 18, 18, 18, 0, 18, 18] [18, 18, 18, 18, 18, 18, 0, 18] [18, 18, 18, 18, 18, 18, 18, 0] ``` Pull Request resolved: pytorch#129510 Approved by: https://github.com/weifengpy
- Loading branch information
1 parent
305ba62
commit 67416a2
Showing
8 changed files
with
281 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) | ||
#include <torch/csrc/distributed/c10d/DMAConnectivity.hpp> | ||
|
||
#include <c10/cuda/CUDAException.h> | ||
#include <c10/cuda/driver_api.h> | ||
|
||
#include <cuda_runtime.h> | ||
#include <nvml.h> | ||
|
||
namespace { | ||
|
||
constexpr int max_nvlinks = 64; | ||
|
||
std::string get_bus_id(int device_idx) { | ||
char bus_id[80]; | ||
cudaDeviceProp prop{}; | ||
C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_idx)); | ||
snprintf( | ||
bus_id, | ||
sizeof(bus_id), | ||
NVML_DEVICE_PCI_BUS_ID_FMT, | ||
prop.pciDomainID, | ||
prop.pciBusID, | ||
prop.pciDeviceID); | ||
return std::string(bus_id); | ||
} | ||
|
||
struct C10_EXPORT NVLinkDetector : public c10d::DMAConnectivityDetector { | ||
c10::intrusive_ptr<c10d::DMAConnectivity> detect() override { | ||
int num_devices; | ||
C10_CUDA_CHECK(cudaGetDeviceCount(&num_devices)); | ||
|
||
std::vector<std::vector<int>> matrix; | ||
matrix.reserve(num_devices); | ||
for (int i = 0; i < num_devices; ++i) { | ||
matrix.emplace_back(num_devices, 0); | ||
} | ||
|
||
// Obtain the bus_id for all visible devices | ||
std::unordered_map<std::string, int> bus_id_to_device_idx; | ||
std::vector<std::string> bus_ids; | ||
bus_ids.reserve(num_devices); | ||
for (int i = 0; i < num_devices; ++i) { | ||
auto bus_id = get_bus_id(i); | ||
bus_id_to_device_idx.emplace(bus_id, i); | ||
bus_ids.push_back(std::move(bus_id)); | ||
} | ||
|
||
// Obtain the nvml device for all bus_ids | ||
auto driver_api = c10::cuda::DriverAPI::get(); | ||
std::vector<nvmlDevice_t> nvml_devices(num_devices, nullptr); | ||
for (int i = 0; i < num_devices; ++i) { | ||
TORCH_CHECK_EQ( | ||
driver_api->nvmlDeviceGetHandleByPciBusId_v2_( | ||
bus_ids[i].c_str(), &nvml_devices[i]), | ||
NVML_SUCCESS); | ||
} | ||
|
||
std::vector<int> switch_link_count(num_devices, 0); | ||
for (int i = 0; i < num_devices; ++i) { | ||
for (int link = 0; link < max_nvlinks; ++link) { | ||
nvmlReturn_t ret; | ||
nvmlIntNvLinkDeviceType_t deviceType; | ||
ret = driver_api->nvmlDeviceGetNvLinkRemoteDeviceType_( | ||
nvml_devices[i], link, &deviceType); | ||
if (ret != NVML_SUCCESS) { | ||
// We've exhausted the NVLinks connected to this device. This error | ||
// is benign. There doesn't seem to be a reliable way to obtain the | ||
// maximum link value that can be passed to the API. Therefore, we | ||
// simply increment the link value until the API fails or we reach a | ||
// predefined maximum value. | ||
break; | ||
} | ||
// Remote device is GPU | ||
if (deviceType == NVML_NVLINK_DEVICE_TYPE_GPU) { | ||
nvmlPciInfo_t pciInfo; | ||
TORCH_CHECK_EQ( | ||
driver_api->nvmlDeviceGetNvLinkRemotePciInfo_v2_( | ||
nvml_devices[i], link, &pciInfo), | ||
NVML_SUCCESS); | ||
auto it = bus_id_to_device_idx.find(pciInfo.busId); | ||
if (it != bus_id_to_device_idx.end()) { | ||
if (i != it->second) { | ||
matrix[i][it->second] += 1; | ||
} | ||
} | ||
// Remote device is NVSwitch | ||
} else if (deviceType == NVML_NVLINK_DEVICE_TYPE_SWITCH) { | ||
switch_link_count[i] += 1; | ||
} | ||
} | ||
} | ||
|
||
// Process NVSwitch connections. | ||
// For simplicity, we assume that all NVSwitches are interconnected. | ||
for (int i = 0; i < num_devices; ++i) { | ||
for (int j = 0; j < num_devices; ++j) { | ||
if (i == j) { | ||
continue; | ||
} | ||
matrix[i][j] += std::min(switch_link_count[i], switch_link_count[j]); | ||
} | ||
} | ||
|
||
return c10::make_intrusive<c10d::DMAConnectivity>( | ||
c10::DeviceType::CUDA, "nvlink", std::move(matrix)); | ||
} | ||
}; | ||
|
||
struct RegisterDetector { | ||
RegisterDetector() { | ||
register_dma_connectivity_detector( | ||
c10::DeviceType::CUDA, "nvlink", c10::make_intrusive<NVLinkDetector>()); | ||
} | ||
}; | ||
|
||
static RegisterDetector register_detector_; | ||
|
||
} // namespace | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
#include <torch/csrc/distributed/c10d/DMAConnectivity.hpp> | ||
|
||
namespace { | ||
|
||
std::string get_detector_key( | ||
c10::DeviceType device_type, | ||
std::string connection_type) { | ||
std::ostringstream oss; | ||
oss << device_type << "/" << connection_type; | ||
return oss.str(); | ||
} | ||
|
||
class DetectorMap { | ||
public: | ||
static DetectorMap& get() { | ||
static DetectorMap instance; | ||
return instance; | ||
} | ||
|
||
void register_detector( | ||
c10::DeviceType device_type, | ||
const std::string& connection_type, | ||
c10::intrusive_ptr<c10d::DMAConnectivityDetector> detector) { | ||
auto key = get_detector_key(device_type, connection_type); | ||
detector_map_[key] = std::move(detector); | ||
} | ||
|
||
c10::intrusive_ptr<c10d::DMAConnectivity> detect( | ||
c10::DeviceType device_type, | ||
const std::string& connection_type) { | ||
auto key = get_detector_key(device_type, connection_type); | ||
{ | ||
auto it = cached_.find(key); | ||
if (it != cached_.end()) { | ||
return it->second; | ||
} | ||
} | ||
|
||
auto it = detector_map_.find(key); | ||
TORCH_CHECK( | ||
it != detector_map_.end(), | ||
"DMA connectivity detector for ", | ||
device_type, | ||
" over ", | ||
connection_type, | ||
" is not available"); | ||
auto detector = it->second; | ||
auto connectivity = detector->detect(); | ||
cached_[key] = connectivity; | ||
return connectivity; | ||
} | ||
|
||
private: | ||
DetectorMap() = default; | ||
DetectorMap(const DetectorMap&) = delete; | ||
DetectorMap& operator=(const DetectorMap&) = delete; | ||
|
||
std::unordered_map< | ||
std::string, | ||
c10::intrusive_ptr<c10d::DMAConnectivityDetector>> | ||
detector_map_; | ||
|
||
std::unordered_map<std::string, c10::intrusive_ptr<c10d::DMAConnectivity>> | ||
cached_; | ||
}; | ||
|
||
}; // namespace | ||
|
||
namespace c10d { | ||
|
||
DMAConnectivity::DMAConnectivity( | ||
c10::DeviceType device_type, | ||
std::string connection_type, | ||
std::vector<std::vector<int>> matrix) | ||
: device_type(device_type), | ||
connection_type(connection_type), | ||
matrix(std::move(matrix)) {} | ||
|
||
void register_dma_connectivity_detector( | ||
c10::DeviceType device_type, | ||
const std::string& connection_type, | ||
c10::intrusive_ptr<DMAConnectivityDetector> detector) { | ||
return DetectorMap::get().register_detector( | ||
device_type, connection_type, std::move(detector)); | ||
} | ||
|
||
c10::intrusive_ptr<DMAConnectivity> detect_dma_connectivity( | ||
c10::DeviceType device_type, | ||
const std::string& connection_type) { | ||
return DetectorMap::get().detect(device_type, connection_type); | ||
} | ||
|
||
} // namespace c10d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#pragma once | ||
|
||
#include <optional> | ||
|
||
#include <ATen/ATen.h> | ||
|
||
namespace c10d { | ||
|
||
struct TORCH_API DMAConnectivity : c10::intrusive_ptr_target { | ||
c10::DeviceType device_type; | ||
std::string connection_type; | ||
|
||
// This is an NxN matrix representing the connectivity between N devices, | ||
// where each element matrix[i][j] indicates the connectivity between device | ||
// i and device j. A value of 0 denotes that there is no connection between | ||
// device i and j. The meaning of non-zero values are specific to the | ||
// connection type (e.g., for NVLink it represents the number of NVLinks). | ||
std::vector<std::vector<int>> matrix; | ||
|
||
explicit DMAConnectivity( | ||
c10::DeviceType device_type, | ||
std::string connection_type, | ||
std::vector<std::vector<int>> matrix); | ||
}; | ||
|
||
struct DMAConnectivityDetector : c10::intrusive_ptr_target { | ||
virtual c10::intrusive_ptr<DMAConnectivity> detect() = 0; | ||
virtual ~DMAConnectivityDetector() {} | ||
}; | ||
|
||
C10_EXPORT void register_dma_connectivity_detector( | ||
c10::DeviceType device_type, | ||
const std::string& connection_type, | ||
c10::intrusive_ptr<DMAConnectivityDetector> detector); | ||
|
||
TORCH_API c10::intrusive_ptr<DMAConnectivity> detect_dma_connectivity( | ||
c10::DeviceType device_type, | ||
const std::string& connection_type); | ||
|
||
} // namespace c10d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters