4
4
5
5
import torch
6
6
7
+ import vllm .envs as envs
7
8
from vllm .logger import init_logger
8
9
9
10
from .interface import Platform , PlatformEnum , _Backend
@@ -33,22 +34,28 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
33
34
dtype : torch .dtype , kv_cache_dtype : Optional [str ],
34
35
block_size : int , use_v1 : bool ,
35
36
use_mla : bool ) -> str :
36
- if selected_backend != _Backend .PALLAS :
37
+ if (selected_backend != _Backend .PALLAS
38
+ and selected_backend != _Backend .PALLAS_VLLM_V1 ):
37
39
logger .info ("Cannot use %s backend on TPU." , selected_backend )
38
- logger .info ("Using Pallas backend." )
39
- return "vllm.attention.backends.pallas.PallasAttentionBackend"
40
+
41
+ if use_v1 :
42
+ logger .info ("Using Pallas V1 backend." )
43
+ return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
44
+ else :
45
+ logger .info ("Using Pallas backend." )
46
+ return "vllm.attention.backends.pallas.PallasAttentionBackend"
40
47
41
48
@classmethod
42
49
def get_device_name (cls , device_id : int = 0 ) -> str :
43
- raise NotImplementedError
50
+ return "tpu"
44
51
45
52
@classmethod
46
53
def get_device_total_memory (cls , device_id : int = 0 ) -> int :
47
54
raise NotImplementedError
48
55
49
56
@classmethod
50
57
def is_async_output_supported (cls , enforce_eager : Optional [bool ]) -> bool :
51
- return True
58
+ return not envs . VLLM_USE_V1
52
59
53
60
@classmethod
54
61
def inference_mode (cls ):
@@ -63,22 +70,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
63
70
cache_config .block_size = 16
64
71
65
72
compilation_config = vllm_config .compilation_config
66
- if compilation_config .level == CompilationLevel .NO_COMPILATION :
67
- # TPU does not support NO_COMPILATION
73
+
74
+ # TPU only supports DYNAMO_ONCE compilation level
75
+ if compilation_config .level != CompilationLevel .DYNAMO_ONCE :
76
+ logger .info ("[TPU] Forcing DYNAMO_ONCE compilation level" )
68
77
compilation_config .level = CompilationLevel .DYNAMO_ONCE
69
- assert compilation_config .level < CompilationLevel .PIECEWISE ,\
70
- "TPU does not support Inductor."
71
78
72
79
if compilation_config .backend == "" :
73
80
compilation_config .backend = "openxla"
74
81
75
82
assert vllm_config .speculative_config is None , \
76
83
"TPU does not support speculative decoding"
77
84
78
- assert not vllm_config .scheduler_config .chunked_prefill_enabled , (
79
- "Chunked prefill is not yet supported for TPU backend" )
80
- assert not vllm_config .speculative_config , (
81
- "Speculative decoding is not yet supported for TPU backend" )
82
85
if vllm_config .model_config .dtype in (torch .float16 , torch .float32 ):
83
86
logger .warning (
84
87
"The TPU backend currently does not support %s. "
@@ -88,8 +91,27 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
88
91
parallel_config = vllm_config .parallel_config
89
92
scheduler_config = vllm_config .scheduler_config
90
93
if parallel_config .worker_cls == "auto" :
91
- if scheduler_config . is_multi_step :
94
+ if envs . VLLM_USE_V1 :
92
95
parallel_config .worker_cls = \
93
- "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker "
96
+ "vllm.v1. worker.tpu_worker.TPUWorker "
94
97
else :
95
- parallel_config .worker_cls = "vllm.worker.tpu_worker.TPUWorker"
98
+ if scheduler_config .is_multi_step :
99
+ parallel_config .worker_cls = \
100
+ "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
101
+ else :
102
+ parallel_config .worker_cls = \
103
+ "vllm.worker.tpu_worker.TPUWorker"
104
+
105
+ # Adjust scheduler config for V1
106
+ # TODO: Add support for these
107
+ if envs .VLLM_USE_V1 and vllm_config .cache_config .enable_prefix_caching :
108
+ logger .warning ("[V1][TPU] Disable prefix caching" )
109
+ vllm_config .cache_config .enable_prefix_caching = False
110
+
111
+ assert not vllm_config .speculative_config , (
112
+ "Speculative decoding is not yet supported for TPU backend" )
113
+
114
+ @classmethod
115
+ def is_pin_memory_available (cls ):
116
+ logger .warning ("Pin memory is not supported on TPU." )
117
+ return False
0 commit comments