Skip to content

Commit 5c7a13f

Browse files
jeejeeleeLeiWang1999
authored andcommitted
[Misc] Adjust max_position_embeddings for LoRA compatibility (vllm-project#8957)
Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent afc5484 commit 5c7a13f

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

vllm/worker/model_runner.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -1037,9 +1037,17 @@ def load_model(self) -> None:
10371037
assert supports_lora(
10381038
self.model
10391039
), f"{self.model.__class__.__name__} does not support LoRA yet."
1040+
10401041
if supports_multimodal(self.model):
10411042
logger.warning("Regarding multimodal models, vLLM currently "
10421043
"only supports adding LoRA to language model.")
1044+
# It's necessary to distinguish between the max_position_embeddings
1045+
# of VLMs and LLMs.
1046+
if hasattr(self.model.config, "max_position_embeddings"):
1047+
max_pos_embeddings = self.model.config.max_position_embeddings
1048+
else:
1049+
max_pos_embeddings = (
1050+
self.model.config.text_config.max_position_embeddings)
10431051

10441052
self.lora_manager = LRUCacheWorkerLoRAManager(
10451053
self.scheduler_config.max_num_seqs,
@@ -1049,8 +1057,7 @@ def load_model(self) -> None:
10491057
self.device,
10501058
self.model.embedding_modules,
10511059
self.model.embedding_padding_modules,
1052-
max_position_embeddings=self.model.config.
1053-
max_position_embeddings,
1060+
max_position_embeddings=max_pos_embeddings,
10541061
)
10551062
self.model = self.lora_manager.create_lora_manager(self.model)
10561063

0 commit comments

Comments
 (0)