@@ -165,10 +165,10 @@ def __init__(
165165 batch_size : int ,
166166 num_beams : int ,
167167 device : torch .device ,
168- length_penalty : Optional [ float ] = 1.0 ,
169- do_early_stopping : Optional [ Union [bool , str ] ] = False ,
170- num_beam_hyps_to_keep : Optional [ int ] = 1 ,
171- num_beam_groups : Optional [ int ] = 1 ,
168+ length_penalty : float = 1.0 ,
169+ do_early_stopping : Union [bool , str ] = False ,
170+ num_beam_hyps_to_keep : int = 1 ,
171+ num_beam_groups : int = 1 ,
172172 max_length : Optional [int ] = None ,
173173 ):
174174 logger .warning_once (
@@ -214,7 +214,7 @@ def __init__(
214214
215215 @property
216216 def is_done (self ) -> bool :
217- return self ._done .all ()
217+ return self ._done .all (). item ()
218218
219219 def process (
220220 self ,
@@ -225,8 +225,8 @@ def process(
225225 pad_token_id : Optional [Union [int , torch .Tensor ]] = None ,
226226 eos_token_id : Optional [Union [int , list [int ], torch .Tensor ]] = None ,
227227 beam_indices : Optional [torch .LongTensor ] = None ,
228- group_index : Optional [ int ] = 0 ,
229- decoder_prompt_len : Optional [ int ] = 0 ,
228+ group_index : int = 0 ,
229+ decoder_prompt_len : int = 0 ,
230230 ) -> dict [str , torch .Tensor ]:
231231 # add up to the length which the next_scores is calculated on (including decoder prompt)
232232 cur_len = input_ids .shape [- 1 ] + 1
@@ -460,9 +460,9 @@ def __init__(
460460 num_beams : int ,
461461 constraints : list [Constraint ],
462462 device : torch .device ,
463- length_penalty : Optional [ float ] = 1.0 ,
464- do_early_stopping : Optional [ Union [bool , str ] ] = False ,
465- num_beam_hyps_to_keep : Optional [ int ] = 1 ,
463+ length_penalty : float = 1.0 ,
464+ do_early_stopping : Union [bool , str ] = False ,
465+ num_beam_hyps_to_keep : int = 1 ,
466466 max_length : Optional [int ] = None ,
467467 ):
468468 logger .warning_once (
@@ -495,7 +495,7 @@ def __init__(
495495
496496 @property
497497 def is_done (self ) -> bool :
498- return self ._done .all ()
498+ return self ._done .all (). item ()
499499
500500 def make_constraint_states (self , n ):
501501 return [ConstraintListState ([constraint .copy () for constraint in self .constraints ]) for _ in range (n )]
@@ -515,7 +515,7 @@ def process(
515515 pad_token_id : Optional [Union [int , torch .Tensor ]] = None ,
516516 eos_token_id : Optional [Union [int , list [int ], torch .Tensor ]] = None ,
517517 beam_indices : Optional [torch .LongTensor ] = None ,
518- decoder_prompt_len : Optional [ int ] = 0 ,
518+ decoder_prompt_len : int = 0 ,
519519 ) -> tuple [torch .Tensor ]:
520520 r"""
521521 Args:
@@ -912,7 +912,9 @@ def finalize(
912912
913913
914914class BeamHypotheses :
915- def __init__ (self , num_beams : int , length_penalty : float , early_stopping : bool , max_length : Optional [int ] = None ):
915+ def __init__ (
916+ self , num_beams : int , length_penalty : float , early_stopping : Union [bool , str ], max_length : Optional [int ] = None
917+ ):
916918 """
917919 Initialize n-best list of hypotheses.
918920 """
0 commit comments