Skip to content

Commit

Permalink
add attn_mask input for encoder-decoder (#1431)
Browse files Browse the repository at this point in the history
  • Loading branch information
smallv0221 authored Dec 10, 2021
1 parent 10ac335 commit e468e19
Showing 1 changed file with 4 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1103,7 +1103,8 @@ def forward(self,
self.encoder = enable_faster_encoder(self.encoder, need_build=False)
if encoder_output is None:
assert input_ids is not None, "You have to specify either input_ids or encoder_output."
encoder_output = self.encoder(input_ids)
encoder_output = self.prepare_encoder_decoder_kwargs_for_generation(
input_ids, model_kwargs)["encoder_output"]
self.encoder = disable_faster_encoder(self.encoder)
if seq_len is None:
assert input_ids is not None, "You have to specify either input_ids when generating seq_len."
Expand Down Expand Up @@ -1206,7 +1207,8 @@ def forward(self,
self.encoder = enable_faster_encoder(self.encoder, need_build=False)
if encoder_output is None:
assert input_ids is not None, "You have to specify either input_ids or encoder_output."
encoder_output = self.encoder(input_ids)
encoder_output = self.prepare_encoder_decoder_kwargs_for_generation(
input_ids, model_kwargs)["encoder_output"]
self.encoder = disable_faster_encoder(self.encoder)
batch_size = paddle.shape(encoder_output)[0]
if seq_len is None:
Expand Down

0 comments on commit e468e19

Please sign in to comment.