Skip to content

Commit 057411c

Browse files
fix longformer slow down (#5811)
1 parent 89a78be commit 057411c

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/transformers/modeling_longformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def forward(
311311
# is index masked or global attention
312312
is_index_masked = attention_mask < 0
313313
is_index_global_attn = attention_mask > 0
314-
is_global_attn = any(is_index_global_attn.flatten())
314+
is_global_attn = is_index_global_attn.flatten().any().item()
315315

316316
hidden_states = hidden_states.transpose(0, 1)
317317

0 commit comments

Comments
 (0)