|
53 | 53 | AsyncModelRunnerOutput, KVConnectorOutput) |
54 | 54 | from vllm.v1.sample.metadata import SamplingMetadata |
55 | 55 | from vllm.v1.worker.utils import bind_kv_cache |
56 | | -from vllm.v1.utils import CpuGpuBuffer |
57 | 56 | from vllm_gaudi.v1.worker.hpu_input_batch import InputBatch, CachedRequestState |
58 | 57 | from vllm.distributed.parallel_state import get_pp_group |
59 | 58 | from vllm.model_executor.models.interfaces import (SupportsMultiModal, supports_eagle3, supports_transcription) |
@@ -543,10 +542,8 @@ def forward(self, *args, **kwargs): |
543 | 542 | self._reset_rotary_cos_sin() |
544 | 543 | return hidden_states |
545 | 544 |
|
546 | | - def get_input_embeddings(self, input_ids, multimodal_embeddings=None, is_multimodal=False): |
547 | | - return self.model.get_input_embeddings(input_ids=input_ids, |
548 | | - multimodal_embeddings=multimodal_embeddings, |
549 | | - is_multimodal=is_multimodal) |
| 545 | + def get_input_embeddings(self, input_ids, multimodal_embeddings=None): |
| 546 | + return self.model.get_input_embeddings(input_ids=input_ids, multimodal_embeddings=multimodal_embeddings) |
550 | 547 |
|
551 | 548 | def get_multimodal_embeddings(self, **batched_mm_inputs): |
552 | 549 | return self.model.get_multimodal_embeddings(**batched_mm_inputs) |
@@ -741,8 +738,6 @@ def __init__( |
741 | 738 | self.mm_registry = MULTIMODAL_REGISTRY |
742 | 739 | self.uses_mrope = model_config.uses_mrope |
743 | 740 | self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(model_config) |
744 | | - if self.supports_mm_inputs: |
745 | | - self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool) |
746 | 741 | self.is_multimodal_raw_input_supported = (model_config.is_multimodal_raw_input_only_model) |
747 | 742 |
|
748 | 743 | # Lazy initialization |
@@ -853,9 +848,6 @@ def __init__( |
853 | 848 | assert not (self.unified_attn and not self.use_contiguous_pa), 'Unified attn requires contiguous_pa!' |
854 | 849 | assert not (self.unified_attn and not self.use_merged_prefill), 'Unified attn requires merged_prefill!' |
855 | 850 |
|
856 | | - def _make_buffer(self, *size: Union[int, torch.SymInt], dtype: torch.dtype, numpy: bool = True) -> CpuGpuBuffer: |
857 | | - return CpuGpuBuffer(*size, dtype=dtype, device=self.device, pin_memory=self.pin_memory, with_numpy=numpy) |
858 | | - |
859 | 851 | def unified_bucketing_fn(self, is_causal, query_len, shared_blocks, unique_blocks, logits): |
860 | 852 | if not get_config().use_bucketing: |
861 | 853 | return query_len, shared_blocks, unique_blocks, logits |
@@ -1319,15 +1311,8 @@ def _gather_mm_embeddings( |
1319 | 1311 | scheduler_output: "SchedulerOutput", |
1320 | 1312 | req_ids: list[str], |
1321 | 1313 | shift_computed_tokens: int = 0, |
1322 | | - total_num_scheduled_tokens: Optional[int] = None, |
1323 | | - ) -> tuple[list[torch.Tensor], torch.Tensor]: |
1324 | | - total_num_scheduled_tokens = total_num_scheduled_tokens or scheduler_output.total_num_scheduled_tokens |
1325 | | - |
1326 | | - mm_embeds = list[torch.Tensor]() |
1327 | | - is_mm_embed = self.is_mm_embed.cpu |
1328 | | - is_mm_embed[:total_num_scheduled_tokens] = False |
1329 | | - |
1330 | | - req_start_idx = 0 |
| 1314 | + ) -> list[torch.Tensor]: |
| 1315 | + mm_embeds: list[torch.Tensor] = [] |
1331 | 1316 | for req_id in req_ids: |
1332 | 1317 | num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] |
1333 | 1318 | req_state = self.requests[req_id] |
@@ -1366,17 +1351,8 @@ def _gather_mm_embeddings( |
1366 | 1351 | encoder_output[start_idx:end_idx], |
1367 | 1352 | is_embed=is_embed, |
1368 | 1353 | ) |
1369 | | - req_start_pos = req_start_idx + start_pos - num_computed_tokens |
1370 | | - is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \ |
1371 | | - = True |
1372 | | - |
1373 | | - # Only whole mm items are processed |
1374 | 1354 | mm_embeds.append(mm_embeds_item) |
1375 | | - req_start_idx += num_scheduled_tokens |
1376 | | - |
1377 | | - is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens) |
1378 | | - |
1379 | | - return mm_embeds, is_mm_embed |
| 1355 | + return mm_embeds |
1380 | 1356 |
|
1381 | 1357 | def get_model(self) -> torch.nn.Module: |
1382 | 1358 | assert self.model is not None |
@@ -2963,16 +2939,13 @@ def execute_model( |
2963 | 2939 | with self.profiler.record_event('internal', 'prepare_input_encoders'): |
2964 | 2940 | self._execute_mm_encoder(scheduler_output, req_id) |
2965 | 2941 |
|
2966 | | - mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output, |
2967 | | - req_id, |
2968 | | - total_num_scheduled_tokens=token_ids.shape[-1]) |
| 2942 | + mm_embeds = self._gather_mm_embeddings(scheduler_output, req_id) |
2969 | 2943 | # TODO: Only get embeddings for valid token_ids. Ignore token_ids[<pad_idxs>] # noqa E501 |
2970 | 2944 | # This may require moving multimodal input preps into _prepare_inputs, # noqa E501 |
2971 | 2945 | # to avoid padding issues. |
2972 | 2946 | inputs_embeds = self.model.get_input_embeddings( |
2973 | | - token_ids, |
2974 | | - multimodal_embeddings=mm_embeds, |
2975 | | - is_multimodal=is_mm_embed, |
| 2947 | + input_ids=token_ids, |
| 2948 | + multimodal_embeddings=mm_embeds or None, |
2976 | 2949 | ) |
2977 | 2950 |
|
2978 | 2951 | model_mm_kwargs = self._extract_mm_kwargs(scheduler_output) |
|
0 commit comments