3131 "hmellor/tiny-random-BambaForCausalLM" ,
3232 "ibm-granite/granite-4.0-tiny-preview" ,
3333 "tiiuae/Falcon-H1-0.5B-Base" ,
34- "nvidia/Nemotron-H-8B-Base-8K" ,
3534]
3635
3736HF_UNSUPPORTED_MODELS = [
4241 "yujiepan/mamba2-codestral-v0.1-tiny-random" ,
4342 # transformers 4.55 is still producing garbage for this model
4443 # TODO(tdoublep): follow-up on transformers side
45- "ibm-granite/granite-4.0-tiny-preview" ,
46- "nvidia/Nemotron-H-8B-Base-8K" ,
44+ "ibm-granite/granite-4.0-tiny-preview"
4745]
4846
4947V1_SUPPORTED_MODELS = [
5452 "hmellor/tiny-random-BambaForCausalLM" ,
5553 "ibm-granite/granite-4.0-tiny-preview" ,
5654 "tiiuae/Falcon-H1-0.5B-Base" ,
57- "nvidia/Nemotron-H-8B-Base-8K" ,
5855]
5956
6057# Avoid OOM
@@ -92,9 +89,7 @@ def test_models(
9289 else :
9390 hf_outputs = None
9491
95- with vllm_runner (model ,
96- max_num_seqs = MAX_NUM_SEQS ,
97- mamba_ssm_cache_dtype = "auto" ) as vllm_model :
92+ with vllm_runner (model , max_num_seqs = MAX_NUM_SEQS ) as vllm_model :
9893 vllm_v0_outputs = vllm_model .generate_greedy_logprobs (
9994 example_prompts , max_tokens , num_logprobs )
10095
@@ -106,8 +101,7 @@ def test_models(
106101 m .setenv ("VLLM_ATTENTION_BACKEND" , "FLASHINFER" )
107102 with vllm_runner (model ,
108103 max_num_seqs = MAX_NUM_SEQS ,
109- enable_prefix_caching = False ,
110- mamba_ssm_cache_dtype = "auto" ) as vllm_model :
104+ enable_prefix_caching = False ) as vllm_model :
111105 vllm_v1_outputs = vllm_model .generate_greedy_logprobs (
112106 example_prompts , max_tokens , num_logprobs )
113107 else :
@@ -424,3 +418,65 @@ def test_full_cuda_graph(
424418 name_0 = "hf" if hf_outputs is not None else "vllm-v0" ,
425419 name_1 = "vllm-v1" ,
426420 )
421+
422+
423+ @pytest .mark .parametrize ("model" , ["Zyphra/Zamba2-1.2B-instruct" ])
424+ @pytest .mark .parametrize ("max_tokens" , [64 ])
425+ @pytest .mark .parametrize ("num_logprobs" , [5 ])
426+ def test_fp32_state (
427+ hf_runner ,
428+ vllm_runner ,
429+ example_prompts ,
430+ monkeypatch ,
431+ model : str ,
432+ max_tokens : int ,
433+ num_logprobs : int ,
434+ ) -> None :
435+
436+ try :
437+ model_info = HF_EXAMPLE_MODELS .find_hf_info (model )
438+ model_info .check_available_online (on_fail = "skip" )
439+ model_info .check_transformers_version (on_fail = "skip" )
440+ except ValueError :
441+ pass
442+
443+ with hf_runner (model ) as hf_model :
444+ if model not in HF_UNSUPPORTED_MODELS :
445+ hf_outputs = hf_model .generate_greedy_logprobs_limit (
446+ example_prompts , max_tokens , num_logprobs )
447+ else :
448+ hf_outputs = None
449+
450+ with vllm_runner (model ,
451+ max_num_seqs = MAX_NUM_SEQS ,
452+ mamba_ssm_cache_dtype = "float32" ) as vllm_model :
453+ vllm_v0_outputs = vllm_model .generate_greedy_logprobs (
454+ example_prompts , max_tokens , num_logprobs )
455+
456+ with monkeypatch .context () as m :
457+ m .setenv ("VLLM_USE_V1" , "1" )
458+ if model in HYBRID_MODELS :
459+ # required due to reorder_batch behaviour
460+ m .setenv ("VLLM_ATTENTION_BACKEND" , "FLASHINFER" )
461+ with vllm_runner (model ,
462+ max_num_seqs = MAX_NUM_SEQS ,
463+ mamba_ssm_cache_dtype = "float32" ,
464+ enable_prefix_caching = False ) as vllm_model :
465+ vllm_v1_outputs = vllm_model .generate_greedy_logprobs (
466+ example_prompts , max_tokens , num_logprobs )
467+
468+ if hf_outputs is not None :
469+ check_logprobs_close (
470+ outputs_0_lst = hf_outputs ,
471+ outputs_1_lst = vllm_v0_outputs ,
472+ name_0 = "hf" ,
473+ name_1 = "vllm-v0" ,
474+ )
475+
476+ ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
477+ check_logprobs_close (
478+ outputs_0_lst = ref_outputs ,
479+ outputs_1_lst = vllm_v1_outputs ,
480+ name_0 = "hf" if hf_outputs is not None else "vllm-v0" ,
481+ name_1 = "vllm-v1" ,
482+ )
0 commit comments