42
42
RowParallelLinear )
43
43
from vllm .model_executor .layers .logits_processor import LogitsProcessor
44
44
from vllm .model_executor .layers .rotary_embedding import get_rope
45
- from vllm .model_executor .layers .sampler import Sampler , SamplerOutput
45
+ from vllm .model_executor .layers .sampler import SamplerOutput , get_sampler
46
46
from vllm .model_executor .layers .vocab_parallel_embedding import (
47
47
ParallelLMHead , VocabParallelEmbedding )
48
48
from vllm .model_executor .model_loader .weight_utils import default_weight_loader
@@ -283,17 +283,19 @@ def forward(
283
283
input_ids : torch .Tensor ,
284
284
positions : torch .Tensor ,
285
285
intermediate_tensors : Optional [IntermediateTensors ],
286
+ inputs_embeds : Optional [torch .Tensor ] = None ,
286
287
) -> Union [torch .Tensor , IntermediateTensors ]:
287
288
"""
288
289
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
289
290
"""
290
291
if get_pp_group ().is_first_rank :
292
+ if inputs_embeds is not None :
293
+ hidden_states = inputs_embeds
291
294
# Get embeddings of input.
292
295
# shape: (batch_size, seq_len, d_model)
293
- inputs_embeds = self .embed_tokens (input_ids )
296
+ else :
297
+ hidden_states = self .embed_tokens (input_ids )
294
298
295
- # embed positions
296
- hidden_states = inputs_embeds
297
299
else :
298
300
assert intermediate_tensors is not None
299
301
hidden_states = intermediate_tensors ["hidden_states" ]
@@ -337,7 +339,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
337
339
prefix = maybe_prefix (prefix , "lm_head" ),
338
340
)
339
341
self .logits_processor = LogitsProcessor (config .vocab_size )
340
- self .sampler = Sampler ()
342
+ self .sampler = get_sampler ()
341
343
self .make_empty_intermediate_tensors = (
342
344
self .model .make_empty_intermediate_tensors )
343
345
@@ -346,11 +348,13 @@ def forward(
346
348
input_ids : torch .Tensor ,
347
349
positions : torch .Tensor ,
348
350
intermediate_tensors : Optional [IntermediateTensors ] = None ,
351
+ inputs_embeds : Optional [torch .Tensor ] = None ,
349
352
) -> Union [torch .Tensor , IntermediateTensors ]:
350
353
hidden_states = self .model (
351
354
input_ids = input_ids ,
352
355
positions = positions ,
353
356
intermediate_tensors = intermediate_tensors ,
357
+ inputs_embeds = inputs_embeds ,
354
358
)
355
359
return hidden_states
356
360
0 commit comments