diff --git a/block_recurrent_transformer_pytorch/block_recurrent_transformer_pytorch.py b/block_recurrent_transformer_pytorch/block_recurrent_transformer_pytorch.py index daa5a78..c4e9109 100644 --- a/block_recurrent_transformer_pytorch/block_recurrent_transformer_pytorch.py +++ b/block_recurrent_transformer_pytorch/block_recurrent_transformer_pytorch.py @@ -330,7 +330,10 @@ def __init__( self.state_to_kv = nn.Linear(dim, dim_head * 2, bias = False) self.init_state = nn.Parameter(torch.randn(num_state_vectors, dim)) + torch.nn.init.normal_(self.init_state, 0, .1) self.state_pos_ids = nn.Parameter(torch.randn(num_state_vectors, dim)) + # NOTE: the state position id embeddings are drawn from N(0,1) since they are added after a layer norm + torch.nn.init.normal_(self.state_pos_ids, 0, 1) self.to_state_out = nn.Linear(inner_dim * 2, dim, bias = False) @@ -343,6 +346,7 @@ def __init__( self.state_out_to_gate = nn.Linear(dim, dim) self.learned_ema_beta = nn.Parameter(torch.randn(dim)) + torch.nn.init.normal_(self.learned_ema_beta, 0, .1) # since each read should be followed by a write, just store cache in the container