Skip to content

Commit b12e87f

Browse files
authored
[platforms] enable platform plugins (#11602)
Signed-off-by: youkaichao <youkaichao@gmail.com>
1 parent 5dbf854 commit b12e87f

File tree

23 files changed

+360
-181
lines changed

23 files changed

+360
-181
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,12 @@ steps:
106106
source_file_dependencies:
107107
- vllm/
108108
commands:
109-
- pip install -e ./plugins/vllm_add_dummy_model
110109
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py
111110
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
112111
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
113112
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
114113
- pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
115114
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py
116-
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
117115
- pytest -v -s entrypoints/test_chat_utils.py
118116
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
119117

@@ -333,8 +331,6 @@ steps:
333331
- vllm/
334332
- tests/models
335333
commands:
336-
- pip install -e ./plugins/vllm_add_dummy_model
337-
- pytest -v -s models/test_oot_registration.py # it needs a clean process
338334
- pytest -v -s models/test_registry.py
339335
- pytest -v -s models/test_initialization.py
340336

@@ -469,11 +465,28 @@ steps:
469465
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)'
470466
- pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'
471467
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
472-
- pip install -e ./plugins/vllm_add_dummy_model
473-
- pytest -v -s distributed/test_distributed_oot.py
474468
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
475469
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/disagg_test.py
476470

471+
- label: Plugin Tests (2 GPUs) # 40min
472+
working_dir: "/vllm-workspace/tests"
473+
num_gpus: 2
474+
fast_check: true
475+
source_file_dependencies:
476+
- vllm/plugins/
477+
- tests/plugins/
478+
commands:
479+
# begin platform plugin tests, all the code in-between runs on dummy platform
480+
- pip install -e ./plugins/vllm_add_dummy_platform
481+
- pytest -v -s plugins_tests/test_platform_plugins.py
482+
- pip uninstall vllm_add_dummy_platform -y
483+
# end platform plugin tests
484+
# other tests continue here:
485+
- pip install -e ./plugins/vllm_add_dummy_model
486+
- pytest -v -s distributed/test_distributed_oot.py
487+
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
488+
- pytest -v -s models/test_oot_registration.py # it needs a clean process
489+
477490
- label: Multi-step Tests (4 GPUs) # 36min
478491
working_dir: "/vllm-workspace/tests"
479492
num_gpus: 4

docs/source/design/plugin_system.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,11 @@ Every plugin has three parts:
4141
2. **Plugin name**: The name of the plugin. This is the value in the dictionary of the `entry_points` dictionary. In the example above, the plugin name is `register_dummy_model`. Plugins can be filtered by their names using the `VLLM_PLUGINS` environment variable. To load only a specific plugin, set `VLLM_PLUGINS` to the plugin name.
4242
3. **Plugin value**: The fully qualified name of the function to register in the plugin system. In the example above, the plugin value is `vllm_add_dummy_model:register`, which refers to a function named `register` in the `vllm_add_dummy_model` module.
4343

44-
## What Can Plugins Do?
44+
## Types of supported plugins
4545

46-
Currently, the primary use case for plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling `ModelRegistry.register_model` to register the model. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM.
46+
- **General plugins** (with group name `vllm.general_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling `ModelRegistry.register_model` to register the model inside the plugin function.
47+
48+
- **Platform plugins** (with group name `vllm.platform_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree platforms into vLLM. The plugin function should return `None` when the platform is not supported in the current environment, or the platform class's fully qualified name when the platform is supported.
4749

4850
## Guidelines for Writing Plugins
4951

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
to_enc_dec_tuple_list, zip_enc_dec_prompts)
3232
from vllm.logger import init_logger
3333
from vllm.outputs import RequestOutput
34-
from vllm.platforms import current_platform
3534
from vllm.sampling_params import BeamSearchParams
3635
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
3736
identity)
@@ -242,6 +241,7 @@ def video_assets() -> _VideoAssets:
242241
class HfRunner:
243242

244243
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
244+
from vllm.platforms import current_platform
245245
if x is None or isinstance(x, (bool, )):
246246
return x
247247

tests/kernels/test_attention_selector.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55

66
from tests.kernels.utils import override_backend_env_variable
77
from vllm.attention.selector import which_attn_to_use
8-
from vllm.platforms import cpu, cuda, openvino, rocm
8+
from vllm.platforms.cpu import CpuPlatform
9+
from vllm.platforms.cuda import CudaPlatform
10+
from vllm.platforms.openvino import OpenVinoPlatform
11+
from vllm.platforms.rocm import RocmPlatform
912
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
1013

1114

@@ -20,26 +23,23 @@ def test_env(name: str, device: str, monkeypatch):
2023
override_backend_env_variable(monkeypatch, name)
2124

2225
if device == "cpu":
23-
with patch("vllm.attention.selector.current_platform",
24-
cpu.CpuPlatform()):
26+
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
2527
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
2628
False)
2729
assert backend.name == "TORCH_SDPA"
2830
elif device == "hip":
29-
with patch("vllm.attention.selector.current_platform",
30-
rocm.RocmPlatform()):
31+
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
3132
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
3233
False)
3334
assert backend.name == "ROCM_FLASH"
3435
elif device == "openvino":
3536
with patch("vllm.attention.selector.current_platform",
36-
openvino.OpenVinoPlatform()):
37+
OpenVinoPlatform()):
3738
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
3839
False)
3940
assert backend.name == "OPENVINO"
4041
else:
41-
with patch("vllm.attention.selector.current_platform",
42-
cuda.CudaPlatform()):
42+
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
4343
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
4444
False)
4545
assert backend.name == name
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from setuptools import setup
2+
3+
setup(
4+
name='vllm_add_dummy_platform',
5+
version='0.1',
6+
packages=['vllm_add_dummy_platform'],
7+
entry_points={
8+
'vllm.platform_plugins': [
9+
"dummy_platform_plugin = vllm_add_dummy_platform:dummy_platform_plugin" # noqa
10+
]
11+
})
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from typing import Optional
2+
3+
4+
def dummy_platform_plugin() -> Optional[str]:
5+
return "vllm_add_dummy_platform.dummy_platform.DummyPlatform"
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from vllm.platforms.cuda import CudaPlatform
2+
3+
4+
class DummyPlatform(CudaPlatform):
5+
device_name = "DummyDevice"
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
def test_platform_plugins():
2+
# simulate workload by running an example
3+
import runpy
4+
current_file = __file__
5+
import os
6+
example_file = os.path.join(
7+
os.path.dirname(os.path.dirname(os.path.dirname(current_file))),
8+
"examples", "offline_inference.py")
9+
runpy.run_path(example_file)
10+
11+
# check if the plugin is loaded correctly
12+
from vllm.platforms import _init_trace, current_platform
13+
assert current_platform.device_name == "DummyDevice", (
14+
f"Expected DummyDevice, got {current_platform.device_name}, "
15+
"possibly because current_platform is imported before the plugin"
16+
f" is loaded. The first import:\n{_init_trace}")

vllm/config.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
2323
get_quantization_config)
2424
from vllm.model_executor.models import ModelRegistry
25-
from vllm.platforms import current_platform, interface
25+
from vllm.platforms import CpuArchEnum
2626
from vllm.tracing import is_otel_available, otel_import_error_traceback
2727
from vllm.transformers_utils.config import (
2828
ConfigFormat, get_config, get_hf_image_processor_config,
@@ -349,6 +349,7 @@ def __init__(self,
349349
self.is_hybrid = self._init_is_hybrid()
350350
self.has_inner_state = self._init_has_inner_state()
351351

352+
from vllm.platforms import current_platform
352353
if current_platform.is_neuron():
353354
self.override_neuron_config = override_neuron_config
354355
else:
@@ -589,6 +590,7 @@ def _verify_quantization(self) -> None:
589590
raise ValueError(
590591
f"Unknown quantization method: {self.quantization}. Must "
591592
f"be one of {supported_quantization}.")
593+
from vllm.platforms import current_platform
592594
current_platform.verify_quantization(self.quantization)
593595
if self.quantization not in optimized_quantization_methods:
594596
logger.warning(
@@ -644,6 +646,7 @@ def verify_async_output_proc(self, parallel_config, speculative_config,
644646

645647
# Reminder: Please update docs/source/usage/compatibility_matrix.md
646648
# If the feature combo become valid
649+
from vllm.platforms import current_platform
647650
if not current_platform.is_async_output_supported(self.enforce_eager):
648651
logger.warning(
649652
"Async output processing is not supported on the "
@@ -1012,6 +1015,7 @@ def _verify_args(self) -> None:
10121015
raise ValueError(
10131016
"GPU memory utilization must be less than 1.0. Got "
10141017
f"{self.gpu_memory_utilization}.")
1018+
from vllm.platforms import current_platform
10151019
if (current_platform.is_cuda() and self.block_size is not None
10161020
and self.block_size > 32):
10171021
raise ValueError("CUDA Paged Attention kernel only supports "
@@ -1279,6 +1283,7 @@ def __post_init__(self) -> None:
12791283
f"distributed executor backend "
12801284
f"'{self.distributed_executor_backend}'.")
12811285
ray_only_devices = ["tpu", "hpu"]
1286+
from vllm.platforms import current_platform
12821287
if (current_platform.device_type in ray_only_devices
12831288
and self.world_size > 1):
12841289
if self.distributed_executor_backend is None:
@@ -1327,7 +1332,7 @@ def use_ray(self) -> bool:
13271332
def _verify_args(self) -> None:
13281333
# Lazy import to avoid circular import
13291334
from vllm.executor.executor_base import ExecutorBase
1330-
1335+
from vllm.platforms import current_platform
13311336
if self.distributed_executor_backend not in (
13321337
"ray", "mp", None) and not (isinstance(
13331338
self.distributed_executor_backend, type) and issubclass(
@@ -1528,6 +1533,7 @@ def compute_hash(self) -> str:
15281533
def __init__(self, device: str = "auto") -> None:
15291534
if device == "auto":
15301535
# Automated device type detection
1536+
from vllm.platforms import current_platform
15311537
self.device_type = current_platform.device_type
15321538
if not self.device_type:
15331539
raise RuntimeError("Failed to infer device type")
@@ -2241,9 +2247,10 @@ def _get_and_verify_dtype(
22412247
else:
22422248
torch_dtype = config_dtype
22432249

2250+
from vllm.platforms import current_platform
22442251
if (current_platform.is_cpu()
22452252
and current_platform.get_cpu_architecture()
2246-
== interface.CpuArchEnum.POWERPC
2253+
== CpuArchEnum.POWERPC
22472254
and (config_dtype == torch.float16
22482255
or config_dtype == torch.float32)):
22492256
logger.info(
@@ -3083,6 +3090,7 @@ def _get_quantization_config(
30833090
model_config: ModelConfig,
30843091
load_config: LoadConfig) -> Optional[QuantizationConfig]:
30853092
"""Get the quantization config."""
3093+
from vllm.platforms import current_platform
30863094
if model_config.quantization is not None:
30873095
from vllm.model_executor.model_loader.weight_utils import (
30883096
get_quant_config)
@@ -3145,6 +3153,7 @@ def __post_init__(self):
31453153
self.quant_config = VllmConfig._get_quantization_config(
31463154
self.model_config, self.load_config)
31473155

3156+
from vllm.platforms import current_platform
31483157
if self.scheduler_config is not None and \
31493158
self.model_config is not None and \
31503159
self.scheduler_config.chunked_prefill_enabled and \

vllm/distributed/parallel_state.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
import vllm.envs as envs
4040
from vllm.distributed.utils import StatelessProcessGroup
4141
from vllm.logger import init_logger
42-
from vllm.platforms import current_platform
4342
from vllm.utils import direct_register_custom_op, supports_custom_op
4443

4544
if TYPE_CHECKING:
@@ -194,6 +193,7 @@ def __init__(
194193
assert self.cpu_group is not None
195194
assert self.device_group is not None
196195

196+
from vllm.platforms import current_platform
197197
if current_platform.is_cuda_alike():
198198
self.device = torch.device(f"cuda:{local_rank}")
199199
else:
@@ -1188,6 +1188,7 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
11881188
import ray # Lazy import Ray
11891189
ray.shutdown()
11901190
gc.collect()
1191+
from vllm.platforms import current_platform
11911192
if not current_platform.is_cpu():
11921193
torch.cuda.empty_cache()
11931194

vllm/engine/arg_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from vllm.executor.executor_base import ExecutorBase
1919
from vllm.logger import init_logger
2020
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
21-
from vllm.platforms import current_platform
2221
from vllm.transformers_utils.utils import check_gguf_file
2322
from vllm.usage.usage_lib import UsageContext
2423
from vllm.utils import FlexibleArgumentParser, StoreBoolean
@@ -1094,6 +1093,7 @@ def create_engine_config(self,
10941093
use_sliding_window = (model_config.get_sliding_window()
10951094
is not None)
10961095
use_spec_decode = self.speculative_model is not None
1096+
from vllm.platforms import current_platform
10971097
if (is_gpu and not use_sliding_window and not use_spec_decode
10981098
and not self.enable_lora
10991099
and not self.enable_prompt_adapter

vllm/executor/ray_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from vllm.config import ParallelConfig
99
from vllm.executor.msgspec_utils import decode_hook, encode_hook
1010
from vllm.logger import init_logger
11-
from vllm.platforms import current_platform
1211
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
1312
from vllm.utils import get_ip
1413
from vllm.worker.worker_base import WorkerWrapperBase
@@ -229,6 +228,7 @@ def initialize_ray_cluster(
229228
the default Ray cluster address.
230229
"""
231230
assert_ray_available()
231+
from vllm.platforms import current_platform
232232

233233
# Connect to a ray cluster.
234234
if current_platform.is_rocm() or current_platform.is_xpu():

vllm/model_executor/guided_decoding/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from vllm.model_executor.guided_decoding.utils import (
77
convert_lark_to_gbnf, grammar_is_likely_lark,
88
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
9-
from vllm.platforms import CpuArchEnum, current_platform
9+
from vllm.platforms import CpuArchEnum
1010

1111
if TYPE_CHECKING:
1212
from transformers import PreTrainedTokenizer
@@ -39,6 +39,7 @@ def maybe_backend_fallback(
3939

4040
if guided_params.backend == "xgrammar":
4141
# xgrammar only has x86 wheels for linux, fallback to outlines
42+
from vllm.platforms import current_platform
4243
if current_platform.get_cpu_architecture() is not CpuArchEnum.X86:
4344
logger.warning("xgrammar is only supported on x86 CPUs. "
4445
"Falling back to use outlines instead.")

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import torch.nn as nn
1919

2020
from vllm.logger import init_logger
21-
from vllm.platforms import current_platform
2221

2322
from .interfaces import (has_inner_state, is_attention_free, is_hybrid,
2423
supports_cross_encoding, supports_multimodal,
@@ -273,6 +272,7 @@ def _try_load_model_cls(
273272
model_arch: str,
274273
model: _BaseRegisteredModel,
275274
) -> Optional[Type[nn.Module]]:
275+
from vllm.platforms import current_platform
276276
current_platform.verify_model_arch(model_arch)
277277
try:
278278
return model.load_model_cls()

vllm/model_executor/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33

44
import torch
55

6-
from vllm.platforms import current_platform
7-
86

97
def set_random_seed(seed: int) -> None:
8+
from vllm.platforms import current_platform
109
current_platform.seed_everything(seed)
1110

1211

@@ -38,6 +37,7 @@ def set_weight_attrs(
3837
# This sometimes causes OOM errors during model loading. To avoid this,
3938
# we sync the param tensor after its weight loader is called.
4039
# TODO(woosuk): Remove this hack once we have a better solution.
40+
from vllm.platforms import current_platform
4141
if current_platform.is_tpu() and key == "weight_loader":
4242
value = _make_synced_weight_loader(value)
4343
setattr(weight, key, value)

0 commit comments

Comments
 (0)