Skip to content

Commit

Permalink
[core] clean up cudagraph batchsize padding logic (vllm-project#10996)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <youkaichao@gmail.com>
  • Loading branch information
youkaichao authored Dec 13, 2024
1 parent 34f1a80 commit be39e3c
Show file tree
Hide file tree
Showing 11 changed files with 150 additions and 104 deletions.
5 changes: 3 additions & 2 deletions tests/models/decoder_only/language/test_jamba.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from tests.utils import multi_gpu_test
from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.sampling_params import SamplingParams

from ...utils import check_outputs_equal
Expand Down Expand Up @@ -189,7 +189,8 @@ def test_mamba_cache_cg_padding(
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
while len(example_prompts) == VllmConfig.get_graph_batch_size(
vllm_config = EngineArgs(model=model).create_engine_config()
while len(example_prompts) == vllm_config.pad_for_cudagraph(
len(example_prompts)):
example_prompts.append(example_prompts[0])

Expand Down
5 changes: 3 additions & 2 deletions tests/models/decoder_only/language/test_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from transformers import AutoModelForCausalLM, AutoTokenizer

from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.sampling_params import SamplingParams

from ...utils import check_outputs_equal
Expand Down Expand Up @@ -200,7 +200,8 @@ def test_mamba_cache_cg_padding(
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
while len(example_prompts) == VllmConfig.get_graph_batch_size(
vllm_config = EngineArgs(model=model).create_engine_config()
while len(example_prompts) == vllm_config.pad_for_cudagraph(
len(example_prompts)):
example_prompts.append(example_prompts[0])

Expand Down
4 changes: 2 additions & 2 deletions tests/worker/test_encoder_decoder_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pytest
import torch

from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
Expand Down Expand Up @@ -548,7 +547,8 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
# With CUDA Graph capture and replay enabled, the decoder and encoder
# input sequences will be padded. Create the expected padded tensors
# accordingly.
graph_batch_size = VllmConfig.get_graph_batch_size(expanded_batch_size)
graph_batch_size = model_runner.vllm_config.pad_for_cudagraph(
expanded_batch_size)
cuda_graph_pad_size = graph_batch_size - expanded_batch_size
padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size))
padded_encoder_seq_lens = encoder_seq_lens + list(
Expand Down
4 changes: 2 additions & 2 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import pytest
import torch

from vllm.config import VllmConfig
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.engine.arg_utils import EngineArgs
Expand Down Expand Up @@ -177,7 +176,8 @@ def test_prepare_decode_cuda_graph(batch_size):
model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
assert len(slot_mapping) == len(input_tokens)

expected_bs = VllmConfig.get_graph_batch_size(len(seq_group_metadata_list))
expected_bs = model_runner.vllm_config.pad_for_cudagraph(
len(seq_group_metadata_list))
# Verify input metadata is correct for prompts.
device = model_runner.device
assert attn_metadata.num_prefills == 0
Expand Down
171 changes: 105 additions & 66 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2354,6 +2354,12 @@ def model_post_init(self, __context: Any) -> None:
# not configurable, computed after init
compile_sizes: List[int] = PrivateAttr
capture_sizes: List[int] = PrivateAttr
max_capture_size: int = PrivateAttr
# optimization:
# Intuitively, bs_to_padded_graph_size should be Dict[int, int].
# since we know all keys are in a range [0, max_capture_size],
# we can optimize it to List[int] for better lookup performance.
bs_to_padded_graph_size: List[int] = PrivateAttr

# keep track of enabled and disabled custom ops
enabled_custom_ops: Counter[str] = PrivateAttr
Expand All @@ -2365,6 +2371,19 @@ def model_post_init(self, __context: Any) -> None:
# Map from layer name to the attention cls
static_forward_context: Dict[str, Any] = PrivateAttr

def __repr__(self) -> str:
exclude = {
"static_forward_context",
"enabled_custom_ops",
"disabled_custom_ops",
"compilation_time",
"bs_to_padded_graph_size",
"pass_config",
}
return self.model_dump_json(exclude=exclude, exclude_unset=True)

__str__ = __repr__

@classmethod
def from_cli(cls, cli_value: str) -> "CompilationConfig":
"""Parse the CLI value for the compilation config."""
Expand Down Expand Up @@ -2450,18 +2469,22 @@ def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]):

# sort to make sure cudagraph capture sizes are in descending order
self.capture_sizes.sort(reverse=True)
self.max_capture_size = self.capture_sizes[
0] if self.capture_sizes else 0


_BATCH_SIZE_ALIGNMENT = 8
# all the token sizes that **can** be captured by cudagraph.
# they can be arbitrarily large.
# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
# the actual sizes to capture will be determined by the model,
# depending on the model's max_num_seqs.
# NOTE: get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025)
]
# pre-compute the mapping from batch size to padded graph size
self.bs_to_padded_graph_size = [
0 for i in range(self.max_capture_size + 1)
]
for end, start in zip(self.capture_sizes,
self.capture_sizes[1:] + [0]):
for bs in range(start, end):
if bs == start:
self.bs_to_padded_graph_size[bs] = start
else:
self.bs_to_padded_graph_size[bs] = end
self.bs_to_padded_graph_size[
self.max_capture_size] = self.max_capture_size


@dataclass
Expand Down Expand Up @@ -2491,40 +2514,12 @@ class VllmConfig:
init=True) # type: ignore
instance_id: str = ""

@staticmethod
def get_graph_batch_size(batch_size: int) -> int:
"""Returns the padded batch size given actual batch size.
Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
"""
if batch_size <= 2:
return batch_size
elif batch_size <= 4:
return 4
else:
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)

@staticmethod
def get_max_graph_batch_size(max_num_seqs: int) -> int:
"""
max_num_seqs: Maximum number of sequences in a batch.
_BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.
pad the max_num_seqs if necessary by calling get_graph_batch_size,
which will deal with some edge cases like 1, 2, 4.
if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded
size. if not, it means the padded size is larger than the largest size
in _BATCH_SIZES_TO_CAPTURE, return the largest size in
_BATCH_SIZES_TO_CAPTURE.
"""
padded_size = VllmConfig.get_graph_batch_size(max_num_seqs)
if padded_size in _BATCH_SIZES_TO_CAPTURE:
return padded_size
assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]
return _BATCH_SIZES_TO_CAPTURE[-1]
def pad_for_cudagraph(self, batch_size: int) -> int:
# if batch_size > self.compilation_config.max_capture_size,
# it should raise an IndexError.
# the caller should make sure the batch_size is within the range,
# i.e., batch_size <= self.compilation_config.max_capture_size
return self.compilation_config.bs_to_padded_graph_size[batch_size]

@staticmethod
def _get_quantization_config(
Expand Down Expand Up @@ -2618,27 +2613,7 @@ def __post_init__(self):
self.compilation_config.pass_config.enable_reshape = False
self.compilation_config.level = CompilationLevel.PIECEWISE

if not envs.VLLM_USE_V1:
max_batchsize_to_capture = 0
if self.scheduler_config is not None and \
self.model_config is not None and \
not self.model_config.enforce_eager:
max_batchsize_to_capture = \
self.get_max_graph_batch_size(
self.scheduler_config.max_num_seqs)
batch_size_capture_list = [
size for size in _BATCH_SIZES_TO_CAPTURE
if size <= max_batchsize_to_capture
]
else:
batch_size_capture_list = []
if self.model_config is not None and \
not self.model_config.enforce_eager:
batch_size_capture_list = [1, 2, 4
] + [i for i in range(8, 513, 8)]

self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list)
self._set_cudagraph_sizes()

if self.cache_config is not None and \
self.cache_config.cpu_offload_gb > 0 and \
Expand All @@ -2659,6 +2634,70 @@ def __post_init__(self):
if not self.instance_id:
self.instance_id = random_uuid()[:5]

def _set_cudagraph_sizes(self):
"""
cudagraph batchsize padding logic:
`[1, 2, 4] + [8 * i for i in range(1, 1025)]` is a list of all possible
batch sizes that cudagraph will capture.
Depending on the engine's configuration of `max_num_seqs`, the
candidate batch sizes to capture cudagraph will shrink to the subset
which just cover the range of `[1, max_num_seqs]`. In the common case,
`max_num_seqs` is 256, and the cudagraph batch sizes will be
`[1, 2, 4, 8, 16, 24, 32, 40, ..., 256]`.
However, if users specify the cudagraph capture sizes through
compilation config, we will use the specified sizes instead.
In the end, `vllm_config.compilation_config.capture_sizes` will be the
final sizes to capture cudagraph (in descending order).
During runtime, if batchsize is larger than
`vllm_config.compilation_config.capture_sizes`,
no cudagraph will be used.
If the batch size is no larger than
`vllm_config.compilation_config.capture_sizes`,
we can quickly find the padded graph size for a given batch size by
looking up `vllm_config.compilation_config.bs_to_padded_graph_size`.
"""

# calculate the default `batch_size_capture_list`
if not envs.VLLM_USE_V1:
batch_size_capture_list = []
max_batchsize_to_capture = 0
if self.scheduler_config is not None and \
self.model_config is not None and \
not self.model_config.enforce_eager:

possible_sizes = [1, 2, 4] + [8 * i for i in range(1, 1025)]
# find the minimum size that is larger than max_num_seqs,
# which then becomes the max_batchsize_to_capture
larger_sizes = [
x for x in possible_sizes
if x >= self.scheduler_config.max_num_seqs
]
if larger_sizes:
max_batchsize_to_capture = larger_sizes[0]
else:
max_batchsize_to_capture = possible_sizes[-1]

# filter out the sizes that are
# larger than max_batchsize_to_capture
batch_size_capture_list = [
size for size in possible_sizes
if size <= max_batchsize_to_capture
]
else:
batch_size_capture_list = []
if self.model_config is not None and \
not self.model_config.enforce_eager:
batch_size_capture_list = [1, 2, 4
] + [i for i in range(8, 513, 8)]

self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list)

def __str__(self):
return (
f"model={self.model_config.model!r},"
Expand Down
20 changes: 14 additions & 6 deletions vllm/model_executor/models/jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention
from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.fused_moe import FusedMoE
Expand Down Expand Up @@ -420,6 +420,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
if self.scheduler_config is not None and \
not self.model_config.enforce_eager:
if self.scheduler_config.max_num_seqs > \
vllm_config.compilation_config.max_capture_size:
self.max_batch_size = \
vllm_config.compilation_config.max_capture_size
else:
self.max_batch_size = vllm_config.pad_for_cudagraph(
self.scheduler_config.max_num_seqs)
else:
self.max_batch_size = 8192 + 2

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
Expand All @@ -433,15 +444,12 @@ def forward(self,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:
max_batch_size = (VllmConfig.get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config
else max(_BATCH_SIZES_TO_CAPTURE) + 2)

num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager(
self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
*self._get_mamba_cache_shape())
self.lm_head.weight.dtype, num_mamba_layers,
self.max_batch_size, *self._get_mamba_cache_shape())
(
mamba_cache_tensors,
state_indices_tensor,
Expand Down
21 changes: 14 additions & 7 deletions vllm/model_executor/models/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from transformers import MambaConfig

from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm
Expand Down Expand Up @@ -195,6 +195,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.make_empty_intermediate_tensors = (
self.backbone.make_empty_intermediate_tensors)
if self.scheduler_config is not None and \
not self.model_config.enforce_eager:
if self.scheduler_config.max_num_seqs > \
vllm_config.compilation_config.max_capture_size:
self.max_batch_size = \
vllm_config.compilation_config.max_capture_size
else:
self.max_batch_size = vllm_config.pad_for_cudagraph(
self.scheduler_config.max_num_seqs)
else:
self.max_batch_size = 8192 + 2

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.backbone.get_input_embeddings(input_ids)
Expand All @@ -208,15 +219,11 @@ def forward(self,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:
max_batch_size = (VllmConfig.get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config
else max(_BATCH_SIZES_TO_CAPTURE) + 2)

num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager(
self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
*self._get_mamba_cache_shape())
self.lm_head.weight.dtype, num_mamba_layers,
self.max_batch_size, *self._get_mamba_cache_shape())

(
mamba_cache_tensors,
Expand Down
Loading

0 comments on commit be39e3c

Please sign in to comment.