We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents 42de6f8 + 394604d commit 49cd46cCopy full SHA for 49cd46c
src/flashdeberta/model.py
@@ -301,6 +301,12 @@ def __init__(self, config):
301
self.gradient_checkpointing = False
302
303
def get_attention_mask(self, attention_mask):
304
+ if attention_mask.dim() <= 2:
305
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
306
+ attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
307
+ elif attention_mask.dim() == 3:
308
+ attention_mask = attention_mask.unsqueeze(1)
309
+
310
return attention_mask
311
312
class FlashDebertaV2PreTrainedModel(PreTrainedModel):
0 commit comments