Skip to content

Commit 5340b0e

Browse files
authored
[Bugfix] Fix interface for Olmo2 on V1 (#14976)
Signed-off-by: Roger Wang <ywang@roblox.com>
1 parent 37e3806 commit 5340b0e

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

vllm/model_executor/models/olmo2.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
RowParallelLinear)
4343
from vllm.model_executor.layers.logits_processor import LogitsProcessor
4444
from vllm.model_executor.layers.rotary_embedding import get_rope
45-
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
45+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
4646
from vllm.model_executor.layers.vocab_parallel_embedding import (
4747
ParallelLMHead, VocabParallelEmbedding)
4848
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -283,17 +283,19 @@ def forward(
283283
input_ids: torch.Tensor,
284284
positions: torch.Tensor,
285285
intermediate_tensors: Optional[IntermediateTensors],
286+
inputs_embeds: Optional[torch.Tensor] = None,
286287
) -> Union[torch.Tensor, IntermediateTensors]:
287288
"""
288289
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
289290
"""
290291
if get_pp_group().is_first_rank:
292+
if inputs_embeds is not None:
293+
hidden_states = inputs_embeds
291294
# Get embeddings of input.
292295
# shape: (batch_size, seq_len, d_model)
293-
inputs_embeds = self.embed_tokens(input_ids)
296+
else:
297+
hidden_states = self.embed_tokens(input_ids)
294298

295-
# embed positions
296-
hidden_states = inputs_embeds
297299
else:
298300
assert intermediate_tensors is not None
299301
hidden_states = intermediate_tensors["hidden_states"]
@@ -337,7 +339,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
337339
prefix=maybe_prefix(prefix, "lm_head"),
338340
)
339341
self.logits_processor = LogitsProcessor(config.vocab_size)
340-
self.sampler = Sampler()
342+
self.sampler = get_sampler()
341343
self.make_empty_intermediate_tensors = (
342344
self.model.make_empty_intermediate_tensors)
343345

@@ -346,11 +348,13 @@ def forward(
346348
input_ids: torch.Tensor,
347349
positions: torch.Tensor,
348350
intermediate_tensors: Optional[IntermediateTensors] = None,
351+
inputs_embeds: Optional[torch.Tensor] = None,
349352
) -> Union[torch.Tensor, IntermediateTensors]:
350353
hidden_states = self.model(
351354
input_ids=input_ids,
352355
positions=positions,
353356
intermediate_tensors=intermediate_tensors,
357+
inputs_embeds=inputs_embeds,
354358
)
355359
return hidden_states
356360

0 commit comments

Comments
 (0)