@@ -36,21 +36,23 @@ def top_p_logit_processor(logits, top_p: float):
3636 cumulative_probs = torch .cumsum (F .softmax (sorted_logits , dim = - 1 ), dim = - 1 )
3737
3838 sorted_indices_to_remove = cumulative_probs > top_p
39- sorted_indices_to_remove [..., 1 :] = sorted_indices_to_remove [..., :- 1 ].clone ()
39+
40+ sorted_indices_to_remove = torch .roll (sorted_indices_to_remove , 1 , - 1 )
4041 sorted_indices_to_remove [..., 0 ] = 0
4142
4243 indices_to_remove = sorted_indices_to_remove .scatter (dim = 1 , index = sorted_indices , src = sorted_indices_to_remove )
4344 logits [indices_to_remove ] = - float ("inf" )
4445 return logits
4546
46- def logit_processor (processor :str , logits , attrs ):
47+
48+ def logit_processor (processor : str , logits , attrs ):
4749 """
4850 do logit process for given logits.
4951
5052 Args:
51- processor(str): the type of logit processor
53+ processor(str): the type of logit processor
5254 logits(torch.Tensor): input logits
53- attrs(dict): attrs of the logit processor
55+ attrs(dict): attrs of the logit processor
5456
5557 Returns:
5658 logits after process
@@ -61,6 +63,6 @@ def logit_processor(processor:str, logits , attrs):
6163 func = _LOGIT_PROCESSOR_MAP [processor ]
6264 try :
6365 logits = func (logits , attrs )
64- except Exception as e :
66+ except Exception :
6567 return logits
66- return logits
68+ return logits
0 commit comments