diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 236efb1d2219d5..b8eb9f5a8b4222 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -22,7 +22,6 @@ import torch import torch.utils.checkpoint from torch import nn -from torch.cuda.amp import autocast from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions @@ -219,7 +218,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea scale_factor /= float(self.layer_idx + 1) # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) - with autocast(enabled=False): + with torch.amp.autocast(query.device.type, enabled=False): q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 79e5a27d3c678c..7c62d06949d44a 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -25,7 +25,6 @@ import torch.nn.functional as F import torch.utils.checkpoint from torch import nn -from torch.cuda.amp import autocast from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN @@ -249,7 +248,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea scale_factor /= float(self.layer_idx + 1) # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) - with autocast(enabled=False): + with torch.amp.autocast(query.device.type, enabled=False): q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index bd7484cc9fc305..b972d285aed3e5 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -813,23 +813,24 @@ def require_torch_multi_npu(test_case): def require_torch_xpu(test_case): """ - Decorator marking a test that requires XPU and IPEX. + Decorator marking a test that requires XPU (in PyTorch). - These tests are skipped when Intel Extension for PyTorch isn't installed or it does not match current PyTorch - version. + These tests are skipped when XPU backend is not available. XPU backend might be available either via stock + PyTorch (>=2.4) or via Intel Extension for PyTorch. In the latter case, if IPEX is installed, its version + must match match current PyTorch version. """ - return unittest.skipUnless(is_torch_xpu_available(), "test requires IPEX and an XPU device")(test_case) + return unittest.skipUnless(is_torch_xpu_available(), "test requires XPU device")(test_case) def require_torch_multi_xpu(test_case): """ - Decorator marking a test that requires a multi-XPU setup with IPEX and at least one XPU device. These tests are - skipped on a machine without IPEX or multiple XPUs. + Decorator marking a test that requires a multi-XPU setup (in PyTorch). These tests are skipped on a machine without + multiple XPUs. To run *only* the multi_xpu tests, assuming all test names contain multi_xpu: $ pytest -sv ./tests -k "multi_xpu" """ if not is_torch_xpu_available(): - return unittest.skip("test requires IPEX and at least one XPU device")(test_case) + return unittest.skip("test requires PyTorch XPU")(test_case) return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 80b0740d20a3d1..98f35501928965 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -40,6 +40,7 @@ ExplicitEnum, cached_property, is_accelerate_available, + is_ipex_available, is_safetensors_available, is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, @@ -2136,6 +2137,8 @@ def _setup_devices(self) -> "torch.device": if self.use_cpu: device = torch.device("cpu") elif is_torch_xpu_available(): + if not is_ipex_available() and not is_accelerate_available("0.32.0.dev"): + raise ImportError("Using the XPU PyTorch backend requires `accelerate>=0.32.0.dev`") device = torch.device("xpu:0") torch.xpu.set_device(device) elif is_torch_mlu_available(): diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 75a3243c9443aa..310101e0a9a29a 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -747,13 +747,18 @@ def get_major_and_minor_from_version(full_version): @lru_cache def is_torch_xpu_available(check_device=False): - "Checks if `intel_extension_for_pytorch` is installed and potentially if a XPU is in the environment" - if not is_ipex_available(): + """ + Checks if XPU acceleration is available either via `intel_extension_for_pytorch` or + via stock PyTorch (>=2.4) and potentially if a XPU is in the environment + """ + if not is_torch_available(): return False - import intel_extension_for_pytorch # noqa: F401 import torch + if is_ipex_available(): + import intel_extension_for_pytorch # noqa: F401 + if check_device: try: # Will raise a RuntimeError if no XPU is found