9
9
import numpy as np
10
10
import pytest
11
11
import requests
12
+ import torch
12
13
13
14
from tests .models .utils import (EmbedModelInfo , RerankModelInfo ,
14
15
check_embeddings_close )
@@ -165,16 +166,19 @@ def mteb_test_embed_models(hf_runner,
165
166
vllm_extra_kwargs = None ,
166
167
hf_model_callback = None ,
167
168
atol = MTEB_EMBED_TOL ):
169
+ # A model family has many models with the same architecture,
170
+ # and we don't need to test each one.
168
171
if not model_info .enable_test :
169
- # A model family has many models with the same architecture,
170
- # and we don't need to test each one.
171
172
pytest .skip ("Skipping test." )
172
173
173
- example_prompts = ["The chef prepared a delicious meal." ]
174
+ # Test embed_dims, isnan and whether to use normalize
175
+ example_prompts = ["The chef prepared a delicious meal." * 1000 ]
174
176
177
+ # Allow vllm to test using the given dtype, such as float32
175
178
vllm_extra_kwargs = vllm_extra_kwargs or {}
176
179
vllm_extra_kwargs ["dtype" ] = model_info .dtype
177
180
181
+ # Allow vllm to test using hf_overrides
178
182
if model_info .hf_overrides is not None :
179
183
vllm_extra_kwargs ["hf_overrides" ] = model_info .hf_overrides
180
184
@@ -186,21 +190,32 @@ def mteb_test_embed_models(hf_runner,
186
190
187
191
model_config = vllm_model .llm .llm_engine .model_config
188
192
193
+ # Confirm whether vllm is using the correct architecture
189
194
if model_info .architecture :
190
195
assert model_info .architecture in model_config .architectures
196
+
197
+ # Confirm whether vllm uses the correct default_pooling_type, which
198
+ # relates to whether chunked prefill and prefix caching are enabled
191
199
assert (model_config ._model_info .default_pooling_type ==
192
200
model_info .default_pooling_type )
193
201
194
202
vllm_main_score = run_mteb_embed_task (VllmMtebEncoder (vllm_model ),
195
203
MTEB_EMBED_TASKS )
196
204
vllm_dtype = vllm_model .llm .llm_engine .model_config .dtype
197
- vllm_outputs = vllm_model .embed (example_prompts )
198
205
206
+ # Test embed_dims, isnan and whether to use normalize
207
+ vllm_outputs = vllm_model .embed (example_prompts ,
208
+ truncate_prompt_tokens = - 1 )
209
+ assert not torch .any (torch .isnan (torch .tensor (vllm_outputs )))
210
+
211
+ # Accelerate mteb test by setting
212
+ # SentenceTransformers mteb score to a constant
199
213
if model_info .mteb_score is None :
200
214
with hf_runner (model_info .name ,
201
215
is_sentence_transformer = True ,
202
216
dtype = "float32" ) as hf_model :
203
217
218
+ # e.g. setting default parameters for the encode method of hf_runner
204
219
if hf_model_callback is not None :
205
220
hf_model_callback (hf_model )
206
221
@@ -299,14 +314,16 @@ def mteb_test_rerank_models(hf_runner,
299
314
hf_model_callback = None ,
300
315
vllm_mteb_encoder = VllmMtebEncoder ,
301
316
atol = MTEB_RERANK_TOL ):
317
+ # A model family has many models with the same architecture,
318
+ # and we don't need to test each one.
302
319
if not model_info .enable_test :
303
- # A model family has many models with the same architecture,
304
- # and we don't need to test each one.
305
320
pytest .skip ("Skipping test." )
306
321
322
+ # Allow vllm to test using the given dtype, such as float32
307
323
vllm_extra_kwargs = vllm_extra_kwargs or {}
308
324
vllm_extra_kwargs ["dtype" ] = model_info .dtype
309
325
326
+ # Allow vllm to test using hf_overrides
310
327
if model_info .hf_overrides is not None :
311
328
vllm_extra_kwargs ["hf_overrides" ] = model_info .hf_overrides
312
329
@@ -319,9 +336,15 @@ def mteb_test_rerank_models(hf_runner,
319
336
320
337
model_config = vllm_model .llm .llm_engine .model_config
321
338
339
+ # Confirm whether vllm is using the correct architecture
322
340
if model_info .architecture :
323
341
assert (model_info .architecture in model_config .architectures )
342
+
343
+ # Score API is only enabled for num_labels == 1
324
344
assert model_config .hf_config .num_labels == 1
345
+
346
+ # Confirm whether vllm uses the correct default_pooling_type, which
347
+ # relates to whether chunked prefill and prefix caching are enabled
325
348
assert (model_config ._model_info .default_pooling_type ==
326
349
model_info .default_pooling_type )
327
350
@@ -330,6 +353,8 @@ def mteb_test_rerank_models(hf_runner,
330
353
languages = MTEB_RERANK_LANGS )
331
354
vllm_dtype = model_config .dtype
332
355
356
+ # Accelerate mteb test by setting
357
+ # SentenceTransformers mteb score to a constant
333
358
if model_info .mteb_score is None :
334
359
st_main_score , st_dtype = mteb_test_rerank_models_hf (
335
360
hf_runner , model_info .name , hf_model_callback )
0 commit comments