@@ -22,7 +22,25 @@ def masked_fill(x: jax.Array, mask: jax.Array, value=0) -> jax.Array:
22
22
return jnp .where (mask , jnp .broadcast_to (value , x .shape ), x )
23
23
24
24
@bind (jax .jit , static_argnums = [4 , 5 ])
25
- def cross_attention (params : tuple , x1 : jax .Array , x2 : jax .Array , mask : jax .Array , n_heads : int = 8 , dropout_rate : float = 0.0 ):
25
+ def cross_attention (params : tuple , x1 : jax .Array , x2 : jax .Array , mask : jax .Array , n_heads : int = 8 , dropout_rate : float = 0.0 ) -> jax .Array :
26
+ """
27
+ Run cross-attention function given a list of parameters and two sequences (x1 and x2).
28
+ The function takes in a query sequence x1 and a key-value sequence x2, and returns an output of the same shape as x1.
29
+ T is the length of the query sequence, and S is the length of the key-value sequence.
30
+ Dq is the dimension of the query sequence, and Dkv is the dimension of the key-value sequence.
31
+ H is the number of attention heads.
32
+
33
+ Args:
34
+ params (tuple): tuple of parameters
35
+ x1 (jax.Array): query sequence. Shape: (B, T, Dq)
36
+ x2 (jax.Array): key-value sequence. Shape: (B, S, Dkv)
37
+ mask (jax.Array): mask tensor. Shape: (B, T, S)
38
+ n_heads (int, optional): number of attention heads. Defaults to 8.
39
+ dropout_rate (float, optional): dropout rate. Defaults to 0.0.
40
+
41
+ Returns:
42
+ jax.Array: output of cross-attention
43
+ """
26
44
B , T , Dq = x1 .shape # The original shape
27
45
_ , S , Dkv = x2 .shape
28
46
# in here we attend x2 to x1
0 commit comments