|
20 | 20 | QKVParallelLinear,
|
21 | 21 | ReplicatedLinear,
|
22 | 22 | RowParallelLinear)
|
| 23 | +from vllm.model_executor.layers.pooler import ClassifierPooler |
23 | 24 | from vllm.model_executor.layers.quantization import QuantizationConfig
|
24 | 25 | from vllm.model_executor.layers.rotary_embedding import get_rope
|
25 | 26 | from vllm.model_executor.layers.vocab_parallel_embedding import (
|
26 | 27 | VocabParallelEmbedding)
|
27 | 28 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
28 | 29 | from vllm.model_executor.models import SupportsV0Only
|
29 |
| -from vllm.model_executor.models.interfaces import SupportsQuant |
30 |
| -from vllm.model_executor.models.utils import WeightsMapper |
31 |
| -from vllm.sequence import IntermediateTensors |
| 30 | +from vllm.model_executor.models.bert import BertPooler |
| 31 | +from vllm.model_executor.models.interfaces import (SupportsCrossEncoding, |
| 32 | + SupportsQuant) |
| 33 | +from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix |
| 34 | +from vllm.model_executor.pooling_metadata import PoolingMetadata |
| 35 | +from vllm.sequence import IntermediateTensors, PoolerOutput |
| 36 | +from vllm.transformers_utils.config import ( |
| 37 | + get_cross_encoder_activation_function) |
32 | 38 |
|
33 | 39 | logger = init_logger(__name__)
|
34 | 40 |
|
@@ -405,16 +411,23 @@ def forward(
|
405 | 411 | class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
|
406 | 412 | hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
|
407 | 413 |
|
408 |
| - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| 414 | + def __init__(self, |
| 415 | + *, |
| 416 | + vllm_config: VllmConfig, |
| 417 | + prefix: str = "", |
| 418 | + add_pooling_layer=False): |
409 | 419 | super().__init__()
|
410 | 420 | self.vllm_config = vllm_config
|
| 421 | + self.add_pooling_layer = add_pooling_layer |
411 | 422 | self.config = self.config_verify(vllm_config)
|
412 | 423 | self.embeddings = BertWithRopeEmbedding(self.config)
|
413 | 424 | self.encoder = BertWithRopeEncoder(
|
414 | 425 | vllm_config=vllm_config,
|
415 | 426 | bias=getattr(self.config, "bias", True),
|
416 | 427 | rotary_kwargs=self.config.rotary_kwargs,
|
417 | 428 | prefix=f"{prefix}.encoder")
|
| 429 | + if self.add_pooling_layer: |
| 430 | + self.pooler = BertPooler(self.config) |
418 | 431 |
|
419 | 432 | def config_verify(self, vllm_config):
|
420 | 433 | raise NotImplementedError
|
@@ -450,7 +463,7 @@ def load_weights(self, weights: Iterable[tuple[str,
|
450 | 463 | params_dict = dict(self.named_parameters())
|
451 | 464 | loaded_params: set[str] = set()
|
452 | 465 | for name, loaded_weight in weights:
|
453 |
| - if "pooler" in name: |
| 466 | + if not self.add_pooling_layer and "pooler" in name: |
454 | 467 | continue
|
455 | 468 | for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
456 | 469 | if weight_name not in name:
|
@@ -591,8 +604,8 @@ class GteNewModel(BertWithRope):
|
591 | 604 | "attention.o_proj": "attn.out_proj",
|
592 | 605 | })
|
593 | 606 |
|
594 |
| - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
595 |
| - super().__init__(vllm_config=vllm_config, prefix=prefix) |
| 607 | + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs): |
| 608 | + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) |
596 | 609 |
|
597 | 610 | # GteNewModel only gate_up_proj does not have bias.
|
598 | 611 | # Hack method learned from vllm/model_executor/models/glm.py
|
@@ -762,3 +775,65 @@ def load_weights(self, weights: Iterable[tuple[str,
|
762 | 775 | torch.Tensor]]) -> set[str]:
|
763 | 776 | weights = self.jina_merge_lora_weights(weights)
|
764 | 777 | return super().load_weights(weights)
|
| 778 | + |
| 779 | + |
| 780 | +class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding, |
| 781 | + SupportsQuant): |
| 782 | + |
| 783 | + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| 784 | + super().__init__() |
| 785 | + config = vllm_config.model_config.hf_config |
| 786 | + |
| 787 | + self.default_activation_function = \ |
| 788 | + get_cross_encoder_activation_function(config) |
| 789 | + |
| 790 | + self.num_labels = config.num_labels |
| 791 | + self.new = GteNewModel(vllm_config=vllm_config, |
| 792 | + prefix=maybe_prefix(prefix, "new"), |
| 793 | + add_pooling_layer=True) |
| 794 | + self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
| 795 | + self._pooler = ClassifierPooler(vllm_config.model_config, |
| 796 | + self.classifier, self.new.pooler) |
| 797 | + |
| 798 | + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): |
| 799 | + |
| 800 | + self_weights = [] |
| 801 | + |
| 802 | + def weight_filter(): |
| 803 | + for name, weight in weights: |
| 804 | + if name.startswith("new."): |
| 805 | + yield (name[len("new."):], weight) |
| 806 | + else: |
| 807 | + self_weights.append((name, weight)) |
| 808 | + |
| 809 | + self.new.load_weights(weight_filter()) |
| 810 | + |
| 811 | + params_dict = dict(self.named_parameters()) |
| 812 | + |
| 813 | + for name, loaded_weight in self_weights: |
| 814 | + if name.startswith("classifier"): |
| 815 | + param = params_dict[name] |
| 816 | + weight_loader = getattr(param, "weight_loader", |
| 817 | + default_weight_loader) |
| 818 | + weight_loader(param, loaded_weight) |
| 819 | + |
| 820 | + def pooler( |
| 821 | + self, |
| 822 | + hidden_states: torch.Tensor, |
| 823 | + pooling_metadata: PoolingMetadata, |
| 824 | + ) -> Optional[PoolerOutput]: |
| 825 | + return self._pooler(hidden_states, pooling_metadata) |
| 826 | + |
| 827 | + def forward( |
| 828 | + self, |
| 829 | + input_ids: Optional[torch.Tensor], |
| 830 | + positions: torch.Tensor, |
| 831 | + intermediate_tensors: Optional[IntermediateTensors] = None, |
| 832 | + inputs_embeds: Optional[torch.Tensor] = None, |
| 833 | + token_type_ids: Optional[torch.Tensor] = None, |
| 834 | + ) -> torch.Tensor: |
| 835 | + return self.new(input_ids=input_ids, |
| 836 | + positions=positions, |
| 837 | + inputs_embeds=inputs_embeds, |
| 838 | + intermediate_tensors=intermediate_tensors, |
| 839 | + token_type_ids=token_type_ids) |
0 commit comments