Description
Progress
- Implement TPU executor that works on a single TPU chip (without tensor parallelism) [Hardware] Initial TPU integration #5292
- Support single-host tensor parallel inference [Hardware][TPU] Implement tensor parallelism with Ray #5871
- Support multi-host tensor parallel inference [TPU] Support multi-host inference #7457
- Support INT8 quantization
- Support MoE models such as Mixtral [Hardware][TPU] Support MoE with Pallas GMM kernel #6457
- Benchmark and optimize the TPU backend performance
Project Scope
This project focuses on making vLLM compatible with Google cloud TPUs. Our goal is seamless integration so users can easily run vLLM on TPUs for both online and offline inference. We will target common setups, like popular models such as Gemma, using the bfloat16 data type.
Target TPUs and Models
We will focus on the most recent generations of TPUs, namely TPU v4, v5e, and v5p, considering their superior performance to previous generations. We will start by making sure vLLM works with dense models such as Gemma. After that, we will expand support to Mixture-of-Experts (MoE) models such as Mixtral.
Features Not Included (for now)
The following features are outside the scope of this initial project, but we'd like to tackle them in the future:
- Speculative decoding
- GPTQ/AWQ Quantization
- Multi-LoRA serving
Design
Overview

To integrate the TPU backend into vLLM, we will add the new TPU executor and TPU worker which are counterparts of the GPU executor and GPU worker, respectively. Unlike NVIDIA and AMD GPUs that share the same executor and worker, we create a separate code path for TPUs considering the significant difference between GPUs and TPUs. On the other hand, the two backends will share the other components of LLMEngine
, namely the scheduler, KV cache manager, and tokenizer, as they are (almost) device agnostic.
PyTorch XLA and JAX
As many components of vLLM are device and runtime agnostic, it is possible to use JAX for TPU integration. However, for faster initial integration and maximum code reuse, we will start with PyTorch XLA. Adding JAX backend to vLLM will be interesting future work.
TPU Workers

For tensor-parallel inference, the vLLM TPU executor will spin up multiple TPU workers; one TPU worker per TPU chip. Specifically, we will use Ray to connect and manage the TPU workers which may reside in different TPU VMs. Note that we do not plan to support multi-slice inference at the moment, while we will support multi-host inference within the same TPU pod slice.
Same as the GPU executor, the TPU executor will use Megatron-style model partitioning for tensor-parallel inference. The partitioning strategy will be hardcoded into the model by replacing nn.Linear
with RowParallelLinear
and ColumnParallelLinear
. Auto-sharding the model can be our future work.
GPU Executor vs. TPU Executor

For GPUs, vLLM uses both eager mode and CUDA graphs for model execution. Specifically, vLLM uses eager mode for prefills and CUDA graphs for decodes. vLLM currently does not use torch.compile for GPUs, but plans to use it in the future. For TPUs, on the other hand, vLLM will use torch.compile
(with openxla_eval backend) to trace the PyTorch model and lower it into an XLA graph.
While vLLM’s GPU and TPU backends will take separate code paths, they will share the PyTorch model code. Most of the custom ops for GPUs will not be needed for TPUs, since they can be auto-generated by the XLA compiler. Therefore, for each target op, vLLM will have two implementations, _forward
and _forward_cuda
, and select either of the two implementations at run time depending on the hardware backend. For example, we can define the target ops/layers as follows:
class Op(nn.Module):
def _forward(self,...):
# PyTorch implementation that can be optimized by compilers
# such as XLA or torch.compile.
...
def _forward_cuda(self, ...):
# Implementation using custom ops written in CUDA.
...
def forward(self, ...):
if ...:
return self._forward_cuda(...)
else:
return self._forward(...)
Important exceptions to this are the FlashAttention and PagedAttention custom ops, which cannot be generated by the XLA compiler. We will use custom Pallas kernels for them.
Handling Dynamic Shapes
vLLM’s continuous batching has two phases: prefill and decode. vLLM dynamically switches between the two phases based on its scheduling decisions. The input tensor shape for prefills is [batch_size, prefill_len, hidden_size]
while the input tensor shape for decodes is [batch_size, 1, hidden_size]
since LLMs decode tokens one by one (here we do not consider special cases such as speculative decoding). In LLM inference, the batch_size and prefill_len can vary for every step.
To meet the XLA’s static shape requirement, we will bucketize the possible input shapes. For decodes, we will bucketize the batch_size
dimension by creating buckets for batch_size=[8, 16, 24, 32, 40, …, 256]
. For prefills, to reduce the number of compiled graphs, we will fix the batch_size
to 1, and bucketize the prefill_len
dimension by creating buckets for prefill_len=[8, 16, 32, 64, 128, …, max_model_len]
. Given that each prefill input contains enough tokens to efficiently utilize TPUs, fixing batch_size
as 1 will not hurt performance a lot. The specific bucket sizes will be tuned after benchmarking the compilation overhead and end-to-end performance.
References
- PyTorch XLA Llama V1 inference blog post: https://pytorch.org/blog/path-achieve-low-inference-latency/
- PyTorch XLA Llama V2 inference blog post: https://pytorch.org/blog/high-performance-llama-2/
- PyTorch XLA Llama inference example: https://github.com/pytorch-tpu/llama/tree/llama2-google-next-inference