Skip to content

Commit 49cd46c

Browse files
authored
Merge pull request #6 from nitinsurya/main
GLiNER#263 Fix get_attention_mask function that allows for batched in…
2 parents 42de6f8 + 394604d commit 49cd46c

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

src/flashdeberta/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,12 @@ def __init__(self, config):
301301
self.gradient_checkpointing = False
302302

303303
def get_attention_mask(self, attention_mask):
304+
if attention_mask.dim() <= 2:
305+
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
306+
attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
307+
elif attention_mask.dim() == 3:
308+
attention_mask = attention_mask.unsqueeze(1)
309+
304310
return attention_mask
305311

306312
class FlashDebertaV2PreTrainedModel(PreTrainedModel):

0 commit comments

Comments
 (0)