@@ -115,7 +115,7 @@ def __init__(self, module_name):
115115 """Init LazyImport object.
116116
117117 Args:
118- module_name (string): The name of module imported later
118+ module_name (string): The name of module imported later
119119 """
120120 self .module_name = module_name
121121 self .module = None
@@ -143,13 +143,31 @@ def __call__(self, *args, **kwargs):
143143auto_gptq = LazyImport ("auto_gptq" )
144144htcore = LazyImport ("habana_frameworks.torch.core" )
145145
146+ ################ Check available sys.module to decide behavior #################
147+ def is_package_available (package_name ):
148+ """Check if the package exists in the environment without importing.
149+
150+ Args:
151+ package_name (str): package name
152+ """
153+ from importlib .util import find_spec
154+
155+ package_spec = find_spec (package_name )
156+ return package_spec is not None
157+
158+
159+ ## check hpex
160+ if is_package_available ("habana_frameworks" ):
161+ _hpex_available = True
162+ import habana_frameworks .torch .hpex # pylint: disable=E0401
163+ else :
164+ _hpex_available = False
165+
146166
147167@torch ._dynamo .disable ()
148168@lru_cache (None )
149- def is_optimum_habana_available ():
150- from transformers .utils .import_utils import is_optimum_available
151-
152- return is_optimum_available () and importlib .util .find_spec ("optimum.habana" ) is not None
169+ def is_hpex_available ():
170+ return _hpex_available
153171
154172
155173def get_module (module , key ):
@@ -552,7 +570,7 @@ def is_valid_digit(s):
552570 if torch .cuda .is_available ():
553571 device = torch .device ("cuda" )
554572 # logger.info("Using GPU device")
555- elif is_optimum_habana_available (): # pragma: no cover
573+ elif is_hpex_available (): # pragma: no cover
556574 device = torch .device ("hpu" )
557575 # logger.info("Using HPU device")
558576 elif torch .xpu .is_available (): # pragma: no cover
0 commit comments