diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index bd7484cc9fc305..b972d285aed3e5 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -813,23 +813,24 @@ def require_torch_multi_npu(test_case): def require_torch_xpu(test_case): """ - Decorator marking a test that requires XPU and IPEX. + Decorator marking a test that requires XPU (in PyTorch). - These tests are skipped when Intel Extension for PyTorch isn't installed or it does not match current PyTorch - version. + These tests are skipped when XPU backend is not available. XPU backend might be available either via stock + PyTorch (>=2.4) or via Intel Extension for PyTorch. In the latter case, if IPEX is installed, its version + must match match current PyTorch version. """ - return unittest.skipUnless(is_torch_xpu_available(), "test requires IPEX and an XPU device")(test_case) + return unittest.skipUnless(is_torch_xpu_available(), "test requires XPU device")(test_case) def require_torch_multi_xpu(test_case): """ - Decorator marking a test that requires a multi-XPU setup with IPEX and at least one XPU device. These tests are - skipped on a machine without IPEX or multiple XPUs. + Decorator marking a test that requires a multi-XPU setup (in PyTorch). These tests are skipped on a machine without + multiple XPUs. To run *only* the multi_xpu tests, assuming all test names contain multi_xpu: $ pytest -sv ./tests -k "multi_xpu" """ if not is_torch_xpu_available(): - return unittest.skip("test requires IPEX and at least one XPU device")(test_case) + return unittest.skip("test requires PyTorch XPU")(test_case) return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 80b0740d20a3d1..98f35501928965 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -40,6 +40,7 @@ ExplicitEnum, cached_property, is_accelerate_available, + is_ipex_available, is_safetensors_available, is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, @@ -2136,6 +2137,8 @@ def _setup_devices(self) -> "torch.device": if self.use_cpu: device = torch.device("cpu") elif is_torch_xpu_available(): + if not is_ipex_available() and not is_accelerate_available("0.32.0.dev"): + raise ImportError("Using the XPU PyTorch backend requires `accelerate>=0.32.0.dev`") device = torch.device("xpu:0") torch.xpu.set_device(device) elif is_torch_mlu_available(): diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 75a3243c9443aa..310101e0a9a29a 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -747,13 +747,18 @@ def get_major_and_minor_from_version(full_version): @lru_cache def is_torch_xpu_available(check_device=False): - "Checks if `intel_extension_for_pytorch` is installed and potentially if a XPU is in the environment" - if not is_ipex_available(): + """ + Checks if XPU acceleration is available either via `intel_extension_for_pytorch` or + via stock PyTorch (>=2.4) and potentially if a XPU is in the environment + """ + if not is_torch_available(): return False - import intel_extension_for_pytorch # noqa: F401 import torch + if is_ipex_available(): + import intel_extension_for_pytorch # noqa: F401 + if check_device: try: # Will raise a RuntimeError if no XPU is found