Skip to content

[RFC] Initial Support for Cloud TPUs #3620

Closed
@WoosukKwon

Description

@WoosukKwon

Progress

Project Scope

This project focuses on making vLLM compatible with Google cloud TPUs. Our goal is seamless integration so users can easily run vLLM on TPUs for both online and offline inference. We will target common setups, like popular models such as Gemma, using the bfloat16 data type.

Target TPUs and Models

We will focus on the most recent generations of TPUs, namely TPU v4, v5e, and v5p, considering their superior performance to previous generations. We will start by making sure vLLM works with dense models such as Gemma. After that, we will expand support to Mixture-of-Experts (MoE) models such as Mixtral.

Features Not Included (for now)

The following features are outside the scope of this initial project, but we'd like to tackle them in the future:

  • Speculative decoding
  • GPTQ/AWQ Quantization
  • Multi-LoRA serving

Design

Overview

Screenshot 2024-03-25 at 10 43 50 AM

To integrate the TPU backend into vLLM, we will add the new TPU executor and TPU worker which are counterparts of the GPU executor and GPU worker, respectively. Unlike NVIDIA and AMD GPUs that share the same executor and worker, we create a separate code path for TPUs considering the significant difference between GPUs and TPUs. On the other hand, the two backends will share the other components of LLMEngine, namely the scheduler, KV cache manager, and tokenizer, as they are (almost) device agnostic.

PyTorch XLA and JAX

As many components of vLLM are device and runtime agnostic, it is possible to use JAX for TPU integration. However, for faster initial integration and maximum code reuse, we will start with PyTorch XLA. Adding JAX backend to vLLM will be interesting future work.

TPU Workers

Screenshot 2024-03-25 at 10 44 24 AM

For tensor-parallel inference, the vLLM TPU executor will spin up multiple TPU workers; one TPU worker per TPU chip. Specifically, we will use Ray to connect and manage the TPU workers which may reside in different TPU VMs. Note that we do not plan to support multi-slice inference at the moment, while we will support multi-host inference within the same TPU pod slice.

Same as the GPU executor, the TPU executor will use Megatron-style model partitioning for tensor-parallel inference. The partitioning strategy will be hardcoded into the model by replacing nn.Linear with RowParallelLinear and ColumnParallelLinear. Auto-sharding the model can be our future work.

GPU Executor vs. TPU Executor

Screenshot 2024-03-25 at 10 44 40 AM

For GPUs, vLLM uses both eager mode and CUDA graphs for model execution. Specifically, vLLM uses eager mode for prefills and CUDA graphs for decodes. vLLM currently does not use torch.compile for GPUs, but plans to use it in the future. For TPUs, on the other hand, vLLM will use torch.compile (with openxla_eval backend) to trace the PyTorch model and lower it into an XLA graph.

While vLLM’s GPU and TPU backends will take separate code paths, they will share the PyTorch model code. Most of the custom ops for GPUs will not be needed for TPUs, since they can be auto-generated by the XLA compiler. Therefore, for each target op, vLLM will have two implementations, _forward and _forward_cuda, and select either of the two implementations at run time depending on the hardware backend. For example, we can define the target ops/layers as follows:

class Op(nn.Module):

    def _forward(self,...):
        # PyTorch implementation that can be optimized by compilers
        # such as XLA or torch.compile.
        ...

    def _forward_cuda(self, ...):
        # Implementation using custom ops written in CUDA.
        ...

    def forward(self, ...):
        if ...:
            return self._forward_cuda(...)
        else:
            return self._forward(...)

Important exceptions to this are the FlashAttention and PagedAttention custom ops, which cannot be generated by the XLA compiler. We will use custom Pallas kernels for them.

Handling Dynamic Shapes

vLLM’s continuous batching has two phases: prefill and decode. vLLM dynamically switches between the two phases based on its scheduling decisions. The input tensor shape for prefills is [batch_size, prefill_len, hidden_size] while the input tensor shape for decodes is [batch_size, 1, hidden_size] since LLMs decode tokens one by one (here we do not consider special cases such as speculative decoding). In LLM inference, the batch_size and prefill_len can vary for every step.

To meet the XLA’s static shape requirement, we will bucketize the possible input shapes. For decodes, we will bucketize the batch_size dimension by creating buckets for batch_size=[8, 16, 24, 32, 40, …, 256]. For prefills, to reduce the number of compiled graphs, we will fix the batch_size to 1, and bucketize the prefill_len dimension by creating buckets for prefill_len=[8, 16, 32, 64, 128, …, max_model_len]. Given that each prefill input contains enough tokens to efficiently utilize TPUs, fixing batch_size as 1 will not hurt performance a lot. The specific bucket sizes will be tuned after benchmarking the compilation overhead and end-to-end performance.

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    RFCstaleOver 90 days of inactivitytpuRelated to Google TPUs

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions