@@ -158,7 +158,6 @@ def __init__(
158158 self .seed = seed
159159
160160 self ._inverse_sqrt_key_dim = 1.0 / math .sqrt (float (self ._key_dim ))
161- self ._return_attention_scores = False
162161
163162 # Check for flash attention constraints
164163 if self ._flash_attention and self ._dropout > 0.0 :
@@ -419,6 +418,7 @@ def _compute_attention(
419418 value ,
420419 attention_mask = None ,
421420 training = None ,
421+ return_attention_scores = False ,
422422 ):
423423 """Applies Dot-product attention with query, key, value tensors.
424424
@@ -442,7 +442,7 @@ def _compute_attention(
442442 attention_scores: Multi-headed attention weights.
443443 """
444444 # Check for flash attention constraints
445- if self ._flash_attention and self . _return_attention_scores :
445+ if self ._flash_attention and return_attention_scores :
446446 raise ValueError (
447447 "Returning attention scores is not supported when flash "
448448 "attention is enabled. Please disable flash attention to access"
@@ -452,7 +452,7 @@ def _compute_attention(
452452 # Determine whether to use dot-product attention
453453 use_dot_product_attention = not (
454454 self ._dropout > 0.0
455- or self . _return_attention_scores
455+ or return_attention_scores
456456 or (len (query .shape ) != 4 )
457457 )
458458
@@ -525,7 +525,6 @@ def call(
525525 training = None ,
526526 use_causal_mask = False ,
527527 ):
528- self ._return_attention_scores = return_attention_scores
529528 if key is None :
530529 key = value
531530
@@ -562,6 +561,7 @@ def call(
562561 value ,
563562 attention_mask ,
564563 training ,
564+ return_attention_scores ,
565565 )
566566 attention_output = self ._output_dense (attention_output )
567567
0 commit comments