Skip to content

Commit

Permalink
Update models.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dapowan authored Oct 16, 2021
1 parent 3e248e1 commit 883a65f
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class LIMUBertModel4Pretrain(nn.Module):

def __init__(self, cfg, output_embed=False):
super().__init__()
self.transformer = Transformer(cfg)
self.transformer = Transformer(cfg) # encoder
self.fc = nn.Linear(cfg.hidden, cfg.hidden)
self.linear = nn.Linear(cfg.hidden, cfg.hidden)
self.activ = gelu
Expand Down Expand Up @@ -209,9 +209,8 @@ def forward(self, input_seqs, training=False):
lstm = self.__getattr__('lstm' + str(i))
bn = self.__getattr__('bn' + str(i))
h, _ = lstm(h)
# if self.activ:
# h = F.relu(h)
# h = bn(h)
if self.activ:
h = F.relu(h)
h = h[:, -1, :]
if self.dropout:
h = F.dropout(h, training=training)
Expand Down Expand Up @@ -593,4 +592,4 @@ def fetch_classifier(method, model_cfg, input=None, output=None, feats=False):
model = ClassifierAttn(model_cfg, input=input, output=output)
else:
model = None
return model
return model

0 comments on commit 883a65f

Please sign in to comment.