Skip to content

[Bug]: Unexpected decode graph compilation after preemption #158

@tae-su-kim

Description

@tae-su-kim

Anything you want to discuss about vllm.

On vllm-fork and habana_next branch (commit 067a243), preemption can cause unexpected decode graph compilation. It can be reproduced with following command:

VLLM_GRAPH_RESERVED_MEM=0.08 VLLM_GRAPH_PROMPT_RATIO=0.1 python3 benchmarks/benchmark_throughput.py --input-len 2048 --output-len 1024 --model Meta-Llama-3-8B-Instruct --num-prompts 512 --max-model-len 3072 --device hpu

Also, add following code to line 103 in benchmark_throughput.py to set block_size=128:

        enable_chunked_prefill=enable_chunked_prefill,
        max_num_batched_tokens=max_num_batched_tokens,
        *block_size=128*
    )

Please note that VLLM_GRAPH_RESERVED_MEM=0.08 VLLM_GRAPH_PROMPT_RATIO=0.1 captures 100% of pre-determined prefill and decode graphs. With this setup, 3935 blocks can be allocated.

INFO 07-24 12:11:53 habana_executor.py:78] # HPU blocks: 3935, # CPU blocks: 256

Early logs are as below:

WARNING 07-24 12:16:57 habana_model_runner.py:941] Configuration: (prompt, 1, 2048) was not warmed-up!
WARNING 07-24 12:17:25 habana_model_runner.py:941] Configuration: (decode, 256, 4352) was not warmed-up!
WARNING 07-24 12:17:42 habana_model_runner.py:941] Configuration: (decode, 224, 4096) was not warmed-up!

While prompt graph miss can be easily handled (please check PR 109), decode graph misses are unexpected. You can see that the length is longer than max_model_len.

This issue stems from (L958~961) in habana_model_runner.py:

batch_size_padded = find_bucket(real_batch_size, bucket_cfg)
batch_size_padding = batch_size_padded - real_batch_size
seq_group_metadata_list = seq_group_metadata_list.copy()
seq_group_metadata_list.extend(seq_group_metadata_list[0] for _ in range(batch_size_padding))

After preemption, real_batch_size decreases from max_seq_len (ex. 256 → 243). In this case, L961 pads seq_group_metadata_list to 256 batch for bucketing, and padded seq_group_metadata_list[0] includes non-zero block_table for the first decode request in the batch. Therefore, L697 in habana_model_runner.py creates a DecodeMetadata with block_tables which is longer than the number of maximum pages, leading to unpredictable sequence length for the decode graphs.

To solve this problem, one suggested way will be to pad along the batch dimension with seq_group_metadata without block_table. We can temporarily handle this with increased sequence length for the decode bucket, but it will increase the memory for captured graphs.

Metadata

Metadata

Assignees

Labels

externalIssues or PRs submitted by external users

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions