@@ -234,15 +234,22 @@ def to_sampling_params(self) -> SamplingParams:
234
234
235
235
logits_processors = None
236
236
if self .logit_bias :
237
+ logit_bias : Dict [int , float ] = {}
238
+ try :
239
+ for token_id , bias in self .logit_bias .items ():
240
+ # Convert token_id to integer before we add to LLMEngine
241
+ # Clamp the bias between -100 and 100 per OpenAI API spec
242
+ logit_bias [int (token_id )] = min (100 , max (- 100 , bias ))
243
+ except ValueError as exc :
244
+ raise ValueError (f"Found token_id `{ token_id } ` in logit_bias "
245
+ f"but token_id must be an integer or string "
246
+ f"representing an integer" ) from exc
237
247
238
248
def logit_bias_logits_processor (
239
249
token_ids : List [int ],
240
250
logits : torch .Tensor ) -> torch .Tensor :
241
- assert self .logit_bias is not None
242
- for token_id , bias in self .logit_bias .items ():
243
- # Clamp the bias between -100 and 100 per OpenAI API spec
244
- bias = min (100 , max (- 100 , bias ))
245
- logits [int (token_id )] += bias
251
+ for token_id , bias in logit_bias .items ():
252
+ logits [token_id ] += bias
246
253
return logits
247
254
248
255
logits_processors = [logit_bias_logits_processor ]
@@ -419,15 +426,22 @@ def to_sampling_params(self):
419
426
420
427
logits_processors = None
421
428
if self .logit_bias :
429
+ logit_bias : Dict [int , float ] = {}
430
+ try :
431
+ for token_id , bias in self .logit_bias .items ():
432
+ # Convert token_id to integer
433
+ # Clamp the bias between -100 and 100 per OpenAI API spec
434
+ logit_bias [int (token_id )] = min (100 , max (- 100 , bias ))
435
+ except ValueError as exc :
436
+ raise ValueError (f"Found token_id `{ token_id } ` in logit_bias "
437
+ f"but token_id must be an integer or string "
438
+ f"representing an integer" ) from exc
422
439
423
440
def logit_bias_logits_processor (
424
441
token_ids : List [int ],
425
442
logits : torch .Tensor ) -> torch .Tensor :
426
- assert self .logit_bias is not None
427
- for token_id , bias in self .logit_bias .items ():
428
- # Clamp the bias between -100 and 100 per OpenAI API spec
429
- bias = min (100 , max (- 100 , bias ))
430
- logits [int (token_id )] += bias
443
+ for token_id , bias in logit_bias .items ():
444
+ logits [token_id ] += bias
431
445
return logits
432
446
433
447
logits_processors = [logit_bias_logits_processor ]
0 commit comments