Skip to content

Commit e5150f2

Browse files
authored
[Bugfix] Added test for sampling repetition penalty bug. (#5659)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
1 parent 59a1eb5 commit e5150f2

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed

tests/samplers/test_sampler.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,3 +631,72 @@ def mock_sample(probs, *args, **kwargs):
631631
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
632632
assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
633633
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
634+
635+
636+
@pytest.mark.parametrize("device", CUDA_DEVICES)
637+
def test_sampler_repetition_penalty_mixed(device: str):
638+
639+
vocab_size = 8
640+
641+
def test_sampling_params(sampling_params: List[SamplingParams]):
642+
643+
seq_group_metadata_list: List[SequenceGroupMetadata] = []
644+
seq_lens: List[int] = []
645+
for i in range(2):
646+
seq_group_metadata_list.append(
647+
SequenceGroupMetadata(
648+
request_id=f"test_{i}",
649+
is_prompt=True,
650+
seq_data={0: SequenceData([1, 2, 3])},
651+
sampling_params=sampling_params[i],
652+
block_tables={0: [1]},
653+
))
654+
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
655+
656+
sampling_metadata = SamplingMetadata.prepare(
657+
seq_group_metadata_list,
658+
seq_lens,
659+
query_lens=seq_lens,
660+
device=device,
661+
pin_memory=is_pin_memory_available())
662+
663+
fake_logits = torch.full((2, vocab_size),
664+
1e-2,
665+
device=device,
666+
dtype=torch.float16)
667+
668+
fake_logits[:, 5] = 1.1e-2
669+
fake_logits[:, 1] = 1.2e-2
670+
671+
sampler = MockLogitsSampler(fake_logits)
672+
673+
sampler_output = sampler(logits=fake_logits,
674+
sampling_metadata=sampling_metadata)
675+
676+
generated_tokens = []
677+
for output in sampler_output:
678+
generated_tokens.append(output.samples[0].output_token)
679+
680+
return generated_tokens
681+
682+
# one configuration is greedy with repetition_penalty
683+
sampling_params_rep = SamplingParams(
684+
temperature=0.0,
685+
repetition_penalty=2.0,
686+
)
687+
688+
# other configuration is sampling w/o repetition_penalty
689+
sampling_params_sample = SamplingParams(
690+
temperature=1.0,
691+
top_k=1,
692+
seed=42,
693+
)
694+
695+
tokens1 = test_sampling_params(
696+
[sampling_params_rep, sampling_params_sample])
697+
698+
tokens2 = test_sampling_params(
699+
[sampling_params_sample, sampling_params_rep])
700+
701+
assert tokens1[0] == tokens2[1]
702+
assert tokens1[1] == tokens2[0]

0 commit comments

Comments
 (0)