Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit e4d2b6e

Browse files
tdoublepRobert Shaw
authored andcommitted
[Bugfix] Added test for sampling repetition penalty bug. (vllm-project#5659)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
1 parent a0d8ed2 commit e4d2b6e

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed

tests/samplers/test_sampler.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,3 +637,72 @@ def mock_sample(probs, *args, **kwargs):
637637
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
638638
assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
639639
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
640+
641+
642+
@pytest.mark.parametrize("device", CUDA_DEVICES)
643+
def test_sampler_repetition_penalty_mixed(device: str):
644+
645+
vocab_size = 8
646+
647+
def test_sampling_params(sampling_params: List[SamplingParams]):
648+
649+
seq_group_metadata_list: List[SequenceGroupMetadata] = []
650+
seq_lens: List[int] = []
651+
for i in range(2):
652+
seq_group_metadata_list.append(
653+
SequenceGroupMetadata(
654+
request_id=f"test_{i}",
655+
is_prompt=True,
656+
seq_data={0: SequenceData([1, 2, 3])},
657+
sampling_params=sampling_params[i],
658+
block_tables={0: [1]},
659+
))
660+
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
661+
662+
sampling_metadata = SamplingMetadata.prepare(
663+
seq_group_metadata_list,
664+
seq_lens,
665+
query_lens=seq_lens,
666+
device=device,
667+
pin_memory=is_pin_memory_available())
668+
669+
fake_logits = torch.full((2, vocab_size),
670+
1e-2,
671+
device=device,
672+
dtype=torch.float16)
673+
674+
fake_logits[:, 5] = 1.1e-2
675+
fake_logits[:, 1] = 1.2e-2
676+
677+
sampler = MockLogitsSampler(fake_logits)
678+
679+
sampler_output = sampler(logits=fake_logits,
680+
sampling_metadata=sampling_metadata)
681+
682+
generated_tokens = []
683+
for output in sampler_output:
684+
generated_tokens.append(output.samples[0].output_token)
685+
686+
return generated_tokens
687+
688+
# one configuration is greedy with repetition_penalty
689+
sampling_params_rep = SamplingParams(
690+
temperature=0.0,
691+
repetition_penalty=2.0,
692+
)
693+
694+
# other configuration is sampling w/o repetition_penalty
695+
sampling_params_sample = SamplingParams(
696+
temperature=1.0,
697+
top_k=1,
698+
seed=42,
699+
)
700+
701+
tokens1 = test_sampling_params(
702+
[sampling_params_rep, sampling_params_sample])
703+
704+
tokens2 = test_sampling_params(
705+
[sampling_params_sample, sampling_params_rep])
706+
707+
assert tokens1[0] == tokens2[1]
708+
assert tokens1[1] == tokens2[0]

0 commit comments

Comments
 (0)