Skip to content

Commit 16b821c

Browse files
authored
Avoid T5GemmaModelTest::test_eager_matches_sdpa_inference being flaky (#40702)
fix Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
1 parent 519c252 commit 16b821c

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

tests/models/t5gemma/test_modeling_t5gemma.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,15 @@ def prepare_config_and_inputs(self):
202202
input_ids = torch.where(input_ids == self.bos_token_id, 42, input_ids)
203203
decoder_input_ids = torch.where(decoder_input_ids == self.bos_token_id, 42, decoder_input_ids)
204204

205+
# Avoid leading PAD tokens from inputs.
206+
# `T5GemmaForTokenClassification` and `T5GemmaForSequenceClassification` specify `use_cache=False` when
207+
# calling `self.model`. For `self.use_attention_mask=False` case below, the model goes through
208+
# `make_default_2d_attention_mask`. When there are some pad tokens at the beginning of a sequence, it can't
209+
# attend to any place, and the computed mask `[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38]`
210+
# causes larger differences in some equivalence tests.
211+
# Let's avoid such leading PAD tokens.
212+
decoder_input_ids[:, 0] = self.pad_token_id + 1
213+
205214
attention_mask = None
206215
decoder_attention_mask = None
207216
if self.use_attention_mask:

0 commit comments

Comments
 (0)