@@ -637,3 +637,72 @@ def mock_sample(probs, *args, **kwargs):
637637 hf_probs = torch .softmax (hf_probs , dim = - 1 , dtype = torch .float )
638638 assert torch .allclose (hf_probs , sample_probs , atol = 1e-5 )
639639 assert torch .equal (hf_probs .eq (0 ), sample_probs .eq (0 ))
640+
641+
642+ @pytest .mark .parametrize ("device" , CUDA_DEVICES )
643+ def test_sampler_repetition_penalty_mixed (device : str ):
644+
645+ vocab_size = 8
646+
647+ def test_sampling_params (sampling_params : List [SamplingParams ]):
648+
649+ seq_group_metadata_list : List [SequenceGroupMetadata ] = []
650+ seq_lens : List [int ] = []
651+ for i in range (2 ):
652+ seq_group_metadata_list .append (
653+ SequenceGroupMetadata (
654+ request_id = f"test_{ i } " ,
655+ is_prompt = True ,
656+ seq_data = {0 : SequenceData ([1 , 2 , 3 ])},
657+ sampling_params = sampling_params [i ],
658+ block_tables = {0 : [1 ]},
659+ ))
660+ seq_lens .append (seq_group_metadata_list [- 1 ].seq_data [0 ].get_len ())
661+
662+ sampling_metadata = SamplingMetadata .prepare (
663+ seq_group_metadata_list ,
664+ seq_lens ,
665+ query_lens = seq_lens ,
666+ device = device ,
667+ pin_memory = is_pin_memory_available ())
668+
669+ fake_logits = torch .full ((2 , vocab_size ),
670+ 1e-2 ,
671+ device = device ,
672+ dtype = torch .float16 )
673+
674+ fake_logits [:, 5 ] = 1.1e-2
675+ fake_logits [:, 1 ] = 1.2e-2
676+
677+ sampler = MockLogitsSampler (fake_logits )
678+
679+ sampler_output = sampler (logits = fake_logits ,
680+ sampling_metadata = sampling_metadata )
681+
682+ generated_tokens = []
683+ for output in sampler_output :
684+ generated_tokens .append (output .samples [0 ].output_token )
685+
686+ return generated_tokens
687+
688+ # one configuration is greedy with repetition_penalty
689+ sampling_params_rep = SamplingParams (
690+ temperature = 0.0 ,
691+ repetition_penalty = 2.0 ,
692+ )
693+
694+ # other configuration is sampling w/o repetition_penalty
695+ sampling_params_sample = SamplingParams (
696+ temperature = 1.0 ,
697+ top_k = 1 ,
698+ seed = 42 ,
699+ )
700+
701+ tokens1 = test_sampling_params (
702+ [sampling_params_rep , sampling_params_sample ])
703+
704+ tokens2 = test_sampling_params (
705+ [sampling_params_sample , sampling_params_rep ])
706+
707+ assert tokens1 [0 ] == tokens2 [1 ]
708+ assert tokens1 [1 ] == tokens2 [0 ]
0 commit comments