@@ -12,22 +12,18 @@ def __init__(self, vocab_len, config):
12
12
super ().__init__ ()
13
13
self .num_labels = config .num_labels
14
14
self .embed = nn .Embedding (num_embeddings = vocab_len , embedding_dim = config .embed_dim )
15
- self .dropout = nn .Dropout (config .dropout_rate )
16
- self .lstm = nn .LSTM (input_size = config .embed_dim , hidden_size = config .embed_dim , batch_first = True , bidirectional = True )
17
- self .layer_norm = nn .LayerNorm (config .embed_dim * 2 )
18
- self .classifier = nn .Linear (config .embed_dim * 2 , config .num_labels )
15
+ self .lstm = nn .LSTM (input_size = config .embed_dim , hidden_size = config .hidden_size , batch_first = True , bidirectional = True )
16
+ self .classifier = nn .Linear (config .hidden_size * 2 , config .num_labels )
19
17
self .crf = CRF (num_tags = config .num_labels , batch_first = True )
20
18
self .loss_fct = nn .CrossEntropyLoss ()
21
19
22
20
def forward (self , word_ids , label_ids = None , label_mask = None , use_crf = True ):
23
21
word_embed = self .embed (word_ids )
24
- word_embed = self .dropout (word_embed )
25
22
sequence_output , _ = self .lstm (word_embed )
26
- sequence_output = self .layer_norm (sequence_output )
27
23
logits = self .classifier (sequence_output )
28
24
if label_ids != None :
29
25
if use_crf :
30
- loss = self .crf (logits , label_ids )
26
+ loss = self .crf (logits , label_ids , label_mask )
31
27
loss = - 1 * loss
32
28
else :
33
29
active_logits = logits .view (- 1 , self .num_labels )
0 commit comments