Skip to content

Commit

Permalink
Merge pull request #7 from psoulos/patch-1
Browse files Browse the repository at this point in the history
Update the initialisation of init_state and learned_ema_beta to match…
  • Loading branch information
lucidrains authored Aug 20, 2024
2 parents b22f552 + e19918f commit 64838b2
Showing 1 changed file with 4 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,10 @@ def __init__(
self.state_to_kv = nn.Linear(dim, dim_head * 2, bias = False)

self.init_state = nn.Parameter(torch.randn(num_state_vectors, dim))
torch.nn.init.normal_(self.init_state, 0, .1)
self.state_pos_ids = nn.Parameter(torch.randn(num_state_vectors, dim))
# NOTE: the state position id embeddings are drawn from N(0,1) since they are added after a layer norm
torch.nn.init.normal_(self.state_pos_ids, 0, 1)

self.to_state_out = nn.Linear(inner_dim * 2, dim, bias = False)

Expand All @@ -343,6 +346,7 @@ def __init__(

self.state_out_to_gate = nn.Linear(dim, dim)
self.learned_ema_beta = nn.Parameter(torch.randn(dim))
torch.nn.init.normal_(self.learned_ema_beta, 0, .1)

# since each read should be followed by a write, just store cache in the container

Expand Down

0 comments on commit 64838b2

Please sign in to comment.