Skip to content

Commit cff864e

Browse files
patrickvonplatenweilong.yu
authored and
weilong.yu
committed
[Pixtral] Improve loading (vllm-project#11040)
1 parent 51ceb45 commit cff864e

File tree

1 file changed

+25
-31
lines changed

1 file changed

+25
-31
lines changed

vllm/model_executor/models/pixtral.py

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from dataclasses import dataclass, fields
22
from functools import cached_property
3-
from itertools import tee
43
from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union
54

65
import numpy
@@ -359,38 +358,33 @@ def is_vision_encoder_weights(weight: Tuple[str, torch.Tensor]):
359358
def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]):
360359
return weight[0].startswith("vision_language_adapter")
361360

362-
def is_vision_weights(weight: Tuple[str, torch.Tensor]):
363-
return is_vision_encoder_weights(
364-
weight) or is_vision_lang_adapter_weights(weight)
365-
366-
llm_weights, vision_encoder_weights, vision_lang_adapter_weights = tee(
367-
weights, 3)
368-
369-
# llm
370-
llm_weights = filter(lambda x: not is_vision_weights(x), llm_weights)
371-
self.language_model.load_weights(llm_weights)
372-
373-
# vision encoder
374-
vision_encoder_weights = filter(is_vision_encoder_weights,
375-
vision_encoder_weights)
361+
# Get references to parameters for direct loading
376362
vision_encoder_dict = dict(self.vision_encoder.named_parameters())
377-
for name, loaded_weight in vision_encoder_weights:
378-
# cut 'vision_encoder.'
379-
name = '.'.join(name.split(".")[1:])
380-
param = vision_encoder_dict[name]
381-
382-
default_weight_loader(param, loaded_weight)
383-
384-
# adapter
385-
vision_lang_adapter_weights = filter(is_vision_lang_adapter_weights,
386-
vision_lang_adapter_weights)
387-
vision_lang_adpter_dict = dict(
363+
vision_lang_adapter_dict = dict(
388364
self.vision_language_adapter.named_parameters())
389-
for name, loaded_weight in vision_lang_adapter_weights:
390-
# cut 'vision_language_adapter.'
391-
name = '.'.join(name.split(".")[1:])
392-
param = vision_lang_adpter_dict[name]
393-
default_weight_loader(param, loaded_weight)
365+
366+
def llm_weights_generator():
367+
# Single pass over weights
368+
for name, w in weights:
369+
if is_vision_encoder_weights((name, w)):
370+
# Load vision encoder weights directly
371+
trimmed_name = '.'.join(name.split(".")[1:])
372+
param = vision_encoder_dict[trimmed_name]
373+
with torch.no_grad():
374+
default_weight_loader(param, w)
375+
elif is_vision_lang_adapter_weights((name, w)):
376+
# Load vision-language adapter weights directly
377+
trimmed_name = '.'.join(name.split(".")[1:])
378+
param = vision_lang_adapter_dict[trimmed_name]
379+
with torch.no_grad():
380+
default_weight_loader(param, w)
381+
else:
382+
# LLM weights: yield them to be loaded
383+
# by language_model.load_weights
384+
yield (name, w)
385+
386+
# Now we call the language model load with the generator
387+
self.language_model.load_weights(llm_weights_generator())
394388

395389

396390
# Vision encoder

0 commit comments

Comments
 (0)