@@ -62,6 +62,7 @@ def _construct_expected_sampling_metadata(
62
62
repetition_penalties = [1.0 for _ in range (num_reqs )]
63
63
top_k = [0 for _ in range (num_reqs )]
64
64
top_p = [0.0 for _ in range (num_reqs )]
65
+ min_p = [0.0 for _ in range (num_reqs )]
65
66
temperature = [0.0 for _ in range (num_reqs )]
66
67
stop_token_ids : List [Set [int ]] = [set () for _ in range (num_reqs )]
67
68
min_tokens = [0 for _ in range (num_reqs )]
@@ -80,12 +81,12 @@ def _construct_expected_sampling_metadata(
80
81
req .sampling_params .repetition_penalty )
81
82
top_k [index_in_input_batch ] = req .sampling_params .top_k
82
83
top_p [index_in_input_batch ] = req .sampling_params .top_p
84
+ min_p [index_in_input_batch ] = req .sampling_params .min_p
83
85
temperature [index_in_input_batch ] = req .sampling_params .temperature
84
86
stop_token_ids [
85
87
index_in_input_batch ] = req .sampling_params .all_stop_token_ids
86
88
min_tokens [index_in_input_batch ] = req .sampling_params .min_tokens
87
89
logit_bias [index_in_input_batch ] = req .sampling_params .logit_bias
88
-
89
90
return SamplingMetadata (
90
91
temperature = torch .tensor (temperature , dtype = torch .float ,
91
92
device = device ),
@@ -95,6 +96,8 @@ def _construct_expected_sampling_metadata(
95
96
top_k = torch .tensor (top_k , dtype = torch .int , device = device ),
96
97
no_top_p = all (x == 1.0 for x in top_p ),
97
98
no_top_k = all (x == 0 for x in top_k ),
99
+ min_p = torch .tensor (min_p , dtype = torch .float , device = device ),
100
+ no_min_p = all (x == 0.0 for x in min_p ),
98
101
generators = {},
99
102
max_num_logprobs = 0 ,
100
103
prompt_token_ids = make_tensor_with_pad (
0 commit comments