Skip to content

Commit a7f7011

Browse files
WoosukKwonAlvant
authored andcommitted
[Bugfix] Fix embedding to support 2D inputs (vllm-project#5829)
Signed-off-by: Alvant <alvasian@yandex.ru>
1 parent adb3867 commit a7f7011

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

vllm/model_executor/layers/vocab_parallel_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,11 +306,11 @@ def forward(self, input_):
306306
self.shard_indices.added_vocab_end_index)
307307
else:
308308
masked_input = input_
309-
# Get the embeddings.
309+
# Get the embeddings.
310310
output_parallel = F.embedding(masked_input.long(), self.weight)
311311
# Mask the output embedding.
312312
if self.tp_size > 1:
313-
output_parallel.masked_fill_(input_mask.unsqueeze(1), 0)
313+
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
314314
# Reduce across all the model parallel GPUs.
315315
output = tensor_model_parallel_all_reduce(output_parallel)
316316
return output

0 commit comments

Comments
 (0)