Skip to content

Commit 2f9d49b

Browse files
Adding PrefixConstrainedLogitsProcessor (#8529)
* Adding PrefixConstrainedLogitsProcessor * fixing RAG and style_doc * fixing black (v20 instead of v19) * Improving doc in generation_logits_process.py * Improving docs and typing in generation_utils.py * docs improvement * adding test and fixing doc typo * fixing doc_len * isort on test * fixed test * improve docstring a bit Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
1 parent 3bc1540 commit 2f9d49b

File tree

4 files changed

+77
-3
lines changed

4 files changed

+77
-3
lines changed

src/transformers/generation_logits_process.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import math
1617
from abc import ABC
17-
from typing import Iterable, List
18+
from typing import Callable, Iterable, List
1819

1920
import numpy as np
2021
import torch
@@ -372,3 +373,30 @@ def _set_scores_to_inf_for_banned_tokens(self, scores: torch.Tensor, banned_toke
372373
)
373374
scores = scores.masked_fill(banned_mask, -float("inf"))
374375
return scores
376+
377+
378+
class PrefixConstrainedLogitsProcessor(LogitsProcessor):
379+
r"""
380+
:class:`transformers.LogitsProcessor` that enforces contrained generation and is useful for prefix-conditioned
381+
constrained generation. See `Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__ for more
382+
information.
383+
384+
Args:
385+
prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`):
386+
This function constraints the beam search to allowed tokens only at each step. This function takes 2
387+
arguments :obj:`inputs_ids` and the batch ID :obj:`batch_id`. It has to return a list with the allowed
388+
tokens for the next generation step conditioned on the previously generated tokens :obj:`inputs_ids` and
389+
the batch ID :obj:`batch_id`.
390+
"""
391+
392+
def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int):
393+
self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
394+
self._num_beams = num_beams
395+
396+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
397+
mask = torch.full_like(scores, -math.inf)
398+
for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
399+
for beam_id, sent in enumerate(beam_sent):
400+
mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0
401+
402+
return scores + mask

src/transformers/generation_utils.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
from typing import Any, Dict, Iterable, List, Optional, Tuple
17+
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
1818

1919
import torch
2020
from torch.nn import functional as F
@@ -26,6 +26,7 @@
2626
MinLengthLogitsProcessor,
2727
NoBadWordsLogitsProcessor,
2828
NoRepeatNGramLogitsProcessor,
29+
PrefixConstrainedLogitsProcessor,
2930
RepetitionPenaltyLogitsProcessor,
3031
TemperatureLogitsWarper,
3132
TopKLogitsWarper,
@@ -258,6 +259,8 @@ def _get_logits_processor(
258259
bad_words_ids: List[List[int]],
259260
min_length: int,
260261
eos_token_id: int,
262+
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
263+
num_beams: int,
261264
) -> LogitsProcessorList:
262265
"""
263266
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant
@@ -285,6 +288,8 @@ def _get_logits_processor(
285288
processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
286289
if min_length is not None and eos_token_id is not None and min_length > -1:
287290
processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
291+
if prefix_allowed_tokens_fn is not None:
292+
processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams))
288293
return processors
289294

290295
@torch.no_grad()
@@ -309,6 +314,7 @@ def generate(
309314
num_return_sequences: Optional[int] = None,
310315
decoder_start_token_id: Optional[int] = None,
311316
use_cache: Optional[bool] = None,
317+
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
312318
**model_kwargs
313319
) -> torch.LongTensor:
314320
r"""
@@ -375,6 +381,13 @@ def generate(
375381
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
376382
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
377383
speed up decoding.
384+
prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`, `optional`):
385+
If provided, this function constraints the beam search to allowed tokens only at each step. If not
386+
provided no constraint is applied. This function takes 2 arguments :obj:`inputs_ids` and the batch ID
387+
:obj:`batch_id`. It has to return a list with the allowed tokens for the next generation step
388+
conditioned on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. This
389+
argument is useful for constrained generation conditioned on the prefix, as described in
390+
`Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__.
378391
model_kwargs:
379392
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the
380393
model is an Encoder-Decoder model, encoder specific kwargs should not be prefixed and decoder specific
@@ -494,6 +507,8 @@ def generate(
494507
bad_words_ids=bad_words_ids,
495508
min_length=min_length,
496509
eos_token_id=eos_token_id,
510+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
511+
num_beams=num_beams,
497512
)
498513

499514
if is_greedy_gen_mode:

src/transformers/models/rag/modeling_rag.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""RAG model implementation."""
1616

1717
from dataclasses import dataclass
18-
from typing import List, Optional, Tuple
18+
from typing import Callable, List, Optional, Tuple
1919

2020
import torch
2121

@@ -1229,6 +1229,7 @@ def generate(
12291229
num_return_sequences=None,
12301230
decoder_start_token_id=None,
12311231
n_docs=None,
1232+
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
12321233
**model_kwargs
12331234
):
12341235
"""
@@ -1302,6 +1303,13 @@ def generate(
13021303
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
13031304
n_docs (:obj:`int`, `optional`, defaults to :obj:`config.n_docs`)
13041305
Number of documents to retrieve and/or number of documents for which to generate an answer.
1306+
prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`, `optional`):
1307+
If provided, this function constraints the beam search to allowed tokens only at each step. If not
1308+
provided no constraint is applied. This function takes 2 arguments :obj:`inputs_ids` and the batch ID
1309+
:obj:`batch_id`. It has to return a list with the allowed tokens for the next generation step
1310+
conditioned on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. This
1311+
argument is useful for constrained generation conditioned on the prefix, as described in
1312+
`Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__.
13051313
13061314
Return:
13071315
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
@@ -1395,6 +1403,8 @@ def extend_enc_output(tensor, num_beams=None):
13951403
bad_words_ids=bad_words_ids,
13961404
min_length=min_length,
13971405
eos_token_id=eos_token_id,
1406+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1407+
num_beams=num_beams,
13981408
)
13991409

14001410
if num_beams == 1:

tests/test_generation_logits_process.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
MinLengthLogitsProcessor,
3232
NoBadWordsLogitsProcessor,
3333
NoRepeatNGramLogitsProcessor,
34+
PrefixConstrainedLogitsProcessor,
3435
RepetitionPenaltyLogitsProcessor,
3536
TemperatureLogitsWarper,
3637
TopKLogitsWarper,
@@ -281,3 +282,23 @@ def test_processor_list(self):
281282

282283
# input_ids should never be changed
283284
self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist())
285+
286+
def test_prefix_constrained_logits_processor(self):
287+
vocab_size = 5
288+
batch_size = 2
289+
290+
input_ids = torch.tensor([[0, 1, 3, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long)
291+
scores = self._get_uniform_logits(batch_size, vocab_size)
292+
293+
def prefix_allowed_tokens_fn(batch_id, inputs_ids):
294+
return [[0, 1], [2, 3]][batch_id]
295+
296+
prefix_constrained_logits_proc = PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, 1)
297+
298+
filtered_scores = prefix_constrained_logits_proc(input_ids, scores.clone())
299+
300+
# batch 1: 1st, 2nd (0, 1) token are allowed
301+
# batch 2: 3rd, 4th (2, 3) token are allowed
302+
self.assertListEqual(
303+
torch.isinf(filtered_scores).tolist(), [[False, False, True, True, True], [True, True, False, False, True]]
304+
)

0 commit comments

Comments
 (0)