Skip to content

Commit 8682954

Browse files
committed
update input layer normalization
1 parent aeabf61 commit 8682954

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

ngclearn/utils/analysis/attentive_probe.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,11 @@ def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float =
102102
Wqs, bqs, Wks, bks, Wvs, bvs, Wouts, bouts, Wlnattn_mu,\
103103
Wlnattn_scale, Whid1, bhid1, Wln_mu1, Wln_scale1, Whid2,\
104104
bhid2, Wln_mu2, Wln_scale2, Whid3, bhid3, Wln_mu3, Wln_scale3,\
105-
Wy, by, ln_in_mu, ln_in_scale = params
105+
Wy, by, ln_in_mu, ln_in_scale, ln_in_mu2, ln_in_scale2 = params
106106
cross_attn_params = (Wq, bq, Wk, bk, Wv, bv, Wout, bout)
107107
if use_LN_input:
108108
learnable_query = layer_normalize(learnable_query, ln_in_mu, ln_in_scale)
109+
encodings = layer_normalize(encodings, ln_in_mu2, ln_in_scale2)
109110
features = cross_attention(cross_attn_params, learnable_query, encodings, mask, n_heads, dropout)
110111
# Perform a single self-attention block here
111112
# Self-Attention
@@ -261,7 +262,9 @@ def __init__(
261262
# Finally, define ln for the input to the attention
262263
ln_in_mu = jnp.zeros((1, learnable_query_dim)) ## LN parameter
263264
ln_in_scale = jnp.ones((1, learnable_query_dim)) ## LN parameter
264-
ln_in_params = (ln_in_mu, ln_in_scale)
265+
ln_in_mu2 = jnp.zeros((1, input_dim)) ## LN parameter
266+
ln_in_scale2 = jnp.ones((1, input_dim)) ## LN parameter
267+
ln_in_params = (ln_in_mu, ln_in_scale, ln_in_mu2, ln_in_scale2)
265268
self.probe_params = (learnable_query, *cross_attn_params, *self_attn_params, *mlp_params, *ln_in_params)
266269

267270
## set up gradient calculator

0 commit comments

Comments
 (0)