Skip to content

Commit 7e9b93e

Browse files
committed
add hidden size
1 parent c9645c8 commit 7e9b93e

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

models/LSTM.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,18 @@ def __init__(self, vocab_len, config):
1212
super().__init__()
1313
self.num_labels = config.num_labels
1414
self.embed = nn.Embedding(num_embeddings=vocab_len, embedding_dim=config.embed_dim)
15-
self.dropout = nn.Dropout(config.dropout_rate)
16-
self.lstm = nn.LSTM(input_size=config.embed_dim, hidden_size=config.embed_dim, batch_first=True, bidirectional=True)
17-
self.layer_norm = nn.LayerNorm(config.embed_dim * 2)
18-
self.classifier = nn.Linear(config.embed_dim * 2, config.num_labels)
15+
self.lstm = nn.LSTM(input_size=config.embed_dim, hidden_size=config.hidden_size, batch_first=True, bidirectional=True)
16+
self.classifier = nn.Linear(config.hidden_size * 2, config.num_labels)
1917
self.crf = CRF(num_tags=config.num_labels, batch_first=True)
2018
self.loss_fct = nn.CrossEntropyLoss()
2119

2220
def forward(self, word_ids, label_ids=None, label_mask=None, use_crf=True):
2321
word_embed = self.embed(word_ids)
24-
word_embed = self.dropout(word_embed)
2522
sequence_output, _ = self.lstm(word_embed)
26-
sequence_output = self.layer_norm(sequence_output)
2723
logits = self.classifier(sequence_output)
2824
if label_ids != None:
2925
if use_crf:
30-
loss = self.crf(logits, label_ids)
26+
loss = self.crf(logits, label_ids, label_mask)
3127
loss = -1 * loss
3228
else:
3329
active_logits = logits.view(-1, self.num_labels)

0 commit comments

Comments
 (0)