Skip to content

Commit 089f8ae

Browse files
authored
support 2D attention mask (PaddlePaddle#1226)
1 parent 0a338cc commit 089f8ae

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

paddlenlp/transformers/bert/modeling.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,10 @@ def forward(self,
495495
(input_ids == self.pad_token_id
496496
).astype(self.pooler.dense.weight.dtype) * -1e9,
497497
axis=[1, 2])
498+
else:
499+
if attention_mask.ndim == 2:
500+
# attention_mask [batch_size, sequence_length] -> [batch_size, 1, 1, sequence_length]
501+
attention_mask = attention_mask.unsqueeze(axis=[1, 2])
498502

499503
embedding_output = self.embeddings(
500504
input_ids=input_ids,

0 commit comments

Comments
 (0)