@@ -553,6 +553,34 @@ def nccl_integrity_check(filepath):
553553 return version .value
554554
555555
556+ @lru_cache (maxsize = None )
557+ def find_library (lib_name : str ) -> str :
558+ """
559+ Find the library file in the system.
560+ `lib_name` is full filename, with both prefix and suffix.
561+ This function resolves `lib_name` to the full path of the library.
562+ """
563+ # Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa
564+ # According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard
565+ # `/sbin/ldconfig` should exist in all Linux systems.
566+ # `/sbin/ldconfig` searches the library in the system
567+ libs = subprocess .check_output (["/sbin/ldconfig" , "-p" ]).decode ()
568+ # each line looks like the following:
569+ # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
570+ locs = [line .split ()[- 1 ] for line in libs .splitlines () if lib_name in line ]
571+ # `LD_LIBRARY_PATH` searches the library in the user-defined paths
572+ env_ld_library_path = os .getenv ("LD_LIBRARY_PATH" )
573+ if not locs and env_ld_library_path :
574+ locs = [
575+ os .path .join (dir , lib_name )
576+ for dir in env_ld_library_path .split (":" )
577+ if os .path .exists (os .path .join (dir , lib_name ))
578+ ]
579+ if not locs :
580+ raise ValueError (f"Cannot find { lib_name } in the system." )
581+ return locs [0 ]
582+
583+
556584def find_nccl_library ():
557585 so_file = os .environ .get ("VLLM_NCCL_SO_PATH" , "" )
558586
@@ -572,9 +600,9 @@ def find_nccl_library():
572600 )
573601 else :
574602 if torch .version .cuda is not None :
575- so_file = vllm_nccl_path or "libnccl.so.2"
603+ so_file = vllm_nccl_path or find_library ( "libnccl.so.2" )
576604 elif torch .version .hip is not None :
577- so_file = "librccl.so.1"
605+ so_file = find_library ( "librccl.so.1" )
578606 else :
579607 raise ValueError ("NCCL only supports CUDA and ROCm backends." )
580608 logger .info (f"Found nccl from library { so_file } " )
0 commit comments