@@ -624,9 +624,7 @@ def __init__(self, sequence_bias: Dict[Tuple[int], float]):
624624
625625 # Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size
626626 # is infered in the first usage, which inhibits initializing here)
627- self .sequences_length_greater_than_1 = []
628627 self .length_1_bias = None
629- self .length_greather_than_1_bias = None
630628 self .prepared_bias_variables = False
631629
632630 @add_start_docstrings (LOGITS_PROCESSOR_INPUTS_DOCSTRING )
@@ -642,11 +640,9 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
642640 bias += self .length_1_bias
643641
644642 # 4 - include the bias from length > 1, after determining which biased sequences may be completed.
645- # `matching_mask` is a (batch_size, vocab_size) boolean mask that is True for all tokens whose corresponding
646- # bias should be applied. The bias is applied on the last token of the sequence, if (and only if) the sequence
647- # may become complete this iteration.
648- matching_mask = torch .zeros_like (scores , dtype = torch .bool )
649- for sequence_ids in self .sequences_length_greater_than_1 :
643+ for sequence_ids , sequence_bias in self .sequence_bias .items ():
644+ if len (sequence_ids ) == 1 : # the sequence is of length 1, already applied
645+ continue
650646 if len (sequence_ids ) > input_ids .shape [1 ]: # the sequence is longer than the context, ignore
651647 continue
652648 prefix_length = len (sequence_ids ) - 1
@@ -655,25 +651,20 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
655651 input_ids [:, - prefix_length :],
656652 torch .tensor (sequence_ids [:- 1 ], dtype = input_ids .dtype , device = input_ids .device ),
657653 ).prod (dim = 1 )
658- matching_mask [:, last_token ] |= matching_rows .bool ()
659- bias += torch .where (
660- matching_mask ,
661- self .length_greather_than_1_bias ,
662- torch .tensor (0.0 , device = self .length_greather_than_1_bias .device ),
663- )
654+ bias [:, last_token ] += torch .where (
655+ matching_rows .bool (), sequence_bias , torch .tensor (0.0 , device = input_ids .device )
656+ )
664657
665658 # 5 - apply the bias to the scores
666659 scores = scores + bias
667660 return scores
668661
669662 def _prepare_bias_variables (self , scores : torch .FloatTensor ):
670663 vocabulary_size = scores .shape [- 1 ]
671- sequence_bias = self .sequence_bias
672- tokens_with_bias = []
673664
674665 # Check biased tokens out of bounds
675666 invalid_biases = []
676- for sequence_ids in sequence_bias :
667+ for sequence_ids in self . sequence_bias :
677668 for token_id in sequence_ids :
678669 if token_id >= vocabulary_size :
679670 invalid_biases .append (token_id )
@@ -686,20 +677,9 @@ def _prepare_bias_variables(self, scores: torch.FloatTensor):
686677 # Precompute the bias tensors to be applied. Sequences of length 1 are kept separately, as they can be applied
687678 # with simpler logic.
688679 self .length_1_bias = torch .zeros ((vocabulary_size ,), dtype = torch .float ).to (scores .device )
689- self .length_greather_than_1_bias = torch .zeros ((vocabulary_size ,), dtype = torch .float ).to (scores .device )
690- for sequence_ids , bias in sequence_bias .items ():
680+ for sequence_ids , bias in self .sequence_bias .items ():
691681 if len (sequence_ids ) == 1 :
692682 self .length_1_bias [sequence_ids [- 1 ]] = bias
693- else :
694- self .sequences_length_greater_than_1 .append (sequence_ids )
695- if self .length_greather_than_1_bias [sequence_ids [- 1 ]] != 0.0 :
696- raise ValueError (
697- "Setting a bias on sequences that share a common token termination is not yet supported. "
698- "Please open an issue if you see this error message (after checking that it doesn't already "
699- "exist)."
700- )
701- self .length_greather_than_1_bias [sequence_ids [- 1 ]] = bias
702- tokens_with_bias .append (sequence_ids [- 1 ])
703683
704684 self .prepared_bias_variables = True
705685
0 commit comments