Skip to content

Commit 70755e8

Browse files
authored
[V1][Core] Autotune encoder cache budget (vllm-project#11895)
Signed-off-by: Roger Wang <ywang@roblox.com>
1 parent edce722 commit 70755e8

File tree

6 files changed

+167
-50
lines changed

6 files changed

+167
-50
lines changed

vllm/config.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -1387,13 +1387,15 @@ class SchedulerConfig:
13871387

13881388
is_multimodal_model: bool = False
13891389

1390-
# FIXME(woosuk & ywang96): Below are placeholder values. We need to
1391-
# calculate the actual values from the configurations.
1392-
# Multimodal encoder run compute budget, only used in V1
1393-
max_num_encoder_input_tokens = 16384
1390+
# NOTE: The following multimodal encoder budget will be initialized to
1391+
# max_num_batched_tokens and overridden in case max multimodal embedding
1392+
# size is larger.
1393+
# TODO (ywang96): Make these configurable.
1394+
# Multimodal encoder compute budget, only used in V1
1395+
max_num_encoder_input_tokens: int = field(default=None) # type: ignore
13941396

13951397
# Multimodal encoder cache size, only used in V1
1396-
encoder_cache_size = 16384
1398+
encoder_cache_size: int = field(default=None) # type: ignore
13971399

13981400
# Whether to perform preemption by swapping or
13991401
# recomputation. If not specified, we determine the mode as follows:
@@ -1467,6 +1469,9 @@ def __post_init__(self) -> None:
14671469
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
14681470
)
14691471

1472+
self.max_num_encoder_input_tokens = self.max_num_batched_tokens
1473+
self.encoder_cache_size = self.max_num_batched_tokens
1474+
14701475
if self.enable_chunked_prefill:
14711476
logger.info(
14721477
"Chunked prefill is enabled with max_num_batched_tokens=%d.",

vllm/multimodal/registry.py

+24-5
Original file line numberDiff line numberDiff line change
@@ -252,11 +252,8 @@ def get_max_tokens_per_item_by_modality(
252252
model_config: "ModelConfig",
253253
) -> Mapping[str, int]:
254254
"""
255-
Get the maximum number of tokens per data item from each modality
256-
for profiling the memory usage of a model.
257-
258-
Note:
259-
This is currently directly used only in V1.
255+
Get the maximum number of tokens per data item from each modality based
256+
on underlying model configuration.
260257
"""
261258
if self.has_processor(model_config):
262259
tokenizer = cached_get_tokenizer(
@@ -272,6 +269,28 @@ def get_max_tokens_per_item_by_modality(
272269
for key, plugin in self._plugins.items()
273270
}
274271

272+
def get_max_tokens_per_item_by_nonzero_modality(
273+
self,
274+
model_config: "ModelConfig",
275+
) -> Mapping[str, int]:
276+
"""
277+
Get the maximum number of tokens per data item from each modality based
278+
on underlying model configuration, excluding modalities that user
279+
explicitly disabled via `limit_mm_per_prompt`.
280+
281+
Note:
282+
This is currently directly used only in V1 for profiling the memory
283+
usage of a model.
284+
"""
285+
limits_per_plugin = self._limits_by_model[model_config]
286+
287+
return {
288+
key: max_tokens_per_mm_item
289+
for key, max_tokens_per_mm_item in
290+
self.get_max_tokens_per_item_by_modality(model_config).items()
291+
if limits_per_plugin[key] > 0
292+
}
293+
275294
def get_max_tokens_by_modality(
276295
self,
277296
model_config: "ModelConfig",

vllm/v1/core/encoder_cache_manager.py

+77-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
1-
from typing import Dict, List, Set, Tuple
1+
from typing import TYPE_CHECKING, Dict, List, Set, Tuple
22

3+
from vllm.logger import init_logger
4+
from vllm.multimodal import MULTIMODAL_REGISTRY
35
from vllm.v1.request import Request
46

7+
if TYPE_CHECKING:
8+
from vllm.config import ModelConfig, SchedulerConfig
9+
10+
logger = init_logger(__name__)
11+
512

613
class EncoderCacheManager:
714

@@ -46,3 +53,72 @@ def get_freed_ids(self) -> List[Tuple[str, int]]:
4653
freed = self.freed
4754
self.freed = []
4855
return freed
56+
57+
58+
def compute_encoder_budget(
59+
model_config: "ModelConfig",
60+
scheduler_config: "SchedulerConfig",
61+
) -> Tuple[int, int]:
62+
"""Compute the encoder cache budget based on the model and scheduler
63+
configurations.
64+
65+
Args:
66+
model_config: Model configuration.
67+
scheduler_config: Scheduler configuration.
68+
69+
Returns:
70+
- Compute budget for encoder execution, in unit of number of tokens
71+
in the input sequence.
72+
- Space budget for encoder cache size, in unit of number of tokens
73+
in the input sequence.
74+
"""
75+
76+
if not model_config.is_multimodal_model:
77+
return 0, 0
78+
79+
# TODO: handle encoder-decoder models once we support them.
80+
(
81+
encoder_compute_budget,
82+
encoder_cache_size,
83+
) = _compute_encoder_budget_multimodal(model_config, scheduler_config)
84+
85+
return encoder_compute_budget, encoder_cache_size
86+
87+
88+
def _compute_encoder_budget_multimodal(
89+
model_config: "ModelConfig",
90+
scheduler_config: "SchedulerConfig",
91+
) -> Tuple[int, int]:
92+
"""Compute the encoder cache budget based on the model and scheduler
93+
configurations for a multimodal model.
94+
95+
Args:
96+
model_config: Model configuration.
97+
scheduler_config: Scheduler configuration.
98+
99+
Returns:
100+
- Compute budget for encoder execution, in unit of number of tokens
101+
in the input sequence.
102+
- Space budget for encoder cache size, in unit of number of tokens
103+
in the input sequence.
104+
"""
105+
106+
max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501
107+
model_config)
108+
109+
if not max_tokens_by_modality_dict:
110+
logger.warning(
111+
"All non-text modalities supported by the model have been "
112+
"explicitly disabled via limit_mm_per_prompt. Encoder cache will "
113+
"not be initialized.")
114+
return 0, 0
115+
116+
_, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(),
117+
key=lambda item: item[1])
118+
119+
encoder_compute_budget = max(scheduler_config.max_num_encoder_input_tokens,
120+
max_tokens_per_mm_item)
121+
encoder_cache_size = max(scheduler_config.encoder_cache_size,
122+
max_tokens_per_mm_item)
123+
124+
return encoder_compute_budget, encoder_cache_size

vllm/v1/core/scheduler.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set,
44
Tuple, Union)
55

6-
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
6+
from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig
77
from vllm.logger import init_logger
88
from vllm.sampling_params import SamplingParams
9-
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
9+
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
10+
compute_encoder_budget)
1011
from vllm.v1.core.kv_cache_manager import KVCacheManager
1112
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
1213
from vllm.v1.metrics.stats import SchedulerStats
@@ -25,6 +26,7 @@ class Scheduler:
2526
def __init__(
2627
self,
2728
scheduler_config: SchedulerConfig,
29+
model_config: ModelConfig,
2830
cache_config: CacheConfig,
2931
lora_config: Optional[LoRAConfig],
3032
) -> None:
@@ -69,16 +71,24 @@ def __init__(
6971
self.running_reqs_data: Dict[str, RunningRequestData] = {}
7072

7173
# Encoder-related.
74+
# Calculate encoder cache size if applicable
75+
# NOTE: For now we use the same budget for both compute and space.
76+
# This can be changed when we make encoder cache for embedding caching
77+
# across requests.
78+
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
79+
model_config=model_config,
80+
scheduler_config=scheduler_config,
81+
)
82+
7283
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and
7384
# projector if needed). Currently, we assume that the encoder also
7485
# has the Transformer architecture (e.g., ViT).
75-
self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens #noqa: E501
76-
# NOTE(woosuk): For the models without encoder (e.g., text-only models),
77-
# the encoder cache will not be initialized and used, regardless of
78-
# the cache size. This is because the memory space for the encoder cache
79-
# is preallocated in the profiling run.
86+
self.max_num_encoder_input_tokens = encoder_compute_budget
87+
# NOTE: For the models without encoder (e.g., text-only models),
88+
# the encoder cache will not be initialized because cache size is 0
89+
# for these models.
8090
self.encoder_cache_manager = EncoderCacheManager(
81-
cache_size=self.scheduler_config.encoder_cache_size)
91+
cache_size=encoder_cache_size)
8292

8393
def schedule(self) -> "SchedulerOutput":
8494
# NOTE(woosuk) on the scheduling algorithm:

vllm/v1/engine/core.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,12 @@ def __init__(
5454
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
5555

5656
# Setup scheduler.
57-
self.scheduler = Scheduler(vllm_config.scheduler_config,
58-
vllm_config.cache_config,
59-
vllm_config.lora_config)
57+
self.scheduler = Scheduler(
58+
scheduler_config=vllm_config.scheduler_config,
59+
model_config=vllm_config.model_config,
60+
cache_config=vllm_config.cache_config,
61+
lora_config=vllm_config.lora_config,
62+
)
6063

6164
self.mm_input_mapper_server = MMInputMapperServer(
6265
vllm_config.model_config)

vllm/v1/worker/gpu_model_runner.py

+32-28
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
is_pin_memory_available)
2121
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
2222
FlashAttentionMetadata)
23+
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
2324
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
2425
from vllm.v1.outputs import ModelRunnerOutput
2526
from vllm.v1.sample.metadata import SamplingMetadata
@@ -88,8 +89,12 @@ def __init__(
8889
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
8990
self.mm_input_mapper_profiling.use_cache = False
9091

91-
self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501
92-
self.encoder_cache_size = self.scheduler_config.encoder_cache_size
92+
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
93+
model_config=model_config,
94+
scheduler_config=scheduler_config,
95+
)
96+
self.max_num_encoder_input_tokens = encoder_compute_budget
97+
self.encoder_cache_size = encoder_cache_size
9398

9499
# Lazy initialization
95100
# self.model: nn.Module # Set after load_model
@@ -721,44 +726,30 @@ def profile_run(self) -> None:
721726
]
722727

723728
# Profile with multimodal encoder & encoder cache.
724-
if self.is_multimodal_model:
725-
726-
# Create dummy batch of multimodal inputs.
727-
dummy_request_data = self.input_registry.dummy_data_for_profiling(
728-
model_config=self.model_config,
729-
seq_len=self.max_num_tokens,
730-
mm_registry=self.mm_registry,
731-
)
732-
dummy_mm_data = dummy_request_data.multi_modal_data
729+
# TODO: handle encoder-decoder models once we support them.
730+
if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
731+
and self.encoder_cache_size > 0):
733732

734733
# NOTE: Currently model is profiled with a single non-text
735734
# modality with the max possible input tokens even when
736735
# it supports multiple.
737-
max_tokens_by_modality_dict = self.mm_registry.get_max_tokens_per_item_by_modality( # noqa: E501
736+
max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501
738737
self.model_config)
739-
740738
dummy_data_modality, max_tokens_per_mm_item = max(
741739
max_tokens_by_modality_dict.items(), key=lambda item: item[1])
742740

743741
# Check how many items of this modality can be supported by
744-
# the encoder cache budget.
745-
encoder_cache_budget = min(self.max_num_encoder_input_tokens,
746-
self.encoder_cache_size)
747-
max_num_mm_items_encoder_budget = encoder_cache_budget // \
748-
max_tokens_per_mm_item
749-
750-
# TODO: Allow users to set encoder_cache_budget in case this
751-
# happens.
752-
assert max_num_mm_items_encoder_budget > 0, (
753-
f"Encoder cache budget={encoder_cache_budget} is too small to "
754-
f"support the maximum possible size of multimodal embeddings"
755-
f"={max_tokens_per_mm_item}.")
742+
# the encoder budget.
743+
encoder_budget = min(self.max_num_encoder_input_tokens,
744+
self.encoder_cache_size)
745+
746+
max_num_mm_items_encoder_budget = cdiv(encoder_budget,
747+
max_tokens_per_mm_item)
756748

757749
# Check how many items of this modality can be supported by
758750
# the decoder budget.
759-
max_mm_items_per_req = max(
760-
self.mm_registry.get_mm_limits_per_prompt(
761-
self.model_config).values())
751+
max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt(
752+
self.model_config)[dummy_data_modality]
762753

763754
# NOTE: We do not consider max_num_batched_tokens on purpose
764755
# because the multimodal embeddings can be generated in advance
@@ -769,6 +760,19 @@ def profile_run(self) -> None:
769760
max_num_mm_items = min(max_num_mm_items_encoder_budget,
770761
max_num_mm_items_decoder_budget)
771762

763+
logger.info(
764+
"Encoder cache will be initialized with a budget of %s tokens,"
765+
" and profiled with %s %s items of the maximum feature size.",
766+
encoder_budget, max_num_mm_items, dummy_data_modality)
767+
768+
# Create dummy batch of multimodal inputs.
769+
dummy_request_data = self.input_registry.dummy_data_for_profiling(
770+
model_config=self.model_config,
771+
seq_len=self.max_num_tokens,
772+
mm_registry=self.mm_registry,
773+
)
774+
dummy_mm_data = dummy_request_data.multi_modal_data
775+
772776
# Dummy data definition in V0 may contain multiple multimodal items
773777
# (e.g, multiple images) for a single request, therefore here we
774778
# always replicate first item by max_num_mm_items times since in V1

0 commit comments

Comments
 (0)