Skip to content

[issue tracker] make vllm compatible with dynamo #8821

Closed
@youkaichao

Description

@youkaichao

Anything you want to discuss about vllm.

The first step to enable torch.compile, is to use dynamo to capture the graph. while dynamo can handle many python features, every time there is a python side change, dynamo will try to re-compile the code.

for example:

# test.py
import torch

@torch.compile
def f(x, i):
    return (x + i) * i

x = torch.randn(5, 5).cuda()
f(x, 1)
f(x, 2)
f(x, 3)

running the code with TORCH_LOGS=recompiles_verbose python test.py , we can get:

V0925 13:16:45.714159 140477991954240 torch/_dynamo/guards.py:2609] [0/1] [__recompiles_verbose] Recompiling function f in /data/youkaichao/vllm/testb.py:3
V0925 13:16:45.714159 140477991954240 torch/_dynamo/guards.py:2609] [0/1] [__recompiles_verbose]     triggered by the following guard failure(s):
V0925 13:16:45.714159 140477991954240 torch/_dynamo/guards.py:2609] [0/1] [__recompiles_verbose]     guard 0 failures:
V0925 13:16:45.714159 140477991954240 torch/_dynamo/guards.py:2609] [0/1] [__recompiles_verbose]     - L['i'] == 1                                                 
V0925 13:16:45.882663 140477991954240 torch/_dynamo/guards.py:2609] [0/2] [__recompiles_verbose] Recompiling function f in /data/youkaichao/vllm/testb.py:3
V0925 13:16:45.882663 140477991954240 torch/_dynamo/guards.py:2609] [0/2] [__recompiles_verbose]     triggered by the following guard failure(s):
V0925 13:16:45.882663 140477991954240 torch/_dynamo/guards.py:2609] [0/2] [__recompiles_verbose]     guard 0 failures:
V0925 13:16:45.882663 140477991954240 torch/_dynamo/guards.py:2609] [0/2] [__recompiles_verbose]     - L['i'] == 2                                                 
V0925 13:16:45.882663 140477991954240 torch/_dynamo/guards.py:2609] [0/2] [__recompiles_verbose] 
V0925 13:16:45.882663 140477991954240 torch/_dynamo/guards.py:2609] [0/2] [__recompiles_verbose]     guard 1 failures:
V0925 13:16:45.882663 140477991954240 torch/_dynamo/guards.py:2609] [0/2] [__recompiles_verbose]     - L['i'] == 1             

every function call is a re-compilation, because pytorch will embed the constant into the graph, and the graph is only re-usable when i equals to that value.

this is because torch.compile aims to compile tensor-program, a program that only generalizes to tensors. it does not generalize to Python integers.

to solve the problem, we need to wrap the integer into a tensor, so that pytorch will re-use the graph as long as the tensor metadata (device, shape, dtype, etc) matches, the graph can be re-used:

# test.py
import torch

@torch.compile
def f(x, i):
    i = i.cuda()
    return (x + i) * i

x = torch.randn(5, 5).cuda()
f(x, torch.tensor(1))
f(x, torch.tensor(2))
f(x, torch.tensor(3))

this code will not teigger re-compilation.

to integrate with dynamo, we need to carefully design the warmup scheme, so that we have compiled for all use cases, and future run will not trigger compilation. (if a new user request triggers compilation, the TTFT will be several seconds because of compilation).

our first goal, is to remove unnecessary Python side changes every time we run the model. the changes can be found from the following code:

import os

os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"

from vllm import LLM, SamplingParams
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
llm = LLM(model="meta-llama/Meta-Llama-3-8B",
            enforce_eager=True,
            tensor_parallel_size=1,
            disable_custom_all_reduce=True)

# the first batch will compile
outputs = llm.generate(prompts[:1], sampling_params)

# the second batch might also compile, and enable dynamic shape automatically
outputs = llm.generate(prompts[:2], sampling_params)

print("warm up done" + "\n" * 10)

outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

We use two different batches of requests to warm up the compilation, and then pytorch should capture and compile graphs for all the tensor variations. the final run will reveal all the python side variation we have, which we need to remove.

after warmup, we can see the following re-compilation:

warm up done


Processed prompts:   0%|                                             | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s][rank0]:V0925 13:34:27.040958 139966991710016 torch/_dynamo/guards.py:2609] [0/7] [__recompiles_verbose] Recompiling function forward in /data/youkaichao/vllm/vllm/model_executor/models/llama.py:440
[rank0]:V0925 13:34:27.040958 139966991710016 torch/_dynamo/guards.py:2609] [0/7] [__recompiles_verbose]     triggered by the following guard failure(s):
[rank0]:V0925 13:34:27.040958 139966991710016 torch/_dynamo/guards.py:2609] [0/7] [__recompiles_verbose]     guard 0 failures:
[rank0]:V0925 13:34:27.040958 139966991710016 torch/_dynamo/guards.py:2609] [0/7] [__recompiles_verbose]     - L['attn_metadata'].num_decode_tokens == 0                   
[rank0]:V0925 13:34:27.040958 139966991710016 torch/_dynamo/guards.py:2609] [0/7] [__recompiles_verbose] 
[rank0]:V0925 13:34:27.040958 139966991710016 torch/_dynamo/guards.py:2609] [0/7] [__recompiles_verbose]     guard 1 failures:
[rank0]:V0925 13:34:27.040958 139966991710016 torch/_dynamo/guards.py:2609] [0/7] [__recompiles_verbose]     - tensor 'L['input_ids']' size mismatch at index 0. expected 2, actual 4
[rank0]:V0925 13:34:27.040958 139966991710016 torch/_dynamo/guards.py:2609] [0/7] [__recompiles_verbose] 
[rank0]:V0925 13:34:27.040958 139966991710016 torch/_dynamo/guards.py:2609] [0/7] [__recompiles_verbose]     guard 2 failures:
[rank0]:V0925 13:34:27.040958 139966991710016 torch/_dynamo/guards.py:2609] [0/7] [__recompiles_verbose]     - L['attn_metadata'].num_decode_tokens == 2                   
[rank0]:V0925 13:34:27.040958 139966991710016 torch/_dynamo/guards.py:2609] [0/7] [__recompiles_verbose] 
[rank0]:V0925 13:34:27.040958 139966991710016 torch/_dynamo/guards.py:2609] [0/7] [__recompiles_verbose]     guard 3 failures:
[rank0]:V0925 13:34:27.040958 139966991710016 torch/_dynamo/guards.py:2609] [0/7] [__recompiles_verbose]     - tensor 'L['input_ids']' size mismatch at index 0. expected 1, actual 4
[rank0]:V0925 13:34:27.040958 139966991710016 torch/_dynamo/guards.py:2609] [0/7] [__recompiles_verbose] 
[rank0]:V0925 13:34:27.040958 139966991710016 torch/_dynamo/guards.py:2609] [0/7] [__recompiles_verbose]     guard 4 failures:
[rank0]:V0925 13:34:27.040958 139966991710016 torch/_dynamo/guards.py:2609] [0/7] [__recompiles_verbose]     - tensor 'L['attn_metadata']._cached_decode_metadata.block_tables' size mismatch at index 0. expected 1, actual 4
[rank0]:V0925 13:34:27.040958 139966991710016 torch/_dynamo/guards.py:2609] [0/7] [__recompiles_verbose] 
[rank0]:V0925 13:34:27.040958 139966991710016 torch/_dynamo/guards.py:2609] [0/7] [__recompiles_verbose]     guard 5 failures:
[rank0]:V0925 13:34:27.040958 139966991710016 torch/_dynamo/guards.py:2609] [0/7] [__recompiles_verbose]     - L['attn_metadata'].num_prefills == 1                        
[rank0]:V0925 13:34:27.040958 139966991710016 torch/_dynamo/guards.py:2609] [0/7] [__recompiles_verbose] 
[rank0]:V0925 13:34:27.040958 139966991710016 torch/_dynamo/guards.py:2609] [0/7] [__recompiles_verbose]     guard 6 failures:
[rank0]:V0925 13:34:27.040958 139966991710016 torch/_dynamo/guards.py:2609] [0/7] [__recompiles_verbose]     - tensor 'L['input_ids']' dispatch key set mismatch. expected DispatchKeySet(CUDA, BackendSelect), actual DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView)
[rank0]:V0925 13:34:39.292751 139966991710016 torch/_dynamo/guards.py:2609] [0/8] [__recompiles_verbose] Recompiling function forward in /data/youkaichao/vllm/vllm/model_executor/models/llama.py:440
[rank0]:V0925 13:34:39.292751 139966991710016 torch/_dynamo/guards.py:2609] [0/8] [__recompiles_verbose]     triggered by the following guard failure(s):
[rank0]:V0925 13:34:39.292751 139966991710016 torch/_dynamo/guards.py:2609] [0/8] [__recompiles_verbose]     guard 0 failures:
[rank0]:V0925 13:34:39.292751 139966991710016 torch/_dynamo/guards.py:2609] [0/8] [__recompiles_verbose]     - tensor 'L['attn_metadata']._cached_decode_metadata.block_tables' stride mismatch at index 0. expected 1, actual 2
[rank0]:V0925 13:34:39.292751 139966991710016 torch/_dynamo/guards.py:2609] [0/8] [__recompiles_verbose] 
[rank0]:V0925 13:34:39.292751 139966991710016 torch/_dynamo/guards.py:2609] [0/8] [__recompiles_verbose]     guard 1 failures:
[rank0]:V0925 13:34:39.292751 139966991710016 torch/_dynamo/guards.py:2609] [0/8] [__recompiles_verbose]     - L['attn_metadata'].num_decode_tokens == 0                   
[rank0]:V0925 13:34:39.292751 139966991710016 torch/_dynamo/guards.py:2609] [0/8] [__recompiles_verbose] 
[rank0]:V0925 13:34:39.292751 139966991710016 torch/_dynamo/guards.py:2609] [0/8] [__recompiles_verbose]     guard 2 failures:
[rank0]:V0925 13:34:39.292751 139966991710016 torch/_dynamo/guards.py:2609] [0/8] [__recompiles_verbose]     - tensor 'L['input_ids']' size mismatch at index 0. expected 2, actual 4
[rank0]:V0925 13:34:39.292751 139966991710016 torch/_dynamo/guards.py:2609] [0/8] [__recompiles_verbose] 
[rank0]:V0925 13:34:39.292751 139966991710016 torch/_dynamo/guards.py:2609] [0/8] [__recompiles_verbose]     guard 3 failures:
[rank0]:V0925 13:34:39.292751 139966991710016 torch/_dynamo/guards.py:2609] [0/8] [__recompiles_verbose]     - tensor 'L['attn_metadata']._cached_decode_metadata.block_tables' stride mismatch at index 0. expected 1, actual 2
[rank0]:V0925 13:34:39.292751 139966991710016 torch/_dynamo/guards.py:2609] [0/8] [__recompiles_verbose] 
[rank0]:V0925 13:34:39.292751 139966991710016 torch/_dynamo/guards.py:2609] [0/8] [__recompiles_verbose]     guard 4 failures:
[rank0]:V0925 13:34:39.292751 139966991710016 torch/_dynamo/guards.py:2609] [0/8] [__recompiles_verbose]     - tensor 'L['input_ids']' size mismatch at index 0. expected 1, actual 4
[rank0]:V0925 13:34:39.292751 139966991710016 torch/_dynamo/guards.py:2609] [0/8] [__recompiles_verbose] 
[rank0]:V0925 13:34:39.292751 139966991710016 torch/_dynamo/guards.py:2609] [0/8] [__recompiles_verbose]     guard 5 failures:
[rank0]:V0925 13:34:39.292751 139966991710016 torch/_dynamo/guards.py:2609] [0/8] [__recompiles_verbose]     - tensor 'L['attn_metadata']._cached_decode_metadata.block_tables' size mismatch at index 0. expected 1, actual 4
[rank0]:V0925 13:34:39.292751 139966991710016 torch/_dynamo/guards.py:2609] [0/8] [__recompiles_verbose] 
[rank0]:V0925 13:34:39.292751 139966991710016 torch/_dynamo/guards.py:2609] [0/8] [__recompiles_verbose]     guard 6 failures:
[rank0]:V0925 13:34:39.292751 139966991710016 torch/_dynamo/guards.py:2609] [0/8] [__recompiles_verbose]     - L['attn_metadata'].num_prefills == 1                        
[rank0]:V0925 13:34:39.292751 139966991710016 torch/_dynamo/guards.py:2609] [0/8] [__recompiles_verbose] 
[rank0]:V0925 13:34:39.292751 139966991710016 torch/_dynamo/guards.py:2609] [0/8] [__recompiles_verbose]     guard 7 failures:
[rank0]:V0925 13:34:39.292751 139966991710016 torch/_dynamo/guards.py:2609] [0/8] [__recompiles_verbose]     - tensor 'L['input_ids']' dispatch key set mismatch. expected DispatchKeySet(CUDA, BackendSelect), actual DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView)

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions