2020 is_pin_memory_available )
2121from vllm .v1 .attention .backends .flash_attn import (FlashAttentionBackend ,
2222 FlashAttentionMetadata )
23+ from vllm .v1 .core .encoder_cache_manager import compute_encoder_budget
2324from vllm .v1 .engine .mm_input_mapper import MMInputMapperClient
2425from vllm .v1 .outputs import ModelRunnerOutput
2526from vllm .v1 .sample .metadata import SamplingMetadata
@@ -88,8 +89,12 @@ def __init__(
8889 self .mm_input_mapper_profiling = MMInputMapperClient (self .model_config )
8990 self .mm_input_mapper_profiling .use_cache = False
9091
91- self .max_num_encoder_input_tokens = self .scheduler_config .max_num_encoder_input_tokens # noqa: E501
92- self .encoder_cache_size = self .scheduler_config .encoder_cache_size
92+ encoder_compute_budget , encoder_cache_size = compute_encoder_budget (
93+ model_config = model_config ,
94+ scheduler_config = scheduler_config ,
95+ )
96+ self .max_num_encoder_input_tokens = encoder_compute_budget
97+ self .encoder_cache_size = encoder_cache_size
9398
9499 # Lazy initialization
95100 # self.model: nn.Module # Set after load_model
@@ -721,44 +726,30 @@ def profile_run(self) -> None:
721726 ]
722727
723728 # Profile with multimodal encoder & encoder cache.
724- if self .is_multimodal_model :
725-
726- # Create dummy batch of multimodal inputs.
727- dummy_request_data = self .input_registry .dummy_data_for_profiling (
728- model_config = self .model_config ,
729- seq_len = self .max_num_tokens ,
730- mm_registry = self .mm_registry ,
731- )
732- dummy_mm_data = dummy_request_data .multi_modal_data
729+ # TODO: handle encoder-decoder models once we support them.
730+ if (self .is_multimodal_model and self .max_num_encoder_input_tokens > 0
731+ and self .encoder_cache_size > 0 ):
733732
734733 # NOTE: Currently model is profiled with a single non-text
735734 # modality with the max possible input tokens even when
736735 # it supports multiple.
737- max_tokens_by_modality_dict = self . mm_registry . get_max_tokens_per_item_by_modality ( # noqa: E501
736+ max_tokens_by_modality_dict = MULTIMODAL_REGISTRY . get_max_tokens_per_item_by_nonzero_modality ( # noqa: E501
738737 self .model_config )
739-
740738 dummy_data_modality , max_tokens_per_mm_item = max (
741739 max_tokens_by_modality_dict .items (), key = lambda item : item [1 ])
742740
743741 # Check how many items of this modality can be supported by
744- # the encoder cache budget.
745- encoder_cache_budget = min (self .max_num_encoder_input_tokens ,
746- self .encoder_cache_size )
747- max_num_mm_items_encoder_budget = encoder_cache_budget // \
748- max_tokens_per_mm_item
749-
750- # TODO: Allow users to set encoder_cache_budget in case this
751- # happens.
752- assert max_num_mm_items_encoder_budget > 0 , (
753- f"Encoder cache budget={ encoder_cache_budget } is too small to "
754- f"support the maximum possible size of multimodal embeddings"
755- f"={ max_tokens_per_mm_item } ." )
742+ # the encoder budget.
743+ encoder_budget = min (self .max_num_encoder_input_tokens ,
744+ self .encoder_cache_size )
745+
746+ max_num_mm_items_encoder_budget = cdiv (encoder_budget ,
747+ max_tokens_per_mm_item )
756748
757749 # Check how many items of this modality can be supported by
758750 # the decoder budget.
759- max_mm_items_per_req = max (
760- self .mm_registry .get_mm_limits_per_prompt (
761- self .model_config ).values ())
751+ max_mm_items_per_req = self .mm_registry .get_mm_limits_per_prompt (
752+ self .model_config )[dummy_data_modality ]
762753
763754 # NOTE: We do not consider max_num_batched_tokens on purpose
764755 # because the multimodal embeddings can be generated in advance
@@ -769,6 +760,19 @@ def profile_run(self) -> None:
769760 max_num_mm_items = min (max_num_mm_items_encoder_budget ,
770761 max_num_mm_items_decoder_budget )
771762
763+ logger .info (
764+ "Encoder cache will be initialized with a budget of %s tokens,"
765+ " and profiled with %s %s items of the maximum feature size." ,
766+ encoder_budget , max_num_mm_items , dummy_data_modality )
767+
768+ # Create dummy batch of multimodal inputs.
769+ dummy_request_data = self .input_registry .dummy_data_for_profiling (
770+ model_config = self .model_config ,
771+ seq_len = self .max_num_tokens ,
772+ mm_registry = self .mm_registry ,
773+ )
774+ dummy_mm_data = dummy_request_data .multi_modal_data
775+
772776 # Dummy data definition in V0 may contain multiple multimodal items
773777 # (e.g, multiple images) for a single request, therefore here we
774778 # always replicate first item by max_num_mm_items times since in V1
0 commit comments