From f256ebe4df6757d76f1f1642d7e110268a2f8190 Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Sun, 2 Feb 2025 18:17:26 +0800 Subject: [PATCH] [Hardware][Intel GPU] add XPU bf16 support (#12392) Signed-off-by: Kunshang Ji --- .../installation/gpu/xpu.inc.md | 2 +- vllm/platforms/xpu.py | 23 ++++++++++++++++--- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/docs/source/getting_started/installation/gpu/xpu.inc.md b/docs/source/getting_started/installation/gpu/xpu.inc.md index 4116826789e5c..ef02d9a078a1b 100644 --- a/docs/source/getting_started/installation/gpu/xpu.inc.md +++ b/docs/source/getting_started/installation/gpu/xpu.inc.md @@ -36,7 +36,7 @@ VLLM_TARGET_DEVICE=xpu python setup.py install :::{note} - FP16 is the default data type in the current XPU backend. The BF16 data - type will be supported in the future. + type is supported on Intel Data Center GPU, not supported on Intel Arc GPU yet. ::: ## Set up using Docker diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index a5ca77f57cf47..039cdd5adc9af 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -66,9 +66,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # check and update model config model_config = vllm_config.model_config if model_config.dtype == torch.bfloat16: - logger.warning( - "bfloat16 is not fully supported on XPU, casting to float16.") - model_config.dtype = torch.float16 + bf16_supported = cls.device_support_bf16() + if not bf16_supported: + logger.warning( + "bfloat16 is only supported on Intel Data Center GPU, " + "Intel Arc GPU is not supported yet. Your device is %s," + "which is not supported. will fallback to float16", + cls.get_device_name()) + model_config.dtype = torch.float16 if not model_config.enforce_eager: logger.warning( "CUDA graph is not supported on XPU, fallback to the eager " @@ -116,3 +121,15 @@ def get_current_memory_usage(cls, ) -> float: torch.xpu.reset_peak_memory_stats(device) return torch.xpu.max_memory_allocated(device) + + @classmethod + def device_support_bf16(cls) -> bool: + device_name = cls.get_device_name().lower() + if device_name.count("arc") > 0: + return False + elif device_name.count("data center gpu") > 0: + return True + else: + logger.warning("Unknown device name %s, always use float16", + device_name) + return False