Skip to content

Commit 6f04bb0

Browse files
committed
load local nvrtc-builtins when possible
1 parent c8ed913 commit 6f04bb0

File tree

1 file changed

+40
-13
lines changed

1 file changed

+40
-13
lines changed

arrayfire/library.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -578,22 +578,32 @@ def _setup():
578578

579579
class _clibrary(object):
580580

581+
def __find_nvrtc_builtins_libname(self, search_path):
582+
filelist = os.listdir(search_path)
583+
for f in filelist:
584+
if 'nvrtc-builtins' in f:
585+
return f
586+
return None
587+
581588
def __libname(self, name, head='af', ver_major=AF_VER_MAJOR):
582589
post = self.__post.replace(_VER_MAJOR_PLACEHOLDER, ver_major)
583590
libname = self.__pre + head + name + post
584591

585592
if os.path.isdir(self.AF_PATH + '/lib64'):
586-
libname_full = self.AF_PATH + '/lib64/' + libname
593+
path_search = self.AF_PATH + '/lib64/'
587594
else:
588-
libname_full = self.AF_PATH + '/lib/' + libname
595+
path_search = self.AF_PATH + '/lib/'
589596

590597
if platform.architecture()[0][:2] == '64':
591-
libname_site = sys.prefix + '/lib64/' + libname
598+
path_site = sys.prefix + '/lib64/'
592599
else:
593-
libname_site = sys.prefix + '/lib/' + libname
600+
path_site = sys.prefix + '/lib/'
594601

595-
libname_local = self.AF_PYMODULE_PATH + libname
596-
return (libname, libname_full, libname_site, libname_local)
602+
path_local = self.AF_PYMODULE_PATH
603+
return [('', libname),
604+
(path_search, libname),
605+
(path_site, libname),
606+
(path_local,libname)]
597607

598608
def set_unsafe(self, name):
599609
lib = self.__clibs[name]
@@ -646,14 +656,15 @@ def __init__(self):
646656

647657
for libname in libnames:
648658
try:
649-
ct.cdll.LoadLibrary(libname)
659+
full_libname = libname[0] + libname[1]
660+
ct.cdll.LoadLibrary(full_libname)
650661
if VERBOSE_LOADS:
651-
print('Loaded ' + libname)
662+
print('Loaded ' + full_libname)
652663
break
653664
except OSError:
654665
if VERBOSE_LOADS:
655666
traceback.print_exc()
656-
print('Unable to load ' + libname)
667+
print('Unable to load ' + full_libname)
657668
pass
658669

659670
c_dim4 = c_dim_t*4
@@ -665,21 +676,36 @@ def __init__(self):
665676
libnames = reversed(self.__libname(name))
666677
for libname in libnames:
667678
try:
668-
ct.cdll.LoadLibrary(libname)
679+
full_libname = libname[0] + libname[1]
680+
681+
ct.cdll.LoadLibrary(full_libname)
669682
__name = 'unified' if name == '' else name
670-
clib = ct.CDLL(libname)
683+
clib = ct.CDLL(full_libname)
671684
self.__clibs[__name] = clib
672685
err = clib.af_randu(c_pointer(out), 4, c_pointer(dims), Dtype.f32.value)
673686
if (err == ERR.NONE.value):
674687
self.__name = __name
675688
clib.af_release_array(out)
676689
if VERBOSE_LOADS:
677-
print('Loaded ' + libname)
690+
print('Loaded ' + full_libname)
691+
692+
# load nvrtc-builtins library if using cuda
693+
if name == 'cuda':
694+
nvrtc_name = self.__find_nvrtc_builtins_libname(libname[0])
695+
if nvrtc_name:
696+
ct.cdll.LoadLibrary(libname[0] + nvrtc_name)
697+
698+
if VERBOSE_LOADS:
699+
print('Loaded ' + libname[0] + nvrtc_name)
700+
else:
701+
if VERBOSE_LOADS:
702+
print('Could not find local nvrtc-builtins libarary')
703+
678704
break;
679705
except OSError:
680706
if VERBOSE_LOADS:
681707
traceback.print_exc()
682-
print('Unable to load ' + libname)
708+
print('Unable to load ' + full_libname)
683709
pass
684710

685711
if (self.__name is None):
@@ -707,6 +733,7 @@ def parse(self, res):
707733
lst.append(key)
708734
return tuple(lst)
709735

736+
710737
backend = _clibrary()
711738

712739
def set_backend(name, unsafe=False):

0 commit comments

Comments
 (0)