Skip to content

Commit 471fe65

Browse files
authored
[TPU][V1] Implicitly adjust page size when there's SMEM OOM (#16871)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
1 parent 3a0fba5 commit 471fe65

File tree

3 files changed

+34
-2
lines changed

3 files changed

+34
-2
lines changed

tests/v1/tpu/test_basic.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
]
2323

2424
TENSOR_PARALLEL_SIZES = [1]
25+
MAX_NUM_REQS = [16, 1024]
2526

2627
# TODO: Enable when CI/CD will have a multi-tpu instance
2728
# TENSOR_PARALLEL_SIZES = [1, 4]
@@ -32,12 +33,14 @@
3233
@pytest.mark.parametrize("model", MODELS)
3334
@pytest.mark.parametrize("max_tokens", [5])
3435
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
36+
@pytest.mark.parametrize("max_num_seqs", MAX_NUM_REQS)
3537
def test_basic(
3638
vllm_runner: type[VllmRunner],
3739
monkeypatch: pytest.MonkeyPatch,
3840
model: str,
3941
max_tokens: int,
4042
tensor_parallel_size: int,
43+
max_num_seqs: int,
4144
) -> None:
4245
prompt = "The next numbers of the sequence " + ", ".join(
4346
str(i) for i in range(1024)) + " are:"
@@ -51,9 +54,9 @@ def test_basic(
5154
# Note: max_num_batched_tokens == 1024 is needed here to
5255
# actually test chunked prompt
5356
max_num_batched_tokens=1024,
54-
max_model_len=8196,
57+
max_model_len=8192,
5558
gpu_memory_utilization=0.7,
56-
max_num_seqs=16,
59+
max_num_seqs=max_num_seqs,
5760
tensor_parallel_size=tensor_parallel_size) as vllm_model:
5861
vllm_outputs = vllm_model.generate_greedy(example_prompts,
5962
max_tokens)

vllm/platforms/tpu.py

+14
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,20 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
9797
"Using bfloat16 instead.", vllm_config.model_config.dtype)
9898
vllm_config.model_config.dtype = torch.bfloat16
9999

100+
if envs.VLLM_USE_V1:
101+
from vllm.v1.attention.backends.pallas import (
102+
PallasAttentionBackend)
103+
min_page_size = PallasAttentionBackend.get_min_page_size(
104+
vllm_config)
105+
if min_page_size > vllm_config.cache_config.block_size:
106+
logger.warning(
107+
"Increase the page size from %s to %s to make sure there's"
108+
"no SMEM OOM",
109+
vllm_config.cache_config.block_size,
110+
min_page_size,
111+
)
112+
vllm_config.cache_config.block_size = min_page_size
113+
100114
parallel_config = vllm_config.parallel_config
101115
scheduler_config = vllm_config.scheduler_config
102116
if parallel_config.worker_cls == "auto":

vllm/v1/attention/backends/pallas.py

+15
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1111
AttentionLayer, AttentionType)
1212
from vllm.attention.backends.utils import CommonAttentionState
13+
from vllm.config import VllmConfig
1314
from vllm.logger import init_logger
15+
from vllm.utils import cdiv
1416

1517
logger = init_logger(__name__)
1618

@@ -50,6 +52,19 @@ def swap_blocks(
5052
) -> None:
5153
raise RuntimeError("swap_blocks is not used for the TPU backend.")
5254

55+
# In recent TPU generations, up to v6e, the SMEM size is 1MB. The
56+
# block_tables within the PallasMetadata constitute almost the entire SMEM
57+
# requirement. Its size is max_num_seqs * num_page_per_seq * 4 (Int). Here
58+
# we simply make sure that the size is smaller than half of SMEM capacity.
59+
@staticmethod
60+
def get_min_page_size(vllm_config: VllmConfig) -> int:
61+
max_num_page_per_req = (1024 * 1024 // 2 //
62+
vllm_config.scheduler_config.max_num_seqs // 4)
63+
min_page_size = cdiv(vllm_config.model_config.max_model_len,
64+
max_num_page_per_req)
65+
min_page_size = 1 << (min_page_size - 1).bit_length()
66+
return min_page_size
67+
5368

5469
@dataclass
5570
class PallasMetadata:

0 commit comments

Comments
 (0)