@@ -58,22 +58,27 @@ def __init__(
58
58
) -> None :
59
59
super ().__init__ (vllm_config = vllm_config , prefix = prefix , ** kwargs )
60
60
61
+ self .vllm_config = vllm_config
62
+
61
63
# These are not used in pooling models
62
64
for attr in ("lm_head" , "logits_processor" ):
63
65
if hasattr (self , attr ):
64
66
delattr (self , attr )
65
67
68
+ # If the model already defines a pooler instance, don't overwrite it
69
+ if not getattr (self , "_pooler" , None ):
70
+ self ._init_pooler (vllm_config , prefix = prefix )
71
+
72
+ def _init_pooler (self , vllm_config : "VllmConfig" , prefix : str = "" ):
66
73
pooler_config = vllm_config .model_config .pooler_config
67
74
assert pooler_config is not None
68
75
69
- # If the model already defines a pooler instance, don't overwrite it
70
- if not getattr (self , "_pooler" , None ):
71
- self ._pooler = Pooler .from_config_with_defaults (
72
- pooler_config ,
73
- pooling_type = default_pooling_type ,
74
- normalize = default_normalize ,
75
- softmax = default_softmax ,
76
- )
76
+ self ._pooler = Pooler .from_config_with_defaults (
77
+ pooler_config ,
78
+ pooling_type = default_pooling_type ,
79
+ normalize = default_normalize ,
80
+ softmax = default_softmax ,
81
+ )
77
82
78
83
def pooler (
79
84
self ,
@@ -165,7 +170,9 @@ def as_seq_cls_model(cls: _T) -> _T:
165
170
166
171
# Lazy import
167
172
from vllm .model_executor .layers .linear import RowParallelLinear
168
- from vllm .model_executor .layers .pooler import PoolerOutput , PoolingType
173
+ from vllm .model_executor .layers .pooler import (ClassifierPooler ,
174
+ PoolerOutput , PoolingType ,
175
+ SimplePooler )
169
176
from vllm .model_executor .models .interfaces import SupportsCrossEncoding
170
177
from vllm .model_executor .pooling_metadata import PoolingMetadata
171
178
from vllm .sequence import IntermediateTensors
@@ -182,30 +189,40 @@ def as_seq_cls_model(cls: _T) -> _T:
182
189
class ModelForSequenceClassification (ModelForPooling ,
183
190
SupportsCrossEncoding ):
184
191
185
- def __init__ (
186
- self ,
187
- * ,
188
- vllm_config : "VllmConfig" ,
189
- prefix : str = "" ,
190
- ** kwargs : Any ,
191
- ) -> None :
192
- super ().__init__ (vllm_config = vllm_config , prefix = prefix , ** kwargs )
193
-
192
+ def _init_pooler (self , vllm_config : "VllmConfig" , prefix : str = "" ):
194
193
config = vllm_config .model_config .hf_config
195
194
quant_config = vllm_config .quant_config
196
195
197
- self .vllm_config = vllm_config
198
- self .task = vllm_config .model_config .task
199
- self .pooling_type = (
200
- vllm_config .model_config .pooler_config .pooling_type )
201
-
202
- self .score = RowParallelLinear (config .hidden_size ,
203
- config .num_labels ,
204
- quant_config = quant_config ,
205
- input_is_parallel = False ,
206
- bias = False ,
207
- prefix = maybe_prefix (
208
- prefix , "score" ))
196
+ self .score = RowParallelLinear (
197
+ config .hidden_size ,
198
+ config .num_labels ,
199
+ input_is_parallel = False ,
200
+ bias = False ,
201
+ params_dtype = torch .float32 ,
202
+ quant_config = quant_config ,
203
+ prefix = maybe_prefix (prefix , "score" ),
204
+ )
205
+
206
+ pooler_config = vllm_config .model_config .pooler_config
207
+ assert pooler_config is not None
208
+
209
+ pooler = SimplePooler .from_config_with_defaults (
210
+ pooler_config ,
211
+ pooling_type = PoolingType .LAST ,
212
+ normalize = False ,
213
+ softmax = True ,
214
+ )
215
+
216
+ self ._pooler = ClassifierPooler (
217
+ vllm_config .model_config ,
218
+ pooling = pooler .pooling ,
219
+ classifier = self ._classifier ,
220
+ act_fn = pooler .head .activation ,
221
+ )
222
+
223
+ def _classifier (self , x : torch .Tensor ):
224
+ x , _ = self .score (x .float ())
225
+ return x
209
226
210
227
def forward (
211
228
self ,
@@ -222,27 +239,7 @@ def pooler(
222
239
hidden_states : Union [torch .Tensor , list [torch .Tensor ]],
223
240
pooling_metadata : PoolingMetadata ,
224
241
) -> PoolerOutput :
225
-
226
- def get_logits (hidden_states ):
227
- if isinstance (hidden_states , list ):
228
- logits = [self .score (state )[0 ] for state in hidden_states ]
229
- else :
230
- logits , _ = self .score (hidden_states )
231
- return logits
232
-
233
- if self .pooling_type == PoolingType .ALL :
234
- logits = get_logits (hidden_states )
235
- return self ._pooler (logits , pooling_metadata )
236
- else :
237
- hidden_states = self ._pooler .extract_states (
238
- hidden_states , pooling_metadata )
239
- logits = get_logits (hidden_states )
240
- pooled_data = self ._pooler .head (logits , pooling_metadata )
241
-
242
- pooled_outputs = [
243
- self ._pooler .build_output (data ) for data in pooled_data
244
- ]
245
- return PoolerOutput (outputs = pooled_outputs )
242
+ return self ._pooler (hidden_states , pooling_metadata )
246
243
247
244
def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
248
245
tokens = getattr (self .config , "classifier_from_token" , None )
0 commit comments