|
15 | 15 |
|
16 | 16 | try: |
17 | 17 | if is_hip(): |
18 | | - from amdsmi import (AmdSmiException, |
19 | | - amdsmi_get_processor_handle_from_bdf, amdsmi_init, |
20 | | - amdsmi_shut_down, amdsmi_topo_get_link_type) |
| 18 | + from amdsmi import (AmdSmiException, amdsmi_get_processor_handles, |
| 19 | + amdsmi_init, amdsmi_shut_down, |
| 20 | + amdsmi_topo_get_link_type) |
21 | 21 | else: |
22 | 22 | import pynvml |
23 | 23 |
|
@@ -62,25 +62,22 @@ def _is_full_nvlink(device_ids: List[int], world_size) -> bool: |
62 | 62 | so it works on real physical device ids. |
63 | 63 | """ |
64 | 64 | if is_hip(): |
65 | | - # get devices' BDF in order to get XGMI link info from amdsmi |
66 | | - bdf = custom_ar.get_device_bdf(torch.cuda.current_device()) |
67 | | - all_bdf = [0] * world_size |
68 | | - dist.all_gather_object(all_bdf, bdf) |
69 | | - hsmi = [None] * world_size |
70 | | - try: |
71 | | - for i in range(world_size): |
72 | | - bdf_str = str(bytes(all_bdf[i]).decode("utf-8")) |
73 | | - hsmi[i] = amdsmi_get_processor_handle_from_bdf(bdf_str) |
74 | | - for i in range(world_size): |
75 | | - if i != 0: |
76 | | - link_type = amdsmi_topo_get_link_type(hsmi[0], hsmi[i]) |
77 | | - # type is 2 for XGMI |
78 | | - if link_type['hops'] != 1 or link_type['type'] != 2: |
| 65 | + # On ROCm, we instead query if GPUs are connected by 1-hop XGMI |
| 66 | + handles = [amdsmi_get_processor_handles()[i] for i in device_ids] |
| 67 | + for i, handle in enumerate(handles): |
| 68 | + for j, peer_handle in enumerate(handles): |
| 69 | + if i < j: |
| 70 | + try: |
| 71 | + link_type = amdsmi_topo_get_link_type( |
| 72 | + handle, peer_handle) |
| 73 | + # type is 2 for XGMI |
| 74 | + if link_type["hops"] != 1 or link_type["type"] != 2: |
| 75 | + return False |
| 76 | + except AmdSmiException as error: |
| 77 | + logger.error( |
| 78 | + "AMD link detection failed.", |
| 79 | + exc_info=error) |
79 | 80 | return False |
80 | | - except AmdSmiException as e: |
81 | | - logger.warning(e) |
82 | | - return False |
83 | | - return True |
84 | 81 | else: |
85 | 82 | handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids] |
86 | 83 | for i, handle in enumerate(handles): |
|
0 commit comments