Skip to content

Commit 56be5e8

Browse files
Fix: Raise informative exception when prefix_allowed_tokens_fn return empty set of tokens (#27797)
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
1 parent 307a7d0 commit 56be5e8

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

src/transformers/generation/logits_process.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1229,7 +1229,14 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
12291229
mask = torch.full_like(scores, -math.inf)
12301230
for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
12311231
for beam_id, sent in enumerate(beam_sent):
1232-
mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0
1232+
prefix_allowed_tokens = self._prefix_allowed_tokens_fn(batch_id, sent)
1233+
if len(prefix_allowed_tokens) == 0:
1234+
raise ValueError(
1235+
f"`prefix_allowed_tokens_fn` returned an empty list for batch ID {batch_id}."
1236+
f"This means that the constraint is unsatisfiable. Please check your implementation"
1237+
f"of `prefix_allowed_tokens_fn` "
1238+
)
1239+
mask[batch_id * self._num_beams + beam_id, prefix_allowed_tokens] = 0
12331240

12341241
return scores + mask
12351242

tests/generation/test_logits_process.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,13 @@ def prefix_allowed_tokens_fn(batch_id, inputs_ids):
610610
torch.isinf(filtered_scores).tolist(), [[False, False, True, True, True], [True, True, False, False, True]]
611611
)
612612

613+
def empty_prefix_allowed_tokens_fn(batch_id, inputs_ids):
614+
return []
615+
616+
prefix_constrained_logits_proc = PrefixConstrainedLogitsProcessor(empty_prefix_allowed_tokens_fn, 1)
617+
618+
self.assertRaises(ValueError, prefix_constrained_logits_proc, input_ids, scores.clone())
619+
613620
def test_hamming_diversity(self):
614621
vocab_size = 4
615622
num_beams = 2

0 commit comments

Comments
 (0)