Skip to content

Commit f8fc513

Browse files
committed
fix hpu detect issue
Signed-off-by: xinhe3 <xinhe3@habana.ai>
1 parent c898d13 commit f8fc513

File tree

2 files changed

+30
-12
lines changed

2 files changed

+30
-12
lines changed

auto_round/autoround.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282
is_debug_mode,
8383
is_mx_fp,
8484
is_nv_fp,
85-
is_optimum_habana_available,
85+
is_hpex_available,
8686
is_standard_fp,
8787
is_static_wfp8afp8,
8888
is_wfp8afp8,
@@ -375,7 +375,7 @@ def __init__(
375375
self._check_configs()
376376
torch.set_printoptions(precision=3, sci_mode=True)
377377

378-
if is_optimum_habana_available():
378+
if is_hpex_available():
379379
logger.info("optimum Habana is available, import htcore explicitly.")
380380
import habana_frameworks.torch.core as htcore # pylint: disable=E0401
381381
import habana_frameworks.torch.hpu as hthpu # pylint: disable=E0401]
@@ -3238,7 +3238,7 @@ def _scale_loss_and_backward(self, scaler: Any, loss: torch.Tensor) -> torch.Ten
32383238
"""
32393239
scale_loss = loss * 1000
32403240
scale_loss.backward()
3241-
if is_optimum_habana_available():
3241+
if is_hpex_available():
32423242
htcore.mark_step()
32433243
return scale_loss
32443244

@@ -3255,7 +3255,7 @@ def _step(self, scaler: Any, optimizer: Any, lr_schedule: Any):
32553255
"""
32563256
optimizer.step()
32573257
# for hpu
3258-
if is_optimum_habana_available():
3258+
if is_hpex_available():
32593259
htcore.mark_step()
32603260
optimizer.zero_grad()
32613261
lr_schedule.step()
@@ -3433,7 +3433,7 @@ def _scale_loss_and_backward(self, scaler, loss):
34333433
loss = scaler.scale(loss)
34343434

34353435
loss.backward()
3436-
if is_optimum_habana_available():
3436+
if is_hpex_available():
34373437
htcore.mark_step()
34383438
return loss
34393439

@@ -3447,5 +3447,5 @@ def _step(self, scaler, optimizer, lr_schedule):
34473447
optimizer.step()
34483448
optimizer.zero_grad()
34493449
lr_schedule.step()
3450-
if is_optimum_habana_available():
3450+
if is_hpex_available():
34513451
htcore.mark_step()

auto_round/utils.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def __init__(self, module_name):
115115
"""Init LazyImport object.
116116
117117
Args:
118-
module_name (string): The name of module imported later
118+
module_name (string): The name of module imported later
119119
"""
120120
self.module_name = module_name
121121
self.module = None
@@ -143,13 +143,31 @@ def __call__(self, *args, **kwargs):
143143
auto_gptq = LazyImport("auto_gptq")
144144
htcore = LazyImport("habana_frameworks.torch.core")
145145

146+
################ Check available sys.module to decide behavior #################
147+
def is_package_available(package_name):
148+
"""Check if the package exists in the environment without importing.
149+
150+
Args:
151+
package_name (str): package name
152+
"""
153+
from importlib.util import find_spec
154+
155+
package_spec = find_spec(package_name)
156+
return package_spec is not None
157+
158+
159+
## check hpex
160+
if is_package_available("habana_frameworks"):
161+
_hpex_available = True
162+
import habana_frameworks.torch.hpex # pylint: disable=E0401
163+
else:
164+
_hpex_available = False
165+
146166

147167
@torch._dynamo.disable()
148168
@lru_cache(None)
149-
def is_optimum_habana_available():
150-
from transformers.utils.import_utils import is_optimum_available
151-
152-
return is_optimum_available() and importlib.util.find_spec("optimum.habana") is not None
169+
def is_hpex_available():
170+
return _hpex_available
153171

154172

155173
def get_module(module, key):
@@ -552,7 +570,7 @@ def is_valid_digit(s):
552570
if torch.cuda.is_available():
553571
device = torch.device("cuda")
554572
# logger.info("Using GPU device")
555-
elif is_optimum_habana_available(): # pragma: no cover
573+
elif is_hpex_available(): # pragma: no cover
556574
device = torch.device("hpu")
557575
# logger.info("Using HPU device")
558576
elif torch.xpu.is_available(): # pragma: no cover

0 commit comments

Comments
 (0)