Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/spec_decode/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ def generator_inner():
and llm.llm_engine.log_stats):
for sate_logger in llm.llm_engine.stat_loggers.values():
sate_logger.local_interval = 0
set_random_seed(seed)
if seed is not None:
set_random_seed(seed)

yield llm
del llm
Expand Down
22 changes: 17 additions & 5 deletions tests/spec_decode/e2e/test_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
"num_speculative_tokens": 3,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}])
@pytest.mark.parametrize("batch_size", [1, 8, 32])
@pytest.mark.parametrize("temperature", [0.1, 1.0])
@pytest.mark.parametrize(
Expand All @@ -30,15 +31,26 @@
# Use smaller output len for fast test.
10,
])
@pytest.mark.parametrize("seed", [1])
def test_seeded_consistency(baseline_llm_generator, batch_size: int,
temperature: float, output_len: int):
@pytest.mark.parametrize("seed", [None])
def test_seeded_consistency(baseline_llm_generator, test_llm_generator,
batch_size: int, temperature: float,
output_len: int):
"""Verify outputs are consistent across multiple runs with same seed
"""
run_equality_correctness_test(baseline_llm_generator,
baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
temperature=temperature,
seeded=True,
force_output_len=True)

# Ensure this same test does fail if we _don't_ include per-request seeds
with pytest.raises(AssertionError):
run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
temperature=temperature,
seeded=False,
force_output_len=True)