@@ -97,7 +97,7 @@ def __call__(
97
97
attn = jnp .where (mask == 0 , jnp .finfo (attn .dtype ).min , attn )
98
98
99
99
attn = jax .nn .softmax (attn , axis = - 1 )
100
- attn = self .attn_drop (attn , inference = inference , key = key1 )
100
+ attn = self .attn_drop (attn , inference = infeoence , key = key1 )
101
101
102
102
x = jnp .einsum ("hqk,hkd->hqd" , attn , v )
103
103
x = rearrange (x , "h s d -> s (h d)" )
@@ -1591,7 +1591,6 @@ class RFAttention(eqx.Module):
1591
1591
qkv : eqx .nn .Conv2d
1592
1592
aggreg : list [eqx .nn .Conv2d ]
1593
1593
proj : SingleConvBlock
1594
- attn_drop : eqx .nn .Dropout
1595
1594
proj_drop : eqx .nn .Dropout
1596
1595
1597
1596
def __init__ (
@@ -1606,7 +1605,6 @@ def __init__(
1606
1605
scales : Sequence [int ] = (5 ,),
1607
1606
use_bias : bool = False ,
1608
1607
kernel_func : Callable = jax .nn .relu ,
1609
- attn_drop : float = 0.0 ,
1610
1608
proj_drop : float = 0.0 ,
1611
1609
# TODO: Benchmark against LN, RMSN, NsLN
1612
1610
norm_layer : eqx .Module = eqx .nn .GroupNorm ,
@@ -1653,15 +1651,14 @@ def __init__(
1653
1651
key = key_proj ,
1654
1652
)
1655
1653
1656
- self .attn_drop = eqx .nn .Dropout (attn_drop )
1657
1654
self .proj_drop = eqx .nn .Dropout (proj_drop )
1658
1655
1659
1656
def __call__ (
1660
1657
self ,
1661
- x : Float [Array , "seqlen height width" ],
1658
+ x : Float [Array , "dim height width" ],
1662
1659
key : PRNGKeyArray ,
1663
1660
inference : Optional [bool ] = None ,
1664
- ) -> Float [Array , "seqlen height width" ]:
1661
+ ) -> Float [Array , "dim height width" ]:
1665
1662
qkv_base = self .qkv (x )
1666
1663
1667
1664
aggregated_qkvs = [op (qkv_base ) for op in self .aggreg ]
@@ -1684,6 +1681,7 @@ def __call__(
1684
1681
1685
1682
out = (q * sum_kv ) / (sum_q * sum_k + self .eps )
1686
1683
out = self .proj (out )
1684
+ out = self .proj_drop (out , inference = inference , key = key )
1687
1685
1688
1686
return out
1689
1687
0 commit comments