Skip to content

Commit 41b7cfe

Browse files
committed
feat(reduceformer): proj drop
1 parent 2bd100a commit 41b7cfe

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

src/equimo/layers/attention.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __call__(
9797
attn = jnp.where(mask == 0, jnp.finfo(attn.dtype).min, attn)
9898

9999
attn = jax.nn.softmax(attn, axis=-1)
100-
attn = self.attn_drop(attn, inference=inference, key=key1)
100+
attn = self.attn_drop(attn, inference=infeoence, key=key1)
101101

102102
x = jnp.einsum("hqk,hkd->hqd", attn, v)
103103
x = rearrange(x, "h s d -> s (h d)")
@@ -1591,7 +1591,6 @@ class RFAttention(eqx.Module):
15911591
qkv: eqx.nn.Conv2d
15921592
aggreg: list[eqx.nn.Conv2d]
15931593
proj: SingleConvBlock
1594-
attn_drop: eqx.nn.Dropout
15951594
proj_drop: eqx.nn.Dropout
15961595

15971596
def __init__(
@@ -1606,7 +1605,6 @@ def __init__(
16061605
scales: Sequence[int] = (5,),
16071606
use_bias: bool = False,
16081607
kernel_func: Callable = jax.nn.relu,
1609-
attn_drop: float = 0.0,
16101608
proj_drop: float = 0.0,
16111609
# TODO: Benchmark against LN, RMSN, NsLN
16121610
norm_layer: eqx.Module = eqx.nn.GroupNorm,
@@ -1653,15 +1651,14 @@ def __init__(
16531651
key=key_proj,
16541652
)
16551653

1656-
self.attn_drop = eqx.nn.Dropout(attn_drop)
16571654
self.proj_drop = eqx.nn.Dropout(proj_drop)
16581655

16591656
def __call__(
16601657
self,
1661-
x: Float[Array, "seqlen height width"],
1658+
x: Float[Array, "dim height width"],
16621659
key: PRNGKeyArray,
16631660
inference: Optional[bool] = None,
1664-
) -> Float[Array, "seqlen height width"]:
1661+
) -> Float[Array, "dim height width"]:
16651662
qkv_base = self.qkv(x)
16661663

16671664
aggregated_qkvs = [op(qkv_base) for op in self.aggreg]
@@ -1684,6 +1681,7 @@ def __call__(
16841681

16851682
out = (q * sum_kv) / (sum_q * sum_k + self.eps)
16861683
out = self.proj(out)
1684+
out = self.proj_drop(out, inference=inference, key=key)
16871685

16881686
return out
16891687

0 commit comments

Comments
 (0)