Skip to content

Commit 7d47acc

Browse files
author
reidliu41
committed
fix import
Signed-off-by: reidliu41 <reid201711@gmail.com>
1 parent ff38f0a commit 7d47acc

File tree

5 files changed

+62
-37
lines changed

5 files changed

+62
-37
lines changed

vllm/__init__.py

+35-15
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,6 @@
88

99
import torch
1010

11-
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
12-
from vllm.engine.async_llm_engine import AsyncLLMEngine
13-
from vllm.engine.llm_engine import LLMEngine
14-
from vllm.entrypoints.llm import LLM
15-
from vllm.executor.ray_utils import initialize_ray_cluster
16-
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
17-
from vllm.model_executor.models import ModelRegistry
18-
from vllm.outputs import (ClassificationOutput, ClassificationRequestOutput,
19-
CompletionOutput, EmbeddingOutput,
20-
EmbeddingRequestOutput, PoolingOutput,
21-
PoolingRequestOutput, RequestOutput, ScoringOutput,
22-
ScoringRequestOutput)
23-
from vllm.pooling_params import PoolingParams
24-
from vllm.sampling_params import SamplingParams
25-
2611
# set some common config/environment variables that should be set
2712
# for all processes created by vllm and all processes
2813
# that interact with vllm workers.
@@ -36,6 +21,41 @@
3621
# see https://github.com/vllm-project/vllm/issues/10619
3722
torch._inductor.config.compile_threads = 1
3823

24+
_lazy_imports_module_list = {
25+
"LLM": "vllm.entrypoints.llm.LLM",
26+
"ModelRegistry": "vllm.model_executor.models.ModelRegistry",
27+
"PromptType": "vllm.inputs.PromptType",
28+
"TextPrompt": "vllm.inputs.TextPrompt",
29+
"TokensPrompt": "vllm.inputs.TokensPrompt",
30+
"SamplingParams": "vllm.sampling_params.SamplingParams",
31+
"RequestOutput": "vllm.outputs.RequestOutput",
32+
"CompletionOutput": "vllm.outputs.CompletionOutput",
33+
"PoolingOutput": "vllm.outputs.PoolingOutput",
34+
"PoolingRequestOutput": "vllm.outputs.PoolingRequestOutput",
35+
"EmbeddingOutput": "vllm.outputs.EmbeddingOutput",
36+
"EmbeddingRequestOutput": "vllm.outputs.EmbeddingRequestOutput",
37+
"ClassificationOutput": "vllm.outputs.ClassificationOutput",
38+
"ClassificationRequestOutput": "vllm.outputs.ClassificationRequestOutput",
39+
"ScoringOutput": "vllm.outputs.ScoringOutput",
40+
"ScoringRequestOutput": "vllm.outputs.ScoringRequestOutput",
41+
"LLMEngine": "vllm.engine.llm_engine.LLMEngine",
42+
"EngineArgs": "vllm.engine.arg_utils.EngineArgs",
43+
"AsyncLLMEngine": "vllm.engine.async_llm_engine.AsyncLLMEngine",
44+
"AsyncEngineArgs": "vllm.engine.arg_utils.AsyncEngineArgs",
45+
"initialize_ray_cluster": "vllm.executor.ray_utils.initialize_ray_cluster",
46+
"PoolingParams": "vllm.pooling_params.PoolingParams",
47+
}
48+
49+
50+
def __getattr__(name: str):
51+
if name in _lazy_imports_module_list:
52+
import importlib
53+
module_path, attr = _lazy_imports_module_list[name].rsplit(".", 1)
54+
mod = importlib.import_module(module_path)
55+
return getattr(mod, attr)
56+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
57+
58+
3959
__all__ = [
4060
"__version__",
4161
"__version_tuple__",

vllm/engine/arg_utils.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1834,8 +1834,11 @@ def add_cli_args(parser: FlexibleArgumentParser,
18341834
parser.add_argument('--disable-log-requests',
18351835
action='store_true',
18361836
help='Disable logging requests.')
1837-
from vllm.platforms import current_platform
1838-
current_platform.pre_register_and_update(parser)
1837+
# Skip to avoid triggering platform detection
1838+
import sys
1839+
if not any(arg in sys.argv for arg in ["-h", "--help"]):
1840+
from vllm.platforms import current_platform
1841+
current_platform.pre_register_and_update(parser)
18391842
return parser
18401843

18411844

vllm/executor/ray_utils.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from vllm.config import ParallelConfig
1212
from vllm.executor.msgspec_utils import decode_hook, encode_hook
1313
from vllm.logger import init_logger
14-
from vllm.platforms import current_platform
1514
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
1615
from vllm.utils import get_ip
1716
from vllm.worker.worker_base import WorkerWrapperBase
@@ -109,7 +108,7 @@ def setup_device_if_necessary(self):
109108
# We can remove this API after it is fixed in compiled graph.
110109
assert self.worker is not None, "Worker is not initialized"
111110
if not self.compiled_dag_cuda_device_set:
112-
if current_platform.is_tpu():
111+
if vllm.platforms.current_platform.is_tpu():
113112
# Not needed
114113
pass
115114
else:

vllm/model_executor/layers/spec_decode_base_sampler.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import torch.jit
88
import torch.nn as nn
99

10-
from vllm.platforms import current_platform
11-
1210

1311
class SpecDecodeBaseSampler(nn.Module):
1412
"""Base class for samplers used for Speculative Decoding verification
@@ -37,6 +35,7 @@ def __init__(self, strict_mode: bool = False):
3735
def init_gpu_tensors(self, device: Union[int, str]) -> None:
3836
assert self.num_accepted_tokens is None
3937
if isinstance(device, int):
38+
from vllm.platforms import current_platform
4039
device = f"{current_platform.device_type}:{device}"
4140
elif not isinstance(device, str):
4241
raise ValueError(f"Device must be int or str, get {type(device)}")

vllm/plugins/__init__.py

+20-16
Original file line numberDiff line numberDiff line change
@@ -57,24 +57,28 @@ def load_general_plugins():
5757
return
5858
plugins_loaded = True
5959

60-
# some platform-specific configurations
61-
from vllm.platforms import current_platform
60+
# Skip to avoid triggering platform detection
61+
import sys
62+
if not any(arg in sys.argv for arg in ["-h", "--help"]):
63+
64+
# some platform-specific configurations
65+
from vllm.platforms import current_platform
6266

63-
if current_platform.is_xpu():
64-
# see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158
65-
torch._dynamo.config.disable = True
66-
elif current_platform.is_hpu():
67-
# NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1)
68-
# does not support torch.compile
69-
# Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for
70-
# torch.compile support
71-
is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '1') == '1'
72-
if is_lazy:
67+
if current_platform.is_xpu():
68+
# see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158
7369
torch._dynamo.config.disable = True
74-
# NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only)
75-
# requires enabling lazy collectives
76-
# see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501
77-
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true'
70+
elif current_platform.is_hpu():
71+
# NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1)
72+
# does not support torch.compile
73+
# Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for
74+
# torch.compile support
75+
is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '1') == '1'
76+
if is_lazy:
77+
torch._dynamo.config.disable = True
78+
# NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only)
79+
# requires enabling lazy collectives
80+
# see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501
81+
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true'
7882

7983
plugins = load_plugins_by_group(group='vllm.general_plugins')
8084
# general plugins, we only need to execute the loaded functions

0 commit comments

Comments
 (0)