Skip to content

Commit fbd5df0

Browse files
committed
[V1] TPU support - refactored
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
1 parent d84cef7 commit fbd5df0

File tree

9 files changed

+1742
-28
lines changed

9 files changed

+1742
-28
lines changed

examples/offline_inference/basic.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@
1010
"The future of AI is",
1111
]
1212
# Create a sampling params object.
13-
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
13+
sampling_params = SamplingParams() #temperature=0.8, top_p=0.95)
1414

1515
# Create an LLM.
16-
llm = LLM(model="facebook/opt-125m")
16+
# llm = LLM(model="facebook/opt-125m")
17+
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", max_model_len=512, max_num_seqs=16)
1718
# Generate texts from the prompts. The output is a list of RequestOutput objects
1819
# that contain the prompt, generated text, and other information.
1920
outputs = llm.generate(prompts, sampling_params)
2021
# Print the outputs.
2122
for output in outputs:
2223
prompt = output.prompt
2324
generated_text = output.outputs[0].text
24-
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
25+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

tests/entrypoints/llm/test_accuracy.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,13 @@
2121
EXPECTED_VALUE = 0.58
2222

2323

24-
def run_test():
24+
def run_test(more_args=None):
2525
"""Run the end to end accuracy test."""
2626

27-
model_args = f"pretrained={MODEL_NAME},max_model_len=2048"
27+
model_args = f"pretrained={MODEL_NAME},max_model_len=4096"
28+
29+
if more_args is not None:
30+
model_args = "{},{}".format(model_args, more_args)
2831

2932
results = lm_eval.simple_evaluate(
3033
model="vllm",
@@ -39,14 +42,21 @@ def run_test():
3942
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
4043

4144

42-
@pytest.mark.skipif(not current_platform.is_cuda(),
43-
reason="V1 is currently only supported on CUDA.")
45+
@pytest.mark.skipif(not current_platform.is_cuda()
46+
and not current_platform.is_tpu(),
47+
reason="V1 is currently only supported on CUDA and TPU")
4448
def test_lm_eval_accuracy_v1_engine(monkeypatch):
4549
"""Run with the V1 Engine."""
4650

4751
with monkeypatch.context() as m:
4852
m.setenv("VLLM_USE_V1", "1")
49-
run_test()
53+
54+
more_args = None
55+
if current_platform.is_tpu():
56+
# Limit compilation time for TPU V1
57+
more_args = "max_num_seqs=64"
58+
59+
run_test(more_args)
5060

5161

5262
def test_lm_eval_accuracy_v0_engine(monkeypatch):

tests/entrypoints/openai/correctness/test_lmeval.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
FILTER = "exact_match,strict-match"
2222
RTOL = 0.03
2323
EXPECTED_VALUE = 0.58
24-
DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests"]
24+
DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
2525
MORE_ARGS_LIST = [
2626
[], # Default
2727
["--enable-chunked-prefill"], # Chunked
@@ -67,14 +67,21 @@ def run_test(more_args):
6767
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
6868

6969

70-
@pytest.mark.skipif(not current_platform.is_cuda(),
71-
reason="V1 currently only supported on CUDA")
70+
@pytest.mark.skipif(not current_platform.is_cuda()
71+
and not current_platform.is_tpu(),
72+
reason="V1 currently only supported on CUDA and TPU")
7273
def test_lm_eval_accuracy_v1_engine(monkeypatch):
7374
"""Run with the V1 Engine."""
7475

7576
with monkeypatch.context() as m:
7677
m.setenv("VLLM_USE_V1", "1")
77-
run_test([])
78+
more_args = []
79+
80+
# Limit compilation time for V1
81+
if current_platform.is_tpu():
82+
more_args = ["--max-num-seqs", "64"]
83+
84+
run_test(more_args)
7885

7986

8087
@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)

vllm/platforms/interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class _Backend(enum.Enum):
3737
TRITON_MLA = enum.auto()
3838
HPU_ATTN = enum.auto()
3939
PALLAS = enum.auto()
40+
PALLAS_VLLM_V1 = enum.auto()
4041
IPEX = enum.auto()
4142
BLOCK_SPARSE_FLASH_ATTN = enum.auto()
4243
NO_ATTENTION = enum.auto()

vllm/platforms/tpu.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66

7+
import vllm.envs as envs
78
from vllm.logger import init_logger
89

910
from .interface import Platform, PlatformEnum, _Backend
@@ -33,22 +34,28 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
3334
dtype: torch.dtype, kv_cache_dtype: Optional[str],
3435
block_size: int, use_v1: bool,
3536
use_mla: bool) -> str:
36-
if selected_backend != _Backend.PALLAS:
37+
if (selected_backend != _Backend.PALLAS
38+
and selected_backend != _Backend.PALLAS_VLLM_V1):
3739
logger.info("Cannot use %s backend on TPU.", selected_backend)
38-
logger.info("Using Pallas backend.")
39-
return "vllm.attention.backends.pallas.PallasAttentionBackend"
40+
41+
if use_v1:
42+
logger.info("Using Pallas V1 backend.")
43+
return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
44+
else:
45+
logger.info("Using Pallas backend.")
46+
return "vllm.attention.backends.pallas.PallasAttentionBackend"
4047

4148
@classmethod
4249
def get_device_name(cls, device_id: int = 0) -> str:
43-
raise NotImplementedError
50+
return "tpu"
4451

4552
@classmethod
4653
def get_device_total_memory(cls, device_id: int = 0) -> int:
4754
raise NotImplementedError
4855

4956
@classmethod
5057
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
51-
return True
58+
return not envs.VLLM_USE_V1
5259

5360
@classmethod
5461
def inference_mode(cls):
@@ -63,22 +70,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
6370
cache_config.block_size = 16
6471

6572
compilation_config = vllm_config.compilation_config
66-
if compilation_config.level == CompilationLevel.NO_COMPILATION:
67-
# TPU does not support NO_COMPILATION
73+
74+
# TPU only supports DYNAMO_ONCE compilation level
75+
if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
76+
logger.info("[TPU] Forcing DYNAMO_ONCE compilation level")
6877
compilation_config.level = CompilationLevel.DYNAMO_ONCE
69-
assert compilation_config.level < CompilationLevel.PIECEWISE,\
70-
"TPU does not support Inductor."
7178

7279
if compilation_config.backend == "":
7380
compilation_config.backend = "openxla"
7481

7582
assert vllm_config.speculative_config is None, \
7683
"TPU does not support speculative decoding"
7784

78-
assert not vllm_config.scheduler_config.chunked_prefill_enabled, (
79-
"Chunked prefill is not yet supported for TPU backend")
80-
assert not vllm_config.speculative_config, (
81-
"Speculative decoding is not yet supported for TPU backend")
8285
if vllm_config.model_config.dtype in (torch.float16, torch.float32):
8386
logger.warning(
8487
"The TPU backend currently does not support %s. "
@@ -88,8 +91,27 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
8891
parallel_config = vllm_config.parallel_config
8992
scheduler_config = vllm_config.scheduler_config
9093
if parallel_config.worker_cls == "auto":
91-
if scheduler_config.is_multi_step:
94+
if envs.VLLM_USE_V1:
9295
parallel_config.worker_cls = \
93-
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
96+
"vllm.v1.worker.tpu_worker.TPUWorker"
9497
else:
95-
parallel_config.worker_cls = "vllm.worker.tpu_worker.TPUWorker"
98+
if scheduler_config.is_multi_step:
99+
parallel_config.worker_cls = \
100+
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
101+
else:
102+
parallel_config.worker_cls = \
103+
"vllm.worker.tpu_worker.TPUWorker"
104+
105+
# Adjust scheduler config for V1
106+
# TODO: Add support for these
107+
if envs.VLLM_USE_V1 and vllm_config.cache_config.enable_prefix_caching:
108+
logger.warning("[V1][TPU] Disable prefix caching")
109+
vllm_config.cache_config.enable_prefix_caching = False
110+
111+
assert not vllm_config.speculative_config, (
112+
"Speculative decoding is not yet supported for TPU backend")
113+
114+
@classmethod
115+
def is_pin_memory_available(cls):
116+
logger.warning("Pin memory is not supported on TPU.")
117+
return False

0 commit comments

Comments
 (0)