Skip to content

Commit b61b5a6

Browse files
committed
Add fp32 state test
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
1 parent 5b005ec commit b61b5a6

File tree

1 file changed

+65
-9
lines changed

1 file changed

+65
-9
lines changed

tests/models/language/generation/test_hybrid.py

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
"hmellor/tiny-random-BambaForCausalLM",
3232
"ibm-granite/granite-4.0-tiny-preview",
3333
"tiiuae/Falcon-H1-0.5B-Base",
34-
"nvidia/Nemotron-H-8B-Base-8K",
3534
]
3635

3736
HF_UNSUPPORTED_MODELS = [
@@ -42,8 +41,7 @@
4241
"yujiepan/mamba2-codestral-v0.1-tiny-random",
4342
# transformers 4.55 is still producing garbage for this model
4443
# TODO(tdoublep): follow-up on transformers side
45-
"ibm-granite/granite-4.0-tiny-preview",
46-
"nvidia/Nemotron-H-8B-Base-8K",
44+
"ibm-granite/granite-4.0-tiny-preview"
4745
]
4846

4947
V1_SUPPORTED_MODELS = [
@@ -54,7 +52,6 @@
5452
"hmellor/tiny-random-BambaForCausalLM",
5553
"ibm-granite/granite-4.0-tiny-preview",
5654
"tiiuae/Falcon-H1-0.5B-Base",
57-
"nvidia/Nemotron-H-8B-Base-8K",
5855
]
5956

6057
# Avoid OOM
@@ -92,9 +89,7 @@ def test_models(
9289
else:
9390
hf_outputs = None
9491

95-
with vllm_runner(model,
96-
max_num_seqs=MAX_NUM_SEQS,
97-
mamba_ssm_cache_dtype="auto") as vllm_model:
92+
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
9893
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
9994
example_prompts, max_tokens, num_logprobs)
10095

@@ -106,8 +101,7 @@ def test_models(
106101
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
107102
with vllm_runner(model,
108103
max_num_seqs=MAX_NUM_SEQS,
109-
enable_prefix_caching=False,
110-
mamba_ssm_cache_dtype="auto") as vllm_model:
104+
enable_prefix_caching=False) as vllm_model:
111105
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
112106
example_prompts, max_tokens, num_logprobs)
113107
else:
@@ -424,3 +418,65 @@ def test_full_cuda_graph(
424418
name_0="hf" if hf_outputs is not None else "vllm-v0",
425419
name_1="vllm-v1",
426420
)
421+
422+
423+
@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"])
424+
@pytest.mark.parametrize("max_tokens", [64])
425+
@pytest.mark.parametrize("num_logprobs", [5])
426+
def test_fp32_state(
427+
hf_runner,
428+
vllm_runner,
429+
example_prompts,
430+
monkeypatch,
431+
model: str,
432+
max_tokens: int,
433+
num_logprobs: int,
434+
) -> None:
435+
436+
try:
437+
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
438+
model_info.check_available_online(on_fail="skip")
439+
model_info.check_transformers_version(on_fail="skip")
440+
except ValueError:
441+
pass
442+
443+
with hf_runner(model) as hf_model:
444+
if model not in HF_UNSUPPORTED_MODELS:
445+
hf_outputs = hf_model.generate_greedy_logprobs_limit(
446+
example_prompts, max_tokens, num_logprobs)
447+
else:
448+
hf_outputs = None
449+
450+
with vllm_runner(model,
451+
max_num_seqs=MAX_NUM_SEQS,
452+
mamba_ssm_cache_dtype="float32") as vllm_model:
453+
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
454+
example_prompts, max_tokens, num_logprobs)
455+
456+
with monkeypatch.context() as m:
457+
m.setenv("VLLM_USE_V1", "1")
458+
if model in HYBRID_MODELS:
459+
# required due to reorder_batch behaviour
460+
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
461+
with vllm_runner(model,
462+
max_num_seqs=MAX_NUM_SEQS,
463+
mamba_ssm_cache_dtype="float32",
464+
enable_prefix_caching=False) as vllm_model:
465+
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
466+
example_prompts, max_tokens, num_logprobs)
467+
468+
if hf_outputs is not None:
469+
check_logprobs_close(
470+
outputs_0_lst=hf_outputs,
471+
outputs_1_lst=vllm_v0_outputs,
472+
name_0="hf",
473+
name_1="vllm-v0",
474+
)
475+
476+
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
477+
check_logprobs_close(
478+
outputs_0_lst=ref_outputs,
479+
outputs_1_lst=vllm_v1_outputs,
480+
name_0="hf" if hf_outputs is not None else "vllm-v0",
481+
name_1="vllm-v1",
482+
)

0 commit comments

Comments
 (0)