Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions deep_ep/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def all_gather_object(obj):
# Synchronize NVSHMEM unique IDs
root_unique_id = None
if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode:
self._setup_device_hca_mapping()

# Enable IBGDA
assert num_qps_per_rank > 0
os.environ['NVSHMEM_DISABLE_P2P'] = '0' if allow_nvlink_for_low_latency_mode else '1'
Expand Down Expand Up @@ -133,6 +135,36 @@ def all_gather_object(obj):
self.runtime.sync(device_ids, ipc_handles, root_unique_id)
assert self.runtime.is_available()

def _setup_device_hca_mapping(self):
"""
Set up device to NIC mapping using DEEP_EP_DEVICE_TO_HCA_MAPPING environment variable.
The mapping format is: "0:mlx5_0:1,1:mlx5_1:1,..." where each entry maps a CUDA device ID
to an HCA name separated by colon. HCA name can include additional suffixes like ":1".
"""
if 'DEEP_EP_DEVICE_TO_HCA_MAPPING' in os.environ:
device_mapping = {}
mapping_str = os.environ['DEEP_EP_DEVICE_TO_HCA_MAPPING']
# Parse mapping string like "0:mlx5_0:1,1:mlx5_1:1,..."
for mapping in mapping_str.split(','):
assert ':' in mapping, f"Invalid mapping format '{mapping}' in DEEP_EP_DEVICE_TO_HCA_MAPPING. Expected format: '<device_id>:<hca_name>'"
parts = mapping.split(':', 1) # Split only on first colon
device_id = int(parts[0])
hca_name = parts[1] # Keep the rest as HCA name (including :1)
device_mapping[device_id] = hca_name

# Get current device and set appropriate HCA
current_device = torch.cuda.current_device()
# Translate CUDA_VISIBLE_DEVICES
if 'CUDA_VISIBLE_DEVICES' in os.environ:
visible_devices = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
assert len(visible_devices) > current_device, f"CUDA_VISIBLE_DEVICES has {len(visible_devices)} entries which is fewer than the current device {current_device}"
assert visible_devices[current_device].isdigit(), f"DEEP_EP_DEVICE_TO_HCA_MAPPING requires CUDA_VISIBLE_DEVICES to contain integer indices"
current_device = int(visible_devices[current_device])

assert current_device in device_mapping, f"Current CUDA device {current_device} not found in DEEP_EP_DEVICE_TO_HCA_MAPPING"
os.environ['NVSHMEM_ENABLE_PE_MAPPING'] = '1'
os.environ['NVSHMEM_HCA_LIST'] = device_mapping[current_device]

def destroy(self):
"""
Destroy the cpp runtime and release resources.
Expand Down