Skip to content

[TPU] support attention head dim smaller than 128 #19620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions tests/v1/tpu/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,43 @@ def test_basic(
assert "1024" in output or "0, 1" in output


@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This is a basic test for TPU only")
@pytest.mark.parametrize("max_tokens", [8])
@pytest.mark.parametrize("max_num_seqs", [16])
def test_phi3(
vllm_runner: type[VllmRunner],
monkeypatch: pytest.MonkeyPatch,
max_tokens: int,
max_num_seqs: int,
) -> None:
prompts = [
"A robot may not injure a human being",
"It is only with the heart that one can see rightly;",
"The greatest glory in living lies not in never falling,",
]
answers = [
" or, by violating privacy",
" what is essential is love.",
" but in rising every time we fall.",
]
# test head dim = 96
model = "microsoft/Phi-3-mini-128k-instruct"

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

with vllm_runner(model,
max_num_batched_tokens=256,
max_num_seqs=max_num_seqs) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens)
# vllm_outputs is a list of tuples whose first element is the token id
# and the second element is the output (including the prompt).
for output, answer in zip(vllm_outputs, answers):
generated_text = output[1]
assert answer in generated_text


TP_SIZE_8 = 8


Expand Down
35 changes: 28 additions & 7 deletions vllm/v1/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

logger = init_logger(__name__)

# TPU requires the head size to be a multiple of 128.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it more of a Pallas kernel requirement?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the fundamental issue lies with the TPU hardware. We need to implement padding, either within the model or the kernel. In this case, we've opted to pad at the model level.

TPU_HEAD_SIZE_ALIGNMENT = 128


class PallasAttentionBackend(AttentionBackend):

Expand All @@ -43,6 +46,14 @@ def get_kv_cache_shape(
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
padded_head_size = cdiv(
head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
num_blocks = num_blocks * head_size // padded_head_size
if padded_head_size != head_size:
logger.warning_once(
"head size is padded to %d, and num_blocks is adjusted to %d"
" accordingly", padded_head_size, num_blocks)
head_size = padded_head_size
return (num_blocks, block_size, num_kv_heads * 2, head_size)

@staticmethod
Expand Down Expand Up @@ -132,8 +143,6 @@ def __init__(
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if head_size % 128 != 0:
raise NotImplementedError("Head size must be a multiple of 128.")
if alibi_slopes is not None:
raise NotImplementedError("Alibi slopes is not supported.")
if kv_cache_dtype != "auto":
Expand Down Expand Up @@ -187,6 +196,18 @@ def forward(
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
num_tokens, hidden_size = query.shape
query = query.view(num_tokens, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
padded_head_size = cdiv(
self.head_size,
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
query = torch.nn.functional.pad(
query, (0, padded_head_size - self.head_size), value=0.0)
key = torch.nn.functional.pad(
key, (0, padded_head_size - self.head_size), value=0.0)
value = torch.nn.functional.pad(
value, (0, padded_head_size - self.head_size), value=0.0)

if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0:
# Write input keys and values to the KV cache.
Expand All @@ -213,6 +234,9 @@ def forward(
soft_cap=self.logits_soft_cap,
)

if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
output = output[:, :, :self.head_size]

return output.reshape(num_tokens, hidden_size)


Expand All @@ -231,11 +255,8 @@ def write_to_kv_cache(

"""
_, _, num_combined_kv_heads, head_size = kv_cache.shape
num_kv_heads = num_combined_kv_heads // 2

key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)

head_size = cdiv(head_size,
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
head_size)

Expand Down
Loading