1414
1515from ...conftest import VllmRunner
1616
17- MODELS = ["meta-llama/Llama-3.2-1B" ]
17+ MODEL = "meta-llama/Llama-3.2-1B"
18+ DTYPE = "half"
19+
20+
21+ @pytest .fixture (scope = "module" )
22+ def vllm_model (vllm_runner ):
23+ with vllm_runner (
24+ MODEL ,
25+ dtype = DTYPE ,
26+ max_logprobs = 7 ,
27+ # Very small number of batched tokens to ensure
28+ # that we test chunking.
29+ max_num_batched_tokens = 16 ,
30+ max_num_seqs = 16 ,
31+ max_model_len = 128 ,
32+ enforce_eager = True ,
33+ #TODO: enable this once we support it for
34+ # prompt logprobs.
35+ enable_prefix_caching = False ,
36+ gpu_memory_utilization = 0.5 ,
37+ ) as vllm_model :
38+ yield vllm_model
39+
40+
41+ @pytest .fixture (scope = "module" )
42+ def hf_model (hf_runner ):
43+ with hf_runner (MODEL , dtype = DTYPE ) as hf_model :
44+ yield hf_model
1845
1946
2047def _repeat_logprob_config (
@@ -66,30 +93,23 @@ def _repeat_logprob_config(
6693
6794
6895def _test_case_get_logprobs_and_prompt_logprobs (
69- hf_runner ,
70- vllm_runner ,
71- model : str ,
72- dtype : str ,
96+ hf_model ,
97+ vllm_model ,
7398 batch_logprobs_composition : str ,
74- max_num_batched_tokens : int ,
7599 temperature : float ,
76100 example_prompts ,
77101) -> None :
78102 test_prompts = example_prompts
79103
80- max_num_seqs = 16
81- max_model_len = 128
82-
83104 max_tokens = 5
84- with hf_runner (model , dtype = dtype ) as hf_model :
85- hf_outputs = hf_model .generate_greedy (
86- test_prompts ,
87- max_tokens = max_tokens ,
88- )
89- hf_logprobs = hf_model .generate_greedy_logprobs (
90- test_prompts ,
91- max_tokens = max_tokens ,
92- )
105+ hf_outputs = hf_model .generate_greedy (
106+ test_prompts ,
107+ max_tokens = max_tokens ,
108+ )
109+ hf_logprobs = hf_model .generate_greedy_logprobs (
110+ test_prompts ,
111+ max_tokens = max_tokens ,
112+ )
93113
94114 # Batch has mixed sample params
95115 # (different logprobs/prompt logprobs combos)
@@ -108,20 +128,8 @@ def _test_case_get_logprobs_and_prompt_logprobs(
108128 for num_lp , num_plp in logprob_prompt_logprob_list
109129 ]
110130
111- with vllm_runner (
112- model ,
113- dtype = dtype ,
114- max_logprobs = 7 ,
115- max_num_batched_tokens = max_num_batched_tokens ,
116- max_num_seqs = max_num_seqs ,
117- max_model_len = max_model_len ,
118- enforce_eager = True ,
119- # TODO: enable this once we support it for
120- # prompt logprobs.
121- enable_prefix_caching = False ,
122- ) as vllm_model :
123- vllm_results = vllm_model .model .generate (
124- test_prompts , sampling_params = vllm_sampling_params )
131+ vllm_results = vllm_model .model .generate (
132+ test_prompts , sampling_params = vllm_sampling_params )
125133
126134 for vllm_result , hf_logprob , hf_output , logprob_prompt_logprob in zip (
127135 vllm_results , hf_logprobs , hf_outputs ,
@@ -260,21 +268,14 @@ def _test_case_get_logprobs_and_prompt_logprobs(
260268 assert vllm_result .prompt_logprobs is None
261269
262270
263- @pytest .mark .parametrize ("model" , MODELS )
264- @pytest .mark .parametrize ("dtype" ,
265- ["half" ]) # needed for comparing logprobs with HF
266- # Include a very small max_num_batched_tokens to ensure we test chunking
267- @pytest .mark .parametrize ("max_num_batched_tokens" , [16 , 256 ])
271+ #@pytest.mark.skip_global_cleanup
268272@pytest .mark .parametrize ("batch_logprobs_composition" ,
269273 ["NONE" , "SAMPLE" , "PROMPT" , "SAMPLE_PROMPT" ])
270274@pytest .mark .parametrize ("temperature" , [0.0 , 2.0 ])
271275def test_get_logprobs_and_prompt_logprobs (
272- hf_runner ,
273- vllm_runner ,
274- model : str ,
275- dtype : str ,
276+ hf_model ,
277+ vllm_model ,
276278 batch_logprobs_composition : str ,
277- max_num_batched_tokens : int ,
278279 temperature : float ,
279280 example_prompts ,
280281) -> None :
@@ -293,22 +294,16 @@ def test_get_logprobs_and_prompt_logprobs(
293294 requests in the batch under test.
294295
295296 Args:
296- hf_runner
297- vllm_runner
298- model
299- dtype
297+ hf_model
298+ vllm_model
300299 batch_logprobs_composition: logprobs configuration for test batch
301- max_num_batched_tokens: token budget for scheduling
302300 example_prompts
303301 monkeypatch
304302 """
305303 _test_case_get_logprobs_and_prompt_logprobs (
306- hf_runner = hf_runner ,
307- vllm_runner = vllm_runner ,
308- model = model ,
309- dtype = dtype ,
304+ hf_model = hf_model ,
305+ vllm_model = vllm_model ,
310306 batch_logprobs_composition = batch_logprobs_composition ,
311- max_num_batched_tokens = max_num_batched_tokens ,
312307 temperature = temperature ,
313308 example_prompts = example_prompts )
314309
@@ -325,7 +320,8 @@ def test_max_logprobs(monkeypatch):
325320
326321 runner = VllmRunner ("facebook/opt-125m" ,
327322 max_logprobs = 1 ,
328- enable_prefix_caching = False )
323+ enable_prefix_caching = False ,
324+ max_model_len = 256 )
329325 vllm_sampling_params = SamplingParams (logprobs = 1 )
330326 # should pass
331327 runner .generate (["Hello world" ], sampling_params = vllm_sampling_params )
@@ -335,35 +331,23 @@ def test_max_logprobs(monkeypatch):
335331 runner .generate (["Hello world" ], sampling_params = bad_sampling_params )
336332
337333
338- @pytest .mark .parametrize ("model" , MODELS )
339- def test_none_logprobs (vllm_runner , model , example_prompts , monkeypatch ):
334+ def test_none_logprobs (vllm_model , example_prompts , monkeypatch ):
340335 """Engine should return `logprobs` and `prompt_logprobs` as `None`
341336
342337 Args:
343- vllm_runner: vLLM engine runner fixture
344- model: model name
338+ vllm_model: vLLM model fixture
345339 example_prompts: list of example prompts (test fixture)
346340 monkeypatch: supports editing env vars and rolling back changes
347341 after the test
348342 """
349- override_backend_env_variable (monkeypatch , "FLASH_ATTN" )
350-
351- max_num_seqs = 256
352- max_num_batched_tokens = None
353343 max_tokens = 5
354344
355- with vllm_runner (
356- model ,
357- max_num_batched_tokens = max_num_batched_tokens ,
358- max_num_seqs = max_num_seqs ,
359- enable_prefix_caching = False ,
360- ) as vllm_model :
361- sampling_params_logprobs_none = SamplingParams (max_tokens = max_tokens ,
362- logprobs = None ,
363- prompt_logprobs = None ,
364- temperature = 0.0 )
365- results_logprobs_none = vllm_model .model .generate (
366- example_prompts , sampling_params = sampling_params_logprobs_none )
345+ sampling_params_logprobs_none = SamplingParams (max_tokens = max_tokens ,
346+ logprobs = None ,
347+ prompt_logprobs = None ,
348+ temperature = 0.0 )
349+ results_logprobs_none = vllm_model .model .generate (
350+ example_prompts , sampling_params = sampling_params_logprobs_none )
367351
368352 for i in range (len (results_logprobs_none )):
369353 # Check sample logprobs are None
@@ -373,35 +357,23 @@ def test_none_logprobs(vllm_runner, model, example_prompts, monkeypatch):
373357 assert results_logprobs_none [i ].prompt_logprobs is None
374358
375359
376- @pytest .mark .parametrize ("model" , MODELS )
377- def test_zero_logprobs (vllm_runner , model , example_prompts , monkeypatch ):
360+ def test_zero_logprobs (vllm_model , example_prompts , monkeypatch ):
378361 """Engine should return sampled token and prompt token logprobs
379362
380363 Args:
381- vllm_runner: vLLM engine runner fixture
382- model: model name
364+ vllm_model: vLLM model fixture
383365 example_prompts: list of example prompts (test fixture)
384366 monkeypatch: supports editing env vars and rolling back changes
385367 after the test
386368 """
387- override_backend_env_variable (monkeypatch , "FLASH_ATTN" )
388-
389- max_num_seqs = 256
390- max_num_batched_tokens = None
391369 max_tokens = 5
392370
393- with vllm_runner (
394- model ,
395- max_num_batched_tokens = max_num_batched_tokens ,
396- max_num_seqs = max_num_seqs ,
397- enable_prefix_caching = False ,
398- ) as vllm_model :
399- sampling_params_logprobs_zero = SamplingParams (max_tokens = max_tokens ,
400- logprobs = 0 ,
401- prompt_logprobs = 0 ,
402- temperature = 0.0 )
403- results_logprobs_zero = vllm_model .model .generate (
404- example_prompts , sampling_params = sampling_params_logprobs_zero )
371+ sampling_params_logprobs_zero = SamplingParams (max_tokens = max_tokens ,
372+ logprobs = 0 ,
373+ prompt_logprobs = 0 ,
374+ temperature = 0.0 )
375+ results_logprobs_zero = vllm_model .model .generate (
376+ example_prompts , sampling_params = sampling_params_logprobs_zero )
405377
406378 for i in range (len (results_logprobs_zero )):
407379 # Check that there is one sample logprob dict for each
0 commit comments