Skip to content

Commit 18f79d6

Browse files
authored
Fix negative index handling in MultiHeadAttention attention_axes (#21721)
* Simplify save_img: remove _format, normalize jpg→jpeg, add RGBA→RGB handling and tests * Fix negative index handling in MultiHeadAttention attention_axes * Fix negative index handling in MultiHeadAttention attention_axes * Fix negative index handling in MultiHeadAttention attention_axes * Fix negative index handling in MultiHeadAttention attention_axes * Fix negative index handling in MultiHeadAttention attention_axes
1 parent 18e0364 commit 18f79d6

File tree

3 files changed

+35
-1
lines changed

3 files changed

+35
-1
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ __pycache__
66
**/.vscode test/**
77
**/.vscode-smoke/**
88
**/.venv*/
9+
venv
910
bin/**
1011
build/**
1112
obj/**

keras/src/layers/attention/multi_head_attention.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,10 @@ def _build_attention(self, rank):
378378
if self._attention_axes is None:
379379
self._attention_axes = tuple(range(1, rank - 2))
380380
else:
381-
self._attention_axes = tuple(self._attention_axes)
381+
self._attention_axes = tuple(
382+
axis if axis >= 0 else (rank - 1) + axis
383+
for axis in self._attention_axes
384+
)
382385
(
383386
self._dot_product_equation,
384387
self._combine_equation,

keras/src/layers/attention/multi_head_attention_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,36 @@ def test_high_dim_attention(
203203
run_training_check=False,
204204
)
205205

206+
def test_attention_axes_negative_indexing(self):
207+
x = np.random.normal(size=(2, 3, 8, 4))
208+
209+
# Create two layers with equivalent positive and negative indices
210+
mha_pos = layers.MultiHeadAttention(
211+
num_heads=2, key_dim=4, attention_axes=2
212+
)
213+
mha_neg = layers.MultiHeadAttention(
214+
num_heads=2, key_dim=4, attention_axes=-2
215+
)
216+
217+
# Initialize both layers
218+
_ = mha_pos(x, x)
219+
_ = mha_neg(x, x)
220+
221+
# Set same weights for fair comparison
222+
mha_neg.set_weights(mha_pos.get_weights())
223+
224+
# Get outputs and attention scores
225+
z_pos, a_pos = mha_pos(x, x, return_attention_scores=True)
226+
z_neg, a_neg = mha_neg(x, x, return_attention_scores=True)
227+
228+
# Verify shapes match
229+
self.assertEqual(z_pos.shape, z_neg.shape)
230+
self.assertEqual(a_pos.shape, a_neg.shape)
231+
232+
# Verify outputs are identical
233+
self.assertAllClose(z_pos, z_neg, rtol=1e-5, atol=1e-5)
234+
self.assertAllClose(a_pos, a_neg, rtol=1e-5, atol=1e-5)
235+
206236
@parameterized.named_parameters(
207237
("without_key_same_proj", (4, 8), (2, 8), None, None),
208238
("with_key_same_proj", (4, 8), (2, 8), (2, 3), None),

0 commit comments

Comments
 (0)