diff --git a/tests/models/test_paligemma.py b/tests/models/test_paligemma.py index e11784558f196..2d4899469b775 100644 --- a/tests/models/test_paligemma.py +++ b/tests/models/test_paligemma.py @@ -1,4 +1,3 @@ -import os from typing import List, Optional, Tuple, Type import pytest @@ -24,12 +23,6 @@ models = ["google/paligemma-3b-mix-224"] -# ROCm Triton FA can run into shared memory issues with these models, -# use other backends in the meantime -# FIXME (mattwong, gshtrasb, hongxiayan) -if is_hip(): - os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0" - def vllm_to_hf_output(vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], @@ -138,7 +131,15 @@ def run_test( [0.25, 0.5, 1.0], ], ) -@pytest.mark.parametrize("dtype", ["float", "half"]) +@pytest.mark.parametrize("dtype", [ + pytest.param( + "float", + marks=pytest.mark.skipif( + is_hip(), + reason= + "ROCm FA does not yet fully support 32-bit precision on PaliGemma") + ), "half" +]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 58b6d73e5f113..8e7f470240ba8 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -91,9 +91,15 @@ "please use CK flash attention by setting " "`VLLM_USE_TRITON_FLASH_ATTN=0`") _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = { - "Qwen2ForCausalLM": _ROCM_SWA_REASON, - "MistralForCausalLM": _ROCM_SWA_REASON, - "MixtralForCausalLM": _ROCM_SWA_REASON, + "Qwen2ForCausalLM": + _ROCM_SWA_REASON, + "MistralForCausalLM": + _ROCM_SWA_REASON, + "MixtralForCausalLM": + _ROCM_SWA_REASON, + "PaliGemmaForConditionalGeneration": + ("ROCm flash attention does not yet " + "fully support 32-bit precision on PaliGemma") }