|
65 | 65 | GenerateBeamEncoderDecoderOutput, |
66 | 66 | GenerateDecoderOnlyOutput, |
67 | 67 | GenerateEncoderDecoderOutput, |
| 68 | + GenerationConfig, |
68 | 69 | GreedySearchDecoderOnlyOutput, |
69 | 70 | GreedySearchEncoderDecoderOutput, |
70 | 71 | LogitsProcessorList, |
@@ -2478,6 +2479,35 @@ def test_batched_decoder_start_id(self): |
2478 | 2479 |
|
2479 | 2480 | self.assertListEqual(outputs.tolist(), outputs_batched_ids.tolist()) |
2480 | 2481 |
|
| 2482 | + def test_decoder_start_id_from_config(self): |
| 2483 | + # Refer to: (#30899) |
| 2484 | + articles = [ |
| 2485 | + "Justin Timberlake and Jessica Biel, welcome to parenthood.", |
| 2486 | + "Michael Phelps is arguably the most decorated Olympian of all time.", |
| 2487 | + ] |
| 2488 | + bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") |
| 2489 | + bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( |
| 2490 | + torch_device |
| 2491 | + ) |
| 2492 | + input_ids = bart_tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device) |
| 2493 | + decoder_start_token_id = bart_model.generation_config.decoder_start_token_id |
| 2494 | + |
| 2495 | + # we should be able to take `decoder_start_token_id` from model's generation config if user passes a `GenerationConfig` type |
| 2496 | + outputs = bart_model.generate(input_ids, generation_config=GenerationConfig(do_sample=False)) |
| 2497 | + |
| 2498 | + # If the generatoin config has no `decoder_start_token_id` or `bos_token_id`, we will raise an error unless user passes it in config |
| 2499 | + bart_model.generation_config.decoder_start_token_id = None |
| 2500 | + bart_model.generation_config.bos_token_id = None |
| 2501 | + outputs_with_user_id = bart_model.generate( |
| 2502 | + input_ids, |
| 2503 | + generation_config=GenerationConfig(do_sample=False, decoder_start_token_id=decoder_start_token_id), |
| 2504 | + ) |
| 2505 | + |
| 2506 | + self.assertListEqual(outputs.tolist(), outputs_with_user_id.tolist()) |
| 2507 | + |
| 2508 | + with self.assertRaises(ValueError): |
| 2509 | + outputs = bart_model.generate(input_ids, generation_config=GenerationConfig(do_sample=False)) |
| 2510 | + |
2481 | 2511 | def test_contrastive_search_batched(self): |
2482 | 2512 | # PT-only test: TF doesn't have constrained beam search |
2483 | 2513 | # Tests that contrastive search works with batched inputs (i.e. has the same output as for non-batched inputs) |
|
0 commit comments