diff --git a/tests/attention_test.py b/tests/attention_test.py index 5aae3a6..5c915e9 100644 --- a/tests/attention_test.py +++ b/tests/attention_test.py @@ -119,3 +119,81 @@ def test_score_mod(self): ) self.assertEqual(output.shape, (batch_size, num_heads, seq_len_q, feature_size)) + + def test_autograd(self): + # Prepare inputs + batch_size = 4 + num_heads = 8 + seq_len_q = 64 + seq_len_kv = 64 + feature_size = 32 + + # Random tensors for query, key, and value + key = jax.random.normal( + jax.random.PRNGKey(0), (batch_size, num_heads, seq_len_kv, feature_size) + ) + query = jax.random.normal( + jax.random.PRNGKey(1), (batch_size, num_heads, seq_len_q, feature_size) + ) + value = jax.random.normal( + jax.random.PRNGKey(2), (batch_size, num_heads, seq_len_kv, feature_size) + ) + + def fn(query, key, value): + return flax_attention( + query, + key, + value, + ).sum() + + grad_fn = jax.grad(fn, 0) + grad = grad_fn(query, key, value) + + self.assertEqual(grad.shape, (batch_size, num_heads, seq_len_q, feature_size)) + + def test_autograd_equivalence_with_torch(self): + # Prepare inputs + batch_size = 4 + num_heads = 8 + seq_len_q = 64 + seq_len_kv = 64 + feature_size = 32 + + # Random tensors for query, key, and value + key = jax.random.normal( + jax.random.PRNGKey(0), (batch_size, num_heads, seq_len_kv, feature_size) + ) + query = jax.random.normal( + jax.random.PRNGKey(1), (batch_size, num_heads, seq_len_q, feature_size) + ) + value = jax.random.normal( + jax.random.PRNGKey(2), (batch_size, num_heads, seq_len_kv, feature_size) + ) + + def fn(query, key, value): + return flax_attention( + query, + key, + value, + ).sum() + + grad_fn = jax.grad(fn, 0) + grad_jax = grad_fn(query, key, value) + + query_torch = jax2torch(query) + key_torch = jax2torch(key) + value_torch = jax2torch(value) + + query_torch.requires_grad = True + + output_torch = flex_attention( + query_torch, + key_torch, + value_torch, + ).sum() + + output_torch.backward() + + grad_torch = query_torch.grad.cpu().numpy() + + np.testing.assert_almost_equal(grad_jax, grad_torch, decimal=3) \ No newline at end of file