3030
3131#include "common_ofi.h"
3232#include "opal/constants.h"
33+ #include "opal/mca/accelerator/accelerator.h"
3334#include "opal/mca/base/mca_base_framework.h"
3435#include "opal/mca/base/mca_base_var.h"
3536#include "opal/mca/hwloc/base/base.h"
3839#include "opal/util/argv.h"
3940#include "opal/util/show_help.h"
4041
42+ extern opal_accelerator_base_module_t opal_accelerator ;
4143opal_common_ofi_module_t opal_common_ofi = {.prov_include = NULL ,
4244 .prov_exclude = NULL ,
4345 .output = -1 };
@@ -915,18 +917,190 @@ static uint32_t get_package_rank(opal_process_info_t *process_info)
915917 return (uint32_t ) process_info -> myprocid .rank ;
916918}
917919
920+ static int get_parent_distance (hwloc_obj_t parent , hwloc_obj_t child , int * distance )
921+ {
922+ int dist = 0 ;
923+
924+ while (child != parent ) {
925+ if (!child ) {
926+ return OPAL_ERROR ;
927+ }
928+ child = child -> parent ;
929+ ++ dist ;
930+ }
931+
932+ * distance = dist ;
933+ return OPAL_SUCCESS ;
934+ }
935+
936+ #if OPAL_OFI_PCI_DATA_AVAILABLE
937+ /**
938+ * @brief Attempt to find a nearest provider from the accelerator.
939+ * Check if opal_accelerator is initialized with a valid PCI device, and find a provider from the
940+ * shortest distance.
941+ * Special cases:
942+ * 1. If not accelerator device is available, returns OPAL_ERR_NOT_AVAILABLE.
943+ * 2. If the provider does not have PCI attributers, we do not attempt to make a selection, and
944+ * return OPAL_ERR_NOT_AVAILABLE.
945+ * 3. If there are more than 1 providers with the same equal distance, break the tie using a modulo
946+ * i.e. (local rank on the same accelerator) % (number of nearest providers)
947+ * @param[in] provider_list linked list of providers
948+ * @param[in] num_providers number of providers
949+ * @param[in] accl_id Accelerator id
950+ * @param[in] device_rank local rank on the accelerator
951+ * @param[out] provider pointer to the selected provider
952+ * @return OPAL_SUCCESS if a provider is successfully selected
953+ * OPAL_ERR_NOT_AVAILABLE if a provider cannot be decided deterministically
954+ * OPAL_ERROR if a fatal error happened
955+ */
956+ static int find_nearest_provider_from_accelerator (struct fi_info * provider_list ,
957+ size_t num_providers ,
958+ int accl_id ,
959+ uint32_t device_rank ,
960+ struct fi_info * * provider )
961+ {
962+ hwloc_obj_t accl_dev = NULL , prov_dev = NULL , common_ancestor = NULL ;
963+ int ret = -1 , accl_distance = -1 , prov_distance = -1 , min_distance = INT_MAX ;
964+ opal_accelerator_pci_attr_t accl_pci_attr = {0 };
965+ struct fi_info * current_provider = NULL ;
966+ struct fi_pci_attr pci = {0 };
967+ uint32_t distances [num_providers ], * distance = distances ;
968+ uint32_t near_provider_count = 0 , provider_rank = 0 ;
969+
970+ memset (distances , 0 , sizeof (distances ));
971+
972+ ret = opal_accelerator .get_device_pci_attr (accl_id , & accl_pci_attr );
973+ if (OPAL_SUCCESS != ret ) {
974+ opal_output_verbose (1 , opal_common_ofi .output ,
975+ "%s:%d:Accelerator PCI info is not available" , __FILE__ , __LINE__ );
976+ return OPAL_ERROR ;
977+ }
978+
979+ accl_dev = hwloc_get_pcidev_by_busid (opal_hwloc_topology , accl_pci_attr .domain_id ,
980+ accl_pci_attr .bus_id , accl_pci_attr .device_id ,
981+ accl_pci_attr .function_id );
982+ if (NULL == accl_dev ) {
983+ opal_output_verbose (1 , opal_common_ofi .output ,
984+ "%s:%d:Failed to find accelerator PCI device" , __FILE__ , __LINE__ );
985+ return OPAL_ERROR ;
986+ }
987+
988+ opal_output_verbose (1 , opal_common_ofi .output ,
989+ "%s:%d:Found accelerator device %d: %04x:%02x:%02x.%x VID: %x DID: %x" ,
990+ __FILE__ , __LINE__ , accl_id , accl_pci_attr .domain_id , accl_pci_attr .bus_id ,
991+ accl_pci_attr .device_id , accl_pci_attr .function_id ,
992+ accl_dev -> attr -> pcidev .vendor_id , accl_dev -> attr -> pcidev .device_id );
993+
994+ current_provider = provider_list ;
995+ while (NULL != current_provider ) {
996+ common_ancestor = NULL ;
997+ if (0 == check_provider_attr (provider_list , current_provider )
998+ && OPAL_SUCCESS == get_provider_nic_pci (current_provider , & pci )) {
999+ prov_dev = hwloc_get_pcidev_by_busid (opal_hwloc_topology , pci .domain_id , pci .bus_id ,
1000+ pci .device_id , pci .function_id );
1001+ if (NULL == prov_dev ) {
1002+ opal_output_verbose (1 , opal_common_ofi .output ,
1003+ "%s:%d:Failed to find provider PCI device" , __FILE__ , __LINE__ );
1004+ return OPAL_ERROR ;
1005+ }
1006+
1007+ common_ancestor = hwloc_get_common_ancestor_obj (opal_hwloc_topology , accl_dev ,
1008+ prov_dev );
1009+ if (!common_ancestor ) {
1010+ opal_output_verbose (
1011+ 1 , opal_common_ofi .output ,
1012+ "%s:%d:Failed to find common ancestor of accelerator and provider PCI device" ,
1013+ __FILE__ , __LINE__ );
1014+ /**
1015+ * Return error because any 2 PCI devices should share at least one common ancestor,
1016+ * i.e. root
1017+ */
1018+ return OPAL_ERROR ;
1019+ }
1020+
1021+ ret = get_parent_distance (common_ancestor , accl_dev , & accl_distance );
1022+ if (OPAL_SUCCESS != ret ) {
1023+ opal_output_verbose (
1024+ 1 , opal_common_ofi .output ,
1025+ "%s:%d:Failed to get distance between common ancestor and accelerator device" ,
1026+ __FILE__ , __LINE__ );
1027+ return OPAL_ERROR ;
1028+ }
1029+
1030+ ret = get_parent_distance (common_ancestor , prov_dev , & prov_distance );
1031+ if (OPAL_SUCCESS != ret ) {
1032+ opal_output_verbose (
1033+ 1 , opal_common_ofi .output ,
1034+ "%s:%d:Failed to get distance between common ancestor and provider device" ,
1035+ __FILE__ , __LINE__ );
1036+ return OPAL_ERROR ;
1037+ }
1038+
1039+ if (min_distance > accl_distance + prov_distance ) {
1040+ min_distance = accl_distance + prov_distance ;
1041+ near_provider_count = 1 ;
1042+ } else if (min_distance == accl_distance + prov_distance ) {
1043+ ++ near_provider_count ;
1044+ }
1045+ }
1046+
1047+ * (distance ++ ) = !common_ancestor ? 0 : accl_distance + prov_distance ;
1048+ current_provider = current_provider -> next ;
1049+ }
1050+
1051+ if (0 == near_provider_count ) {
1052+ opal_output_verbose (1 , opal_common_ofi .output , "%s:%d:Provider does not have PCI device" ,
1053+ __FILE__ , __LINE__ );
1054+ return OPAL_ERR_NOT_AVAILABLE ;
1055+ }
1056+
1057+ provider_rank = device_rank % near_provider_count ;
1058+
1059+ distance = distances ;
1060+ current_provider = provider_list ;
1061+ while (NULL != current_provider ) {
1062+ if ((uint32_t ) min_distance == * (distance ++ )
1063+ && provider_rank == -- near_provider_count ) {
1064+ * provider = current_provider ;
1065+ return OPAL_SUCCESS ;
1066+ }
1067+
1068+ current_provider = current_provider -> next ;
1069+ }
1070+
1071+ assert (0 == near_provider_count );
1072+
1073+ return OPAL_ERROR ;
1074+ }
1075+ #endif /* OPAL_OFI_PCI_DATA_AVAILABLE */
1076+
1077+
9181078struct fi_info * opal_common_ofi_select_provider (struct fi_info * provider_list ,
9191079 opal_process_info_t * process_info )
9201080{
921- int ret , num_providers = 0 ;
1081+ int ret , num_providers = 0 , accel_id = -1 ;
9221082 struct fi_info * provider = NULL ;
923- uint32_t package_rank = process_info -> my_local_rank ;
1083+ uint32_t package_rank ;
9241084
1085+ /* Current process' local rank on the same package(socket) */
1086+ package_rank = process_info -> proc_is_bound ? get_package_rank (process_info )
1087+ : process_info -> my_local_rank ;
9251088 num_providers = count_providers (provider_list );
926- if (!process_info -> proc_is_bound || 2 > num_providers ) {
1089+
1090+ #if OPAL_OFI_PCI_DATA_AVAILABLE
1091+ ret = opal_accelerator .get_device (& accel_id );
1092+ if (OPAL_SUCCESS != ret ) {
1093+ opal_output_verbose (1 , opal_common_ofi .output , "%s:%d:Accelerator is not available" ,
1094+ __FILE__ , __LINE__ );
1095+ accel_id = -1 ;
1096+ }
1097+ #endif /* OPAL_OFI_PCI_DATA_AVAILABLE */
1098+
1099+ if ((!process_info -> proc_is_bound && 0 > accel_id ) || 2 > num_providers ) {
9271100 goto round_robin ;
9281101 }
9291102
1103+ #if OPAL_OFI_PCI_DATA_AVAILABLE
9301104 /* Initialize opal_hwloc_topology if it is not already */
9311105 ret = opal_hwloc_base_get_topology ();
9321106 if (0 > ret ) {
@@ -935,9 +1109,27 @@ struct fi_info *opal_common_ofi_select_provider(struct fi_info *provider_list,
9351109 __FILE__ , __LINE__ );
9361110 }
9371111
938- package_rank = get_package_rank (process_info );
1112+ if (0 <= accel_id ) {
1113+ /**
1114+ * If accelerator is enabled, select the closest provider to the accelerator.
1115+ * Note: the function expects a local rank on the accelerator to break ties if there are
1116+ * multiple equidistant providers. package_rank is NOT an accurate measure, but a proxy.
1117+ */
1118+ ret = find_nearest_provider_from_accelerator (provider_list , num_providers , accel_id ,
1119+ package_rank , & provider );
1120+ if (OPAL_SUCCESS == ret ) {
1121+ goto out ;
1122+ }
1123+
1124+ opal_output_verbose (1 , opal_common_ofi .output ,
1125+ "%s:%d:Failed to find a provider close to the accelerator. Error: %d" ,
1126+ __FILE__ , __LINE__ , ret );
1127+
1128+ if (!process_info -> proc_is_bound ) {
1129+ goto round_robin ;
1130+ }
1131+ }
9391132
940- #if OPAL_OFI_PCI_DATA_AVAILABLE
9411133 /**
9421134 * If provider PCI BDF information is available, we calculate its physical distance
9431135 * to the current process, and select the provider with the shortest distance.
0 commit comments