@@ -578,22 +578,32 @@ def _setup():
578578
579579class _clibrary (object ):
580580
581+ def __find_nvrtc_builtins_libname (self , search_path ):
582+ filelist = os .listdir (search_path )
583+ for f in filelist :
584+ if 'nvrtc-builtins' in f :
585+ return f
586+ return None
587+
581588 def __libname (self , name , head = 'af' , ver_major = AF_VER_MAJOR ):
582589 post = self .__post .replace (_VER_MAJOR_PLACEHOLDER , ver_major )
583590 libname = self .__pre + head + name + post
584591
585592 if os .path .isdir (self .AF_PATH + '/lib64' ):
586- libname_full = self .AF_PATH + '/lib64/' + libname
593+ path_search = self .AF_PATH + '/lib64/'
587594 else :
588- libname_full = self .AF_PATH + '/lib/' + libname
595+ path_search = self .AF_PATH + '/lib/'
589596
590597 if platform .architecture ()[0 ][:2 ] == '64' :
591- libname_site = sys .prefix + '/lib64/' + libname
598+ path_site = sys .prefix + '/lib64/'
592599 else :
593- libname_site = sys .prefix + '/lib/' + libname
600+ path_site = sys .prefix + '/lib/'
594601
595- libname_local = self .AF_PYMODULE_PATH + libname
596- return (libname , libname_full , libname_site , libname_local )
602+ path_local = self .AF_PYMODULE_PATH
603+ return [('' , libname ),
604+ (path_search , libname ),
605+ (path_site , libname ),
606+ (path_local ,libname )]
597607
598608 def set_unsafe (self , name ):
599609 lib = self .__clibs [name ]
@@ -646,14 +656,15 @@ def __init__(self):
646656
647657 for libname in libnames :
648658 try :
649- ct .cdll .LoadLibrary (libname )
659+ full_libname = libname [0 ] + libname [1 ]
660+ ct .cdll .LoadLibrary (full_libname )
650661 if VERBOSE_LOADS :
651- print ('Loaded ' + libname )
662+ print ('Loaded ' + full_libname )
652663 break
653664 except OSError :
654665 if VERBOSE_LOADS :
655666 traceback .print_exc ()
656- print ('Unable to load ' + libname )
667+ print ('Unable to load ' + full_libname )
657668 pass
658669
659670 c_dim4 = c_dim_t * 4
@@ -665,21 +676,36 @@ def __init__(self):
665676 libnames = reversed (self .__libname (name ))
666677 for libname in libnames :
667678 try :
668- ct .cdll .LoadLibrary (libname )
679+ full_libname = libname [0 ] + libname [1 ]
680+
681+ ct .cdll .LoadLibrary (full_libname )
669682 __name = 'unified' if name == '' else name
670- clib = ct .CDLL (libname )
683+ clib = ct .CDLL (full_libname )
671684 self .__clibs [__name ] = clib
672685 err = clib .af_randu (c_pointer (out ), 4 , c_pointer (dims ), Dtype .f32 .value )
673686 if (err == ERR .NONE .value ):
674687 self .__name = __name
675688 clib .af_release_array (out )
676689 if VERBOSE_LOADS :
677- print ('Loaded ' + libname )
690+ print ('Loaded ' + full_libname )
691+
692+ # load nvrtc-builtins library if using cuda
693+ if name == 'cuda' :
694+ nvrtc_name = self .__find_nvrtc_builtins_libname (libname [0 ])
695+ if nvrtc_name :
696+ ct .cdll .LoadLibrary (libname [0 ] + nvrtc_name )
697+
698+ if VERBOSE_LOADS :
699+ print ('Loaded ' + libname [0 ] + nvrtc_name )
700+ else :
701+ if VERBOSE_LOADS :
702+ print ('Could not find local nvrtc-builtins libarary' )
703+
678704 break ;
679705 except OSError :
680706 if VERBOSE_LOADS :
681707 traceback .print_exc ()
682- print ('Unable to load ' + libname )
708+ print ('Unable to load ' + full_libname )
683709 pass
684710
685711 if (self .__name is None ):
@@ -707,6 +733,7 @@ def parse(self, res):
707733 lst .append (key )
708734 return tuple (lst )
709735
736+
710737backend = _clibrary ()
711738
712739def set_backend (name , unsafe = False ):
0 commit comments