Skip to content

Commit 89136ff

Browse files
authored
Generate: sequence bias can handle same terminations (#24822)
1 parent 37d8611 commit 89136ff

File tree

2 files changed

+11
-28
lines changed

2 files changed

+11
-28
lines changed

src/transformers/generation/logits_process.py

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -624,9 +624,7 @@ def __init__(self, sequence_bias: Dict[Tuple[int], float]):
624624

625625
# Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size
626626
# is infered in the first usage, which inhibits initializing here)
627-
self.sequences_length_greater_than_1 = []
628627
self.length_1_bias = None
629-
self.length_greather_than_1_bias = None
630628
self.prepared_bias_variables = False
631629

632630
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
@@ -642,11 +640,9 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
642640
bias += self.length_1_bias
643641

644642
# 4 - include the bias from length > 1, after determining which biased sequences may be completed.
645-
# `matching_mask` is a (batch_size, vocab_size) boolean mask that is True for all tokens whose corresponding
646-
# bias should be applied. The bias is applied on the last token of the sequence, if (and only if) the sequence
647-
# may become complete this iteration.
648-
matching_mask = torch.zeros_like(scores, dtype=torch.bool)
649-
for sequence_ids in self.sequences_length_greater_than_1:
643+
for sequence_ids, sequence_bias in self.sequence_bias.items():
644+
if len(sequence_ids) == 1: # the sequence is of length 1, already applied
645+
continue
650646
if len(sequence_ids) > input_ids.shape[1]: # the sequence is longer than the context, ignore
651647
continue
652648
prefix_length = len(sequence_ids) - 1
@@ -655,25 +651,20 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
655651
input_ids[:, -prefix_length:],
656652
torch.tensor(sequence_ids[:-1], dtype=input_ids.dtype, device=input_ids.device),
657653
).prod(dim=1)
658-
matching_mask[:, last_token] |= matching_rows.bool()
659-
bias += torch.where(
660-
matching_mask,
661-
self.length_greather_than_1_bias,
662-
torch.tensor(0.0, device=self.length_greather_than_1_bias.device),
663-
)
654+
bias[:, last_token] += torch.where(
655+
matching_rows.bool(), sequence_bias, torch.tensor(0.0, device=input_ids.device)
656+
)
664657

665658
# 5 - apply the bias to the scores
666659
scores = scores + bias
667660
return scores
668661

669662
def _prepare_bias_variables(self, scores: torch.FloatTensor):
670663
vocabulary_size = scores.shape[-1]
671-
sequence_bias = self.sequence_bias
672-
tokens_with_bias = []
673664

674665
# Check biased tokens out of bounds
675666
invalid_biases = []
676-
for sequence_ids in sequence_bias:
667+
for sequence_ids in self.sequence_bias:
677668
for token_id in sequence_ids:
678669
if token_id >= vocabulary_size:
679670
invalid_biases.append(token_id)
@@ -686,20 +677,9 @@ def _prepare_bias_variables(self, scores: torch.FloatTensor):
686677
# Precompute the bias tensors to be applied. Sequences of length 1 are kept separately, as they can be applied
687678
# with simpler logic.
688679
self.length_1_bias = torch.zeros((vocabulary_size,), dtype=torch.float).to(scores.device)
689-
self.length_greather_than_1_bias = torch.zeros((vocabulary_size,), dtype=torch.float).to(scores.device)
690-
for sequence_ids, bias in sequence_bias.items():
680+
for sequence_ids, bias in self.sequence_bias.items():
691681
if len(sequence_ids) == 1:
692682
self.length_1_bias[sequence_ids[-1]] = bias
693-
else:
694-
self.sequences_length_greater_than_1.append(sequence_ids)
695-
if self.length_greather_than_1_bias[sequence_ids[-1]] != 0.0:
696-
raise ValueError(
697-
"Setting a bias on sequences that share a common token termination is not yet supported. "
698-
"Please open an issue if you see this error message (after checking that it doesn't already "
699-
"exist)."
700-
)
701-
self.length_greather_than_1_bias[sequence_ids[-1]] = bias
702-
tokens_with_bias.append(sequence_ids[-1])
703683

704684
self.prepared_bias_variables = True
705685

tests/generation/test_logits_process.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,9 @@ def test_bias_dist_processor(self):
520520
input_ids = torch.tensor([[0, 1, 3, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long)
521521
positive_bias = {(1,): 100.0, (4,): 100.0}
522522
negative_bias = {(1, 0): -100.0, (0, 1, 2): -100.0, (1, 3, 1, 3): -100.0}
523+
# biases the same termination twice, to ensure we can handle overlapping terminations (it won't have an effect
524+
# on the test cases, though)
525+
negative_bias.update({(1, 3, 1, 3, 1, 3): -100.0})
523526
sequence_bias = {**positive_bias, **negative_bias}
524527

525528
# scores = 0 to facilitate checks

0 commit comments

Comments
 (0)