Skip to content

Commit 8a402cf

Browse files
MengqingCaomzusman
authored andcommitted
[Platform] Move model arch check to platform (vllm-project#11503)
Signed-off-by: Mengqing Cao <cmq0113@163.com>
1 parent aac5a3e commit 8a402cf

File tree

3 files changed

+51
-37
lines changed

3 files changed

+51
-37
lines changed

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -187,31 +187,6 @@
187187
**_SPECULATIVE_DECODING_MODELS,
188188
}
189189

190-
# Models not supported by ROCm.
191-
_ROCM_UNSUPPORTED_MODELS: List[str] = []
192-
193-
# Models partially supported by ROCm.
194-
# Architecture -> Reason.
195-
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
196-
"Triton flash attention. For half-precision SWA support, "
197-
"please use CK flash attention by setting "
198-
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
199-
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
200-
"Qwen2ForCausalLM":
201-
_ROCM_SWA_REASON,
202-
"MistralForCausalLM":
203-
_ROCM_SWA_REASON,
204-
"MixtralForCausalLM":
205-
_ROCM_SWA_REASON,
206-
"PaliGemmaForConditionalGeneration":
207-
("ROCm flash attention does not yet "
208-
"fully support 32-bit precision on PaliGemma"),
209-
"Phi3VForCausalLM":
210-
("ROCm Triton flash attention may run into compilation errors due to "
211-
"excessive use of shared memory. If this happens, disable Triton FA "
212-
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
213-
}
214-
215190

216191
@dataclass(frozen=True)
217192
class _ModelInfo:
@@ -297,17 +272,7 @@ def _try_load_model_cls(
297272
model_arch: str,
298273
model: _BaseRegisteredModel,
299274
) -> Optional[Type[nn.Module]]:
300-
if current_platform.is_rocm():
301-
if model_arch in _ROCM_UNSUPPORTED_MODELS:
302-
raise ValueError(f"Model architecture '{model_arch}' is not "
303-
"supported by ROCm for now.")
304-
305-
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
306-
msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
307-
logger.warning(
308-
"Model architecture '%s' is partially "
309-
"supported by ROCm: %s", model_arch, msg)
310-
275+
current_platform.verify_model_arch(model_arch)
311276
try:
312277
return model.load_model_cls()
313278
except Exception:

vllm/platforms/interface.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
199199
"""
200200
pass
201201

202+
@classmethod
203+
def verify_model_arch(cls, model_arch: str) -> None:
204+
"""
205+
Verify whether the current platform supports the specified model
206+
architecture.
207+
208+
- This will raise an Error or Warning based on the model support on
209+
the current platform.
210+
- By default all models are considered supported.
211+
"""
212+
pass
213+
202214
@classmethod
203215
def verify_quantization(cls, quant: str) -> None:
204216
"""

vllm/platforms/rocm.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
from functools import lru_cache
3-
from typing import TYPE_CHECKING, Optional
3+
from typing import TYPE_CHECKING, Dict, List, Optional
44

55
import torch
66

@@ -33,6 +33,31 @@
3333
" `spawn` instead.")
3434
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
3535

36+
# Models not supported by ROCm.
37+
_ROCM_UNSUPPORTED_MODELS: List[str] = []
38+
39+
# Models partially supported by ROCm.
40+
# Architecture -> Reason.
41+
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
42+
"Triton flash attention. For half-precision SWA support, "
43+
"please use CK flash attention by setting "
44+
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
45+
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
46+
"Qwen2ForCausalLM":
47+
_ROCM_SWA_REASON,
48+
"MistralForCausalLM":
49+
_ROCM_SWA_REASON,
50+
"MixtralForCausalLM":
51+
_ROCM_SWA_REASON,
52+
"PaliGemmaForConditionalGeneration":
53+
("ROCm flash attention does not yet "
54+
"fully support 32-bit precision on PaliGemma"),
55+
"Phi3VForCausalLM":
56+
("ROCm Triton flash attention may run into compilation errors due to "
57+
"excessive use of shared memory. If this happens, disable Triton FA "
58+
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
59+
}
60+
3661

3762
class RocmPlatform(Platform):
3863
_enum = PlatformEnum.ROCM
@@ -102,6 +127,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
102127
else:
103128
parallel_config.worker_cls = "vllm.worker.worker.Worker"
104129

130+
@classmethod
131+
def verify_model_arch(cls, model_arch: str) -> None:
132+
if model_arch in _ROCM_UNSUPPORTED_MODELS:
133+
raise ValueError(f"Model architecture '{model_arch}' is not "
134+
"supported by ROCm for now.")
135+
136+
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
137+
msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
138+
logger.warning(
139+
"Model architecture '%s' is partially "
140+
"supported by ROCm: %s", model_arch, msg)
141+
105142
@classmethod
106143
def verify_quantization(cls, quant: str) -> None:
107144
super().verify_quantization(quant)

0 commit comments

Comments
 (0)