@@ -102,10 +102,11 @@ def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float =
102
102
Wqs , bqs , Wks , bks , Wvs , bvs , Wouts , bouts , Wlnattn_mu ,\
103
103
Wlnattn_scale , Whid1 , bhid1 , Wln_mu1 , Wln_scale1 , Whid2 ,\
104
104
bhid2 , Wln_mu2 , Wln_scale2 , Whid3 , bhid3 , Wln_mu3 , Wln_scale3 ,\
105
- Wy , by , ln_in_mu , ln_in_scale = params
105
+ Wy , by , ln_in_mu , ln_in_scale , ln_in_mu2 , ln_in_scale2 = params
106
106
cross_attn_params = (Wq , bq , Wk , bk , Wv , bv , Wout , bout )
107
107
if use_LN_input :
108
108
learnable_query = layer_normalize (learnable_query , ln_in_mu , ln_in_scale )
109
+ encodings = layer_normalize (encodings , ln_in_mu2 , ln_in_scale2 )
109
110
features = cross_attention (cross_attn_params , learnable_query , encodings , mask , n_heads , dropout )
110
111
# Perform a single self-attention block here
111
112
# Self-Attention
@@ -261,7 +262,9 @@ def __init__(
261
262
# Finally, define ln for the input to the attention
262
263
ln_in_mu = jnp .zeros ((1 , learnable_query_dim )) ## LN parameter
263
264
ln_in_scale = jnp .ones ((1 , learnable_query_dim )) ## LN parameter
264
- ln_in_params = (ln_in_mu , ln_in_scale )
265
+ ln_in_mu2 = jnp .zeros ((1 , input_dim )) ## LN parameter
266
+ ln_in_scale2 = jnp .ones ((1 , input_dim )) ## LN parameter
267
+ ln_in_params = (ln_in_mu , ln_in_scale , ln_in_mu2 , ln_in_scale2 )
265
268
self .probe_params = (learnable_query , * cross_attn_params , * self_attn_params , * mlp_params , * ln_in_params )
266
269
267
270
## set up gradient calculator
0 commit comments