From ea79faef45a9e91b4c4910e01dfd1a6889af5b2c Mon Sep 17 00:00:00 2001 From: Hehe Fan Date: Wed, 30 Mar 2022 23:10:12 +0800 Subject: [PATCH] Update transformer.py --- modules-pytorch-1.8.1/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules-pytorch-1.8.1/transformer.py b/modules-pytorch-1.8.1/transformer.py index 621fc7a..0636a4e 100644 --- a/modules-pytorch-1.8.1/transformer.py +++ b/modules-pytorch-1.8.1/transformer.py @@ -70,7 +70,7 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ - Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))), + Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = 0.))), Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) ])) def forward(self, x):