20
20
is_pin_memory_available )
21
21
from vllm .v1 .attention .backends .flash_attn import (FlashAttentionBackend ,
22
22
FlashAttentionMetadata )
23
+ from vllm .v1 .core .encoder_cache_manager import compute_encoder_budget
23
24
from vllm .v1 .engine .mm_input_mapper import MMInputMapperClient
24
25
from vllm .v1 .outputs import ModelRunnerOutput
25
26
from vllm .v1 .sample .metadata import SamplingMetadata
@@ -88,8 +89,12 @@ def __init__(
88
89
self .mm_input_mapper_profiling = MMInputMapperClient (self .model_config )
89
90
self .mm_input_mapper_profiling .use_cache = False
90
91
91
- self .max_num_encoder_input_tokens = self .scheduler_config .max_num_encoder_input_tokens # noqa: E501
92
- self .encoder_cache_size = self .scheduler_config .encoder_cache_size
92
+ encoder_compute_budget , encoder_cache_size = compute_encoder_budget (
93
+ model_config = model_config ,
94
+ scheduler_config = scheduler_config ,
95
+ )
96
+ self .max_num_encoder_input_tokens = encoder_compute_budget
97
+ self .encoder_cache_size = encoder_cache_size
93
98
94
99
# Lazy initialization
95
100
# self.model: nn.Module # Set after load_model
@@ -721,44 +726,30 @@ def profile_run(self) -> None:
721
726
]
722
727
723
728
# Profile with multimodal encoder & encoder cache.
724
- if self .is_multimodal_model :
725
-
726
- # Create dummy batch of multimodal inputs.
727
- dummy_request_data = self .input_registry .dummy_data_for_profiling (
728
- model_config = self .model_config ,
729
- seq_len = self .max_num_tokens ,
730
- mm_registry = self .mm_registry ,
731
- )
732
- dummy_mm_data = dummy_request_data .multi_modal_data
729
+ # TODO: handle encoder-decoder models once we support them.
730
+ if (self .is_multimodal_model and self .max_num_encoder_input_tokens > 0
731
+ and self .encoder_cache_size > 0 ):
733
732
734
733
# NOTE: Currently model is profiled with a single non-text
735
734
# modality with the max possible input tokens even when
736
735
# it supports multiple.
737
- max_tokens_by_modality_dict = self . mm_registry . get_max_tokens_per_item_by_modality ( # noqa: E501
736
+ max_tokens_by_modality_dict = MULTIMODAL_REGISTRY . get_max_tokens_per_item_by_nonzero_modality ( # noqa: E501
738
737
self .model_config )
739
-
740
738
dummy_data_modality , max_tokens_per_mm_item = max (
741
739
max_tokens_by_modality_dict .items (), key = lambda item : item [1 ])
742
740
743
741
# Check how many items of this modality can be supported by
744
- # the encoder cache budget.
745
- encoder_cache_budget = min (self .max_num_encoder_input_tokens ,
746
- self .encoder_cache_size )
747
- max_num_mm_items_encoder_budget = encoder_cache_budget // \
748
- max_tokens_per_mm_item
749
-
750
- # TODO: Allow users to set encoder_cache_budget in case this
751
- # happens.
752
- assert max_num_mm_items_encoder_budget > 0 , (
753
- f"Encoder cache budget={ encoder_cache_budget } is too small to "
754
- f"support the maximum possible size of multimodal embeddings"
755
- f"={ max_tokens_per_mm_item } ." )
742
+ # the encoder budget.
743
+ encoder_budget = min (self .max_num_encoder_input_tokens ,
744
+ self .encoder_cache_size )
745
+
746
+ max_num_mm_items_encoder_budget = cdiv (encoder_budget ,
747
+ max_tokens_per_mm_item )
756
748
757
749
# Check how many items of this modality can be supported by
758
750
# the decoder budget.
759
- max_mm_items_per_req = max (
760
- self .mm_registry .get_mm_limits_per_prompt (
761
- self .model_config ).values ())
751
+ max_mm_items_per_req = self .mm_registry .get_mm_limits_per_prompt (
752
+ self .model_config )[dummy_data_modality ]
762
753
763
754
# NOTE: We do not consider max_num_batched_tokens on purpose
764
755
# because the multimodal embeddings can be generated in advance
@@ -769,6 +760,19 @@ def profile_run(self) -> None:
769
760
max_num_mm_items = min (max_num_mm_items_encoder_budget ,
770
761
max_num_mm_items_decoder_budget )
771
762
763
+ logger .info (
764
+ "Encoder cache will be initialized with a budget of %s tokens,"
765
+ " and profiled with %s %s items of the maximum feature size." ,
766
+ encoder_budget , max_num_mm_items , dummy_data_modality )
767
+
768
+ # Create dummy batch of multimodal inputs.
769
+ dummy_request_data = self .input_registry .dummy_data_for_profiling (
770
+ model_config = self .model_config ,
771
+ seq_len = self .max_num_tokens ,
772
+ mm_registry = self .mm_registry ,
773
+ )
774
+ dummy_mm_data = dummy_request_data .multi_modal_data
775
+
772
776
# Dummy data definition in V0 may contain multiple multimodal items
773
777
# (e.g, multiple images) for a single request, therefore here we
774
778
# always replicate first item by max_num_mm_items times since in V1
0 commit comments