@@ -572,6 +572,220 @@ def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:
572572 return token_type_ids
573573
574574
575+ class BertMLMHead (nn .Module ):
576+ def __init__ (
577+ self , hidden_size : int , vocab_size : int , layer_norm_eps : float = 1e-12
578+ ):
579+ super ().__init__ ()
580+ self .dense = nn .Linear (hidden_size , hidden_size )
581+ self .activation = nn .GELU ()
582+ self .layer_norm = nn .LayerNorm (hidden_size , eps = layer_norm_eps )
583+ self .decoder = nn .Linear (hidden_size , vocab_size , bias = True )
584+
585+ def tie_weights_with_embeddings (self , embeddings_weight : torch .Tensor ):
586+ self .decoder .weight = embeddings_weight
587+
588+ def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
589+ x = self .dense (hidden_states )
590+ x = self .activation (x )
591+ x = self .layer_norm (x )
592+ logits = self .decoder (x )
593+ return logits
594+
595+
596+ class SPLADESparsePooler (Pooler ):
597+ """
598+ SPLADE sparse pooling:
599+ logits = mlm_head(hidden_states)
600+ -> log1p(relu(logits))
601+ -> (max|sum over L)
602+ -> [V]
603+
604+ Padding is masked with an attention mask,
605+ [CLS]/[SEP] is removed (selected),
606+ and then pooled.
607+ """
608+
609+ def __init__ (
610+ self ,
611+ mlm_head : nn .Module ,
612+ cls_token_id : Optional [int ] = 101 ,
613+ sep_token_id : Optional [int ] = 102 ,
614+ pooling : str = "max" ,
615+ remove_cls_sep : bool = True ,
616+ ):
617+ super ().__init__ ()
618+ assert pooling in ("max" , "sum" )
619+ self .mlm_head = mlm_head
620+ self .cls_token_id = cls_token_id
621+ self .sep_token_id = sep_token_id
622+ self .pooling = pooling
623+ self .remove_cls_sep = remove_cls_sep
624+
625+ def get_supported_tasks (self ) -> Set [PoolingTask ]:
626+ return {"embed" }
627+
628+ def get_pooling_updates (self , task : PoolingTask ) -> PoolingParamsUpdate :
629+ return PoolingParamsUpdate (requires_token_ids = True )
630+
631+ def forward (
632+ self ,
633+ hidden_states : torch .Tensor ,
634+ pooling_metadata : PoolingMetadata ,
635+ ) -> torch .Tensor :
636+ assert isinstance (hidden_states , torch .Tensor ) and hidden_states .dim () == 2
637+
638+ lens_tensor : torch .Tensor = pooling_metadata .prompt_lens
639+ lens : list [int ] = lens_tensor .tolist ()
640+ B : int = len (lens )
641+
642+ token_ids = pooling_metadata .prompt_token_ids
643+ offset = 0
644+ pooled_list : list [torch .Tensor ] = []
645+
646+ for i in range (B ):
647+ L = int (lens [i ])
648+ hs = hidden_states [offset : offset + L ]
649+
650+ start_idx = 0
651+ end_idx = L
652+ if self .remove_cls_sep and token_ids is not None :
653+ if (
654+ self .cls_token_id is not None
655+ and token_ids [i , 0 ].item () == self .cls_token_id
656+ ):
657+ start_idx = 1
658+ if (
659+ self .sep_token_id is not None
660+ and token_ids [i , L - 1 ].item () == self .sep_token_id
661+ ):
662+ end_idx = max (start_idx , L - 1 )
663+
664+ if end_idx <= start_idx :
665+ V = int (self .mlm_head .decoder .out_features )
666+ pooled_list .append (hs .new_zeros ((V ,)))
667+ offset += L
668+ continue
669+
670+ logits_i = self .mlm_head (hs [start_idx :end_idx ])
671+ scores_i = torch .log1p (torch .relu (logits_i ))
672+
673+ if self .pooling == "sum" :
674+ pooled_i = scores_i .sum (dim = 0 )
675+ else : # "max"
676+ pooled_i = scores_i .max (dim = 0 ).values
677+
678+ pooled_list .append (pooled_i .contiguous ())
679+ offset += L
680+
681+ return torch .stack (pooled_list , dim = 0 ).contiguous ()
682+
683+
684+ @default_pooling_type ("CLS" )
685+ class BertSpladeSparseEmbeddingModel (BertEmbeddingModel ):
686+ """
687+ BertEmbeddingModel + SPLADE sparse embedding.
688+ - Make logits by self.mlm_head
689+ - pooler: SPLADESparsePooler(mlm_head...)
690+ """
691+
692+ def __init__ (
693+ self , * , vllm_config : VllmConfig , prefix : str = "" , splade_pooling : str = "max"
694+ ):
695+ super ().__init__ (vllm_config = vllm_config , prefix = prefix )
696+ cfg = vllm_config .model_config .hf_config
697+
698+ # MLM head
699+ self .mlm_head = BertMLMHead (
700+ hidden_size = cfg .hidden_size ,
701+ vocab_size = cfg .vocab_size ,
702+ layer_norm_eps = getattr (cfg , "layer_norm_eps" , 1e-12 ),
703+ )
704+
705+ self ._splade_pooling = splade_pooling
706+ pooler_config = vllm_config .model_config .pooler_config
707+ assert pooler_config is not None
708+ self .pooler = self ._build_pooler (pooler_config )
709+
710+ def _build_pooler (self , pooler_config : PoolerConfig ) -> Pooler :
711+ cfg = self .model .config
712+
713+ if not hasattr (self , "mlm_head" ):
714+ self .mlm_head = BertMLMHead (
715+ hidden_size = cfg .hidden_size ,
716+ vocab_size = cfg .vocab_size ,
717+ layer_norm_eps = getattr (cfg , "layer_norm_eps" , 1e-12 ),
718+ )
719+
720+ pooling_mode = getattr (self , "_splade_pooling" , "max" )
721+
722+ cls_id = getattr (cfg , "cls_token_id" , None )
723+ sep_id = getattr (cfg , "sep_token_id" , None )
724+
725+ return DispatchPooler (
726+ {
727+ "encode" : Pooler .for_encode (pooler_config ),
728+ "embed" : SPLADESparsePooler (
729+ mlm_head = self .mlm_head ,
730+ cls_token_id = cls_id ,
731+ sep_token_id = sep_id ,
732+ pooling = pooling_mode , # "max" or "sum"
733+ remove_cls_sep = True ,
734+ ),
735+ }
736+ )
737+
738+ def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
739+ if not hasattr (self , "mlm_head" ):
740+ cfg = self .model .config
741+ self .mlm_head = BertMLMHead (
742+ hidden_size = cfg .hidden_size ,
743+ vocab_size = cfg .vocab_size ,
744+ layer_norm_eps = getattr (cfg , "layer_norm_eps" , 1e-12 ),
745+ )
746+
747+ def _strip (name : str ) -> str :
748+ for p in ("model." , "bert." ):
749+ if name .startswith (p ):
750+ name = name [len (p ) :]
751+ return name
752+
753+ weights_list = list (weights )
754+ model_side : list [tuple [str , torch .Tensor ]] = []
755+ mlm_side : list [tuple [str , torch .Tensor ]] = []
756+
757+ for k , w in weights_list :
758+ name = _strip (k )
759+ if name .startswith ("cls.predictions." ):
760+ mlm_side .append ((name , w ))
761+ else :
762+ model_side .append ((name , w ))
763+
764+ loaded : set [str ] = set ()
765+ loaded_model = self .model .load_weights (model_side )
766+ loaded .update ({"model." + n for n in loaded_model })
767+
768+ if mlm_side :
769+ name_map = {
770+ "cls.predictions.transform.dense.weight" : "mlm_head.dense.weight" ,
771+ "cls.predictions.transform.dense.bias" : "mlm_head.dense.bias" ,
772+ ("cls.predictions.transform.LayerNorm.weight" ): (
773+ "mlm_head.layer_norm.weight"
774+ ),
775+ ("cls.predictions.transform.LayerNorm.bias" ): (
776+ "mlm_head.layer_norm.bias"
777+ ),
778+ "cls.predictions.decoder.weight" : "mlm_head.decoder.weight" ,
779+ "cls.predictions.decoder.bias" : "mlm_head.decoder.bias" ,
780+ }
781+ remapped = [(name_map [n ], w ) for n , w in mlm_side if n in name_map ]
782+ if remapped :
783+ loaded_mlm = AutoWeightsLoader (self ).load_weights (remapped )
784+ loaded .update (loaded_mlm )
785+
786+ return loaded
787+
788+
575789@default_pooling_type ("CLS" )
576790class BertForSequenceClassification (nn .Module , SupportsCrossEncoding , SupportsQuant ):
577791 """A model that uses Bert to provide embedding functionalities.
0 commit comments