@@ -440,10 +440,30 @@ def set_unsafe(self, name):
440440 raise RuntimeError ("Backend not found" )
441441 self .__name = name
442442
443- def __init__ (self ):
444-
443+ def _loadlibs (self ):
444+ """
445+ function that loads ArrayFire upstream libraries
446+ """
445447 more_info_str = "Please look at https://github.com/arrayfire/arrayfire-python/wiki for more information."
446448
449+ # Iterate in reverse order of preference
450+ for name in ('cpu' , 'opencl' , 'cuda' , '' ):
451+ libnames = self .__libname (name )
452+ for libname in libnames :
453+ try :
454+ ct .cdll .LoadLibrary (libname )
455+ __name = 'unified' if name == '' else name
456+ self .__clibs [__name ] = ct .CDLL (libname )
457+ self .__name = __name
458+ break ;
459+ except :
460+ pass
461+
462+ if (self .__name is None ):
463+ raise RuntimeError ("Could not load any ArrayFire libraries.\n " +
464+ more_info_str )
465+
466+ def __init__ (self ):
447467 pre , post , AF_PATH , CUDA_FOUND = _setup ()
448468
449469 self .__pre = pre
@@ -468,7 +488,6 @@ def __init__(self):
468488 'cpu' : 1 ,
469489 'cuda' : 2 ,
470490 'opencl' : 4 }
471-
472491 # Try to pre-load forge library if it exists
473492 libnames = self .__libname ('forge' , '' )
474493 for libname in libnames :
@@ -477,30 +496,15 @@ def __init__(self):
477496 except :
478497 pass
479498
480- # Iterate in reverse order of preference
481- for name in ('cpu' , 'opencl' , 'cuda' , '' ):
482- libnames = self .__libname (name )
483- for libname in libnames :
484- try :
485- ct .cdll .LoadLibrary (libname )
486- __name = 'unified' if name == '' else name
487- self .__clibs [__name ] = ct .CDLL (libname )
488- self .__name = __name
489- break ;
490- except :
491- pass
492-
493- if (self .__name is None ):
494- raise RuntimeError ("Could not load any ArrayFire libraries.\n " +
495- more_info_str )
496-
497499 def get_id (self , name ):
498500 return self .__backend_name_map [name ]
499501
500502 def get_name (self , bk_id ):
501503 return self .__backend_map [bk_id ]
502504
503505 def get (self ):
506+ if (self .__clibs [self .__name ] is None ):
507+ self ._loadlibs ()
504508 return self .__clibs [self .__name ]
505509
506510 def name (self ):
0 commit comments