@@ -631,3 +631,72 @@ def mock_sample(probs, *args, **kwargs):
631631 hf_probs = torch .softmax (hf_probs , dim = - 1 , dtype = torch .float )
632632 assert torch .allclose (hf_probs , sample_probs , atol = 1e-5 )
633633 assert torch .equal (hf_probs .eq (0 ), sample_probs .eq (0 ))
634+
635+
636+ @pytest .mark .parametrize ("device" , CUDA_DEVICES )
637+ def test_sampler_repetition_penalty_mixed (device : str ):
638+
639+ vocab_size = 8
640+
641+ def test_sampling_params (sampling_params : List [SamplingParams ]):
642+
643+ seq_group_metadata_list : List [SequenceGroupMetadata ] = []
644+ seq_lens : List [int ] = []
645+ for i in range (2 ):
646+ seq_group_metadata_list .append (
647+ SequenceGroupMetadata (
648+ request_id = f"test_{ i } " ,
649+ is_prompt = True ,
650+ seq_data = {0 : SequenceData ([1 , 2 , 3 ])},
651+ sampling_params = sampling_params [i ],
652+ block_tables = {0 : [1 ]},
653+ ))
654+ seq_lens .append (seq_group_metadata_list [- 1 ].seq_data [0 ].get_len ())
655+
656+ sampling_metadata = SamplingMetadata .prepare (
657+ seq_group_metadata_list ,
658+ seq_lens ,
659+ query_lens = seq_lens ,
660+ device = device ,
661+ pin_memory = is_pin_memory_available ())
662+
663+ fake_logits = torch .full ((2 , vocab_size ),
664+ 1e-2 ,
665+ device = device ,
666+ dtype = torch .float16 )
667+
668+ fake_logits [:, 5 ] = 1.1e-2
669+ fake_logits [:, 1 ] = 1.2e-2
670+
671+ sampler = MockLogitsSampler (fake_logits )
672+
673+ sampler_output = sampler (logits = fake_logits ,
674+ sampling_metadata = sampling_metadata )
675+
676+ generated_tokens = []
677+ for output in sampler_output :
678+ generated_tokens .append (output .samples [0 ].output_token )
679+
680+ return generated_tokens
681+
682+ # one configuration is greedy with repetition_penalty
683+ sampling_params_rep = SamplingParams (
684+ temperature = 0.0 ,
685+ repetition_penalty = 2.0 ,
686+ )
687+
688+ # other configuration is sampling w/o repetition_penalty
689+ sampling_params_sample = SamplingParams (
690+ temperature = 1.0 ,
691+ top_k = 1 ,
692+ seed = 42 ,
693+ )
694+
695+ tokens1 = test_sampling_params (
696+ [sampling_params_rep , sampling_params_sample ])
697+
698+ tokens2 = test_sampling_params (
699+ [sampling_params_sample , sampling_params_rep ])
700+
701+ assert tokens1 [0 ] == tokens2 [1 ]
702+ assert tokens1 [1 ] == tokens2 [0 ]
0 commit comments