Skip to content

[V1][TPU] Remove unnecessary padding for running on TPU. #14467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from vllm.attention.backends.utils import CommonAttentionState

# These are the 2 tunable parameters of the paged attention Pallas kernel.
NUM_QUERIES_PER_BLOCK = 16
NUM_KV_PAGES_PER_BLOCK = 256
NUM_QUERIES_PER_BLOCK = 32
NUM_KV_PAGES_PER_BLOCK = 128


class PallasAttentionBackend(AttentionBackend):
Expand Down
20 changes: 4 additions & 16 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
NUM_QUERIES_PER_BLOCK,
PallasAttentionBackend,
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
PallasMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
Expand Down Expand Up @@ -77,10 +75,8 @@ def __init__(
self.block_size = cache_config.block_size
self.max_model_len = model_config.max_model_len
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
self.max_num_tokens = _get_padded_number(
scheduler_config.max_num_batched_tokens, NUM_QUERIES_PER_BLOCK)
self.max_num_reqs = _get_padded_number(scheduler_config.max_num_seqs,
NUM_QUERIES_PER_BLOCK)
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs
Comment on lines -80 to +79
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rounding these up was intentional since the user could specify odd values non divisible by our constraints. I think we should keep this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these constrains from kernel? If yes, we no longer need these paddings because the new kernel removed all the constrains. We can save tons of memory and computing by removing these paddings.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review!

@mgoin , by "user could specify odd values", what "value" are you referring to?

Here both NUM_KV_PAGES_PER_BLOCK, NUM_QUERIES_PER_BLOCK are tunable parameter of the ragged kernel. The padding is needed mainly because the kernel v1 has the constraint that self.max_num_tokens%NUM_QUERIES_PER_BLOCK==0 and self.max_num_blocks_per_req%NUM_KV_PAGES_PER_BLOCK==0.

Early this week we switched the kernel from v1 to v2 where in v2 we don't have such constraints, that's why I think we can remove these constraints.

Also note that here are the "max" num_tokens instead of the actual num_tokens we would use. For the actual num_tokens in the real workload and "warmup", we still pad to the next power of 2: https://github.com/vllm-project/vllm/blob/8ed5421aaa7da24051acdae53c860e6ce6598403/vllm/v1/worker/tpu_model_runner.py#L420C45-L420C66.


# Model-related.
self.num_attn_layers = model_config.get_num_layers_by_block_type(
Expand Down Expand Up @@ -141,16 +137,8 @@ def __init__(
device="cpu")
self.slot_mapping_np = self.slot_mapping_cpu.numpy()

# self.input_batch.block_table has a shape of [max_num_reqs,
# max_num_blocks_per_req]. To reduce the number of recompilation,
# we want the block_table.shape[0] to be num_tokens.
# To make the block_table to be compatible with the paged attention
# kernel, we want the block_table[1] to be multiple of
# NUM_KV_PAGES_PER_BLOCK.
padded_max_num_blocks_per_req = _get_padded_number(
self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK)
self.block_table_cpu = torch.zeros(
(self.max_num_tokens, padded_max_num_blocks_per_req),
(self.max_num_tokens, self.max_num_blocks_per_req),
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
device="cpu")

Expand Down