diff --git a/onmt/modules/Transformer.py b/onmt/modules/Transformer.py index 6e3b3dbb..77b07cab 100644 --- a/onmt/modules/Transformer.py +++ b/onmt/modules/Transformer.py @@ -29,8 +29,8 @@ def __init__(self, size, hidden_size, dropout=0.1): self.w_1 = nn.Linear(size, hidden_size) self.w_2 = nn.Linear(hidden_size, size) self.layer_norm = onmt.modules.LayerNorm(size) + self.dropout_1 = nn.Dropout(dropout) # Save a little memory, by doing inplace. - self.dropout_1 = nn.Dropout(dropout, inplace=True) self.relu = nn.ReLU(inplace=True) self.dropout_2 = nn.Dropout(dropout)