Skip to content

Commit

Permalink
[Generate Test] fix greedy generate test (#8293)
Browse files Browse the repository at this point in the history
* fix greedy generate test

* delet ipdb
  • Loading branch information
patrickvonplaten authored Nov 4, 2020
1 parent 734afa3 commit cb966e6
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions tests/test_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,6 @@ def test_greedy_generate(self):
# check `generate()` and `greedy_search()` are equal
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model, input_ids, attention_mask
)
kwargs["encoder_outputs"] = encoder_outputs
max_length = 4

output_ids_generate = model.generate(
Expand All @@ -154,6 +150,13 @@ def test_greedy_generate(self):
max_length=max_length,
**logits_process_kwargs,
)

if model.config.is_encoder_decoder:
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model, input_ids, attention_mask
)
kwargs["encoder_outputs"] = encoder_outputs

with torch.no_grad():
output_ids_greedy = model.greedy_search(
input_ids,
Expand Down

0 comments on commit cb966e6

Please sign in to comment.