Skip to content

Commit 9d7acbb

Browse files
committed
update documentation
1 parent 84237ff commit 9d7acbb

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

ngclearn/utils/analysis/attentive_probe.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,25 @@ def masked_fill(x: jax.Array, mask: jax.Array, value=0) -> jax.Array:
2222
return jnp.where(mask, jnp.broadcast_to(value, x.shape), x)
2323

2424
@bind(jax.jit, static_argnums=[4, 5])
25-
def cross_attention(params: tuple, x1: jax.Array, x2: jax.Array, mask: jax.Array, n_heads: int=8, dropout_rate: float=0.0):
25+
def cross_attention(params: tuple, x1: jax.Array, x2: jax.Array, mask: jax.Array, n_heads: int=8, dropout_rate: float=0.0) -> jax.Array:
26+
"""
27+
Run cross-attention function given a list of parameters and two sequences (x1 and x2).
28+
The function takes in a query sequence x1 and a key-value sequence x2, and returns an output of the same shape as x1.
29+
T is the length of the query sequence, and S is the length of the key-value sequence.
30+
Dq is the dimension of the query sequence, and Dkv is the dimension of the key-value sequence.
31+
H is the number of attention heads.
32+
33+
Args:
34+
params (tuple): tuple of parameters
35+
x1 (jax.Array): query sequence. Shape: (B, T, Dq)
36+
x2 (jax.Array): key-value sequence. Shape: (B, S, Dkv)
37+
mask (jax.Array): mask tensor. Shape: (B, T, S)
38+
n_heads (int, optional): number of attention heads. Defaults to 8.
39+
dropout_rate (float, optional): dropout rate. Defaults to 0.0.
40+
41+
Returns:
42+
jax.Array: output of cross-attention
43+
"""
2644
B, T, Dq = x1.shape # The original shape
2745
_, S, Dkv = x2.shape
2846
# in here we attend x2 to x1

0 commit comments

Comments
 (0)