Skip to content

Commit

Permalink
Flex xpu bug fix (#26135)
Browse files Browse the repository at this point in the history
flex gpu bug fix
  • Loading branch information
abhilash1910 authored Sep 13, 2023
1 parent 9709ab1 commit 05de038
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,12 +1425,13 @@ def __post_init__(self):
and is_torch_available()
and (self.device.type != "cuda")
and (self.device.type != "npu")
and (self.device.type != "xpu")
and (get_xla_device_type(self.device) != "GPU")
and (self.fp16 or self.fp16_full_eval)
):
raise ValueError(
"FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation"
" (`--fp16_full_eval`) can only be used on CUDA or NPU devices."
" (`--fp16_full_eval`) can only be used on CUDA or NPU devices or certain XPU devices (with IPEX)."
)

if (
Expand Down

0 comments on commit 05de038

Please sign in to comment.