Skip to content

Commit 9d05459

Browse files
zucchini-nlpArthurZucker
authored andcommitted
Generation: get special tokens from model config (#30899)
* fix * let's do this way? * codestyle * update * add tests
1 parent e5d174f commit 9d05459

File tree

2 files changed

+53
-1
lines changed

2 files changed

+53
-1
lines changed

src/transformers/generation/utils.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1354,6 +1354,23 @@ def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCa
13541354
self._static_cache.reset() # reset the cache for a new generation
13551355
return self._static_cache
13561356

1357+
def _get_decoder_start_token_id(
1358+
self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None
1359+
) -> int:
1360+
decoder_start_token_id = (
1361+
decoder_start_token_id
1362+
if decoder_start_token_id is not None
1363+
else self.generation_config.decoder_start_token_id
1364+
)
1365+
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
1366+
1367+
if decoder_start_token_id is not None:
1368+
return decoder_start_token_id
1369+
elif bos_token_id is not None:
1370+
return bos_token_id
1371+
else:
1372+
return
1373+
13571374
def _prepare_special_tokens(
13581375
self,
13591376
generation_config: GenerationConfig,
@@ -1378,11 +1395,16 @@ def _tensor_or_none(token, device=None):
13781395
return token
13791396
return torch.tensor(token, device=device, dtype=torch.long)
13801397

1398+
# for BC we also try to get `decoder_start_token_id` from model's generation config (#30892)
1399+
if self.config.is_encoder_decoder:
1400+
generation_config.decoder_start_token_id = self._get_decoder_start_token_id(
1401+
generation_config.decoder_start_token_id, generation_config.bos_token_id
1402+
)
1403+
13811404
bos_token_id = _tensor_or_none(generation_config.bos_token_id, device=device)
13821405
eos_token_id = _tensor_or_none(generation_config.eos_token_id, device=device)
13831406
pad_token_id = _tensor_or_none(generation_config.pad_token_id, device=device)
13841407
decoder_start_token_id = _tensor_or_none(generation_config.decoder_start_token_id, device=device)
1385-
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
13861408

13871409
# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
13881410
if eos_token_id is not None and eos_token_id.ndim == 0:

tests/generation/test_utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
GenerateBeamEncoderDecoderOutput,
6666
GenerateDecoderOnlyOutput,
6767
GenerateEncoderDecoderOutput,
68+
GenerationConfig,
6869
GreedySearchDecoderOnlyOutput,
6970
GreedySearchEncoderDecoderOutput,
7071
LogitsProcessorList,
@@ -2478,6 +2479,35 @@ def test_batched_decoder_start_id(self):
24782479

24792480
self.assertListEqual(outputs.tolist(), outputs_batched_ids.tolist())
24802481

2482+
def test_decoder_start_id_from_config(self):
2483+
# Refer to: (#30899)
2484+
articles = [
2485+
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
2486+
"Michael Phelps is arguably the most decorated Olympian of all time.",
2487+
]
2488+
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
2489+
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
2490+
torch_device
2491+
)
2492+
input_ids = bart_tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
2493+
decoder_start_token_id = bart_model.generation_config.decoder_start_token_id
2494+
2495+
# we should be able to take `decoder_start_token_id` from model's generation config if user passes a `GenerationConfig` type
2496+
outputs = bart_model.generate(input_ids, generation_config=GenerationConfig(do_sample=False))
2497+
2498+
# If the generatoin config has no `decoder_start_token_id` or `bos_token_id`, we will raise an error unless user passes it in config
2499+
bart_model.generation_config.decoder_start_token_id = None
2500+
bart_model.generation_config.bos_token_id = None
2501+
outputs_with_user_id = bart_model.generate(
2502+
input_ids,
2503+
generation_config=GenerationConfig(do_sample=False, decoder_start_token_id=decoder_start_token_id),
2504+
)
2505+
2506+
self.assertListEqual(outputs.tolist(), outputs_with_user_id.tolist())
2507+
2508+
with self.assertRaises(ValueError):
2509+
outputs = bart_model.generate(input_ids, generation_config=GenerationConfig(do_sample=False))
2510+
24812511
def test_contrastive_search_batched(self):
24822512
# PT-only test: TF doesn't have constrained beam search
24832513
# Tests that contrastive search works with batched inputs (i.e. has the same output as for non-batched inputs)

0 commit comments

Comments
 (0)