Skip to content

Commit

Permalink
Limit PaliGemma to half precision on ROCm, restore use of Triton FA f…
Browse files Browse the repository at this point in the history
…or it
  • Loading branch information
mawong-amd committed Jul 19, 2024
1 parent 0a19008 commit e75ec14
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 11 deletions.
17 changes: 9 additions & 8 deletions tests/models/test_paligemma.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from typing import List, Optional, Tuple, Type

import pytest
Expand All @@ -24,12 +23,6 @@

models = ["google/paligemma-3b-mix-224"]

# ROCm Triton FA can run into shared memory issues with these models,
# use other backends in the meantime
# FIXME (mattwong, gshtrasb, hongxiayan)
if is_hip():
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"


def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
Expand Down Expand Up @@ -138,7 +131,15 @@ def run_test(
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["float", "half"])
@pytest.mark.parametrize("dtype", [
pytest.param(
"float",
marks=pytest.mark.skipif(
is_hip(),
reason=
"ROCm FA does not yet fully support 32-bit precision on PaliGemma")
), "half"
])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
Expand Down
12 changes: 9 additions & 3 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,15 @@
"please use CK flash attention by setting "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
"Qwen2ForCausalLM": _ROCM_SWA_REASON,
"MistralForCausalLM": _ROCM_SWA_REASON,
"MixtralForCausalLM": _ROCM_SWA_REASON,
"Qwen2ForCausalLM":
_ROCM_SWA_REASON,
"MistralForCausalLM":
_ROCM_SWA_REASON,
"MixtralForCausalLM":
_ROCM_SWA_REASON,
"PaliGemmaForConditionalGeneration":
("ROCm flash attention does not yet "
"fully support 32-bit precision on PaliGemma")
}


Expand Down

0 comments on commit e75ec14

Please sign in to comment.