|
1 | 1 | from dataclasses import dataclass, fields
|
2 | 2 | from functools import cached_property
|
3 |
| -from itertools import tee |
4 | 3 | from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union
|
5 | 4 |
|
6 | 5 | import numpy
|
@@ -359,38 +358,33 @@ def is_vision_encoder_weights(weight: Tuple[str, torch.Tensor]):
|
359 | 358 | def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]):
|
360 | 359 | return weight[0].startswith("vision_language_adapter")
|
361 | 360 |
|
362 |
| - def is_vision_weights(weight: Tuple[str, torch.Tensor]): |
363 |
| - return is_vision_encoder_weights( |
364 |
| - weight) or is_vision_lang_adapter_weights(weight) |
365 |
| - |
366 |
| - llm_weights, vision_encoder_weights, vision_lang_adapter_weights = tee( |
367 |
| - weights, 3) |
368 |
| - |
369 |
| - # llm |
370 |
| - llm_weights = filter(lambda x: not is_vision_weights(x), llm_weights) |
371 |
| - self.language_model.load_weights(llm_weights) |
372 |
| - |
373 |
| - # vision encoder |
374 |
| - vision_encoder_weights = filter(is_vision_encoder_weights, |
375 |
| - vision_encoder_weights) |
| 361 | + # Get references to parameters for direct loading |
376 | 362 | vision_encoder_dict = dict(self.vision_encoder.named_parameters())
|
377 |
| - for name, loaded_weight in vision_encoder_weights: |
378 |
| - # cut 'vision_encoder.' |
379 |
| - name = '.'.join(name.split(".")[1:]) |
380 |
| - param = vision_encoder_dict[name] |
381 |
| - |
382 |
| - default_weight_loader(param, loaded_weight) |
383 |
| - |
384 |
| - # adapter |
385 |
| - vision_lang_adapter_weights = filter(is_vision_lang_adapter_weights, |
386 |
| - vision_lang_adapter_weights) |
387 |
| - vision_lang_adpter_dict = dict( |
| 363 | + vision_lang_adapter_dict = dict( |
388 | 364 | self.vision_language_adapter.named_parameters())
|
389 |
| - for name, loaded_weight in vision_lang_adapter_weights: |
390 |
| - # cut 'vision_language_adapter.' |
391 |
| - name = '.'.join(name.split(".")[1:]) |
392 |
| - param = vision_lang_adpter_dict[name] |
393 |
| - default_weight_loader(param, loaded_weight) |
| 365 | + |
| 366 | + def llm_weights_generator(): |
| 367 | + # Single pass over weights |
| 368 | + for name, w in weights: |
| 369 | + if is_vision_encoder_weights((name, w)): |
| 370 | + # Load vision encoder weights directly |
| 371 | + trimmed_name = '.'.join(name.split(".")[1:]) |
| 372 | + param = vision_encoder_dict[trimmed_name] |
| 373 | + with torch.no_grad(): |
| 374 | + default_weight_loader(param, w) |
| 375 | + elif is_vision_lang_adapter_weights((name, w)): |
| 376 | + # Load vision-language adapter weights directly |
| 377 | + trimmed_name = '.'.join(name.split(".")[1:]) |
| 378 | + param = vision_lang_adapter_dict[trimmed_name] |
| 379 | + with torch.no_grad(): |
| 380 | + default_weight_loader(param, w) |
| 381 | + else: |
| 382 | + # LLM weights: yield them to be loaded |
| 383 | + # by language_model.load_weights |
| 384 | + yield (name, w) |
| 385 | + |
| 386 | + # Now we call the language model load with the generator |
| 387 | + self.language_model.load_weights(llm_weights_generator()) |
394 | 388 |
|
395 | 389 |
|
396 | 390 | # Vision encoder
|
|
0 commit comments