forked from utahnlp/consistency
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlinear_classifier.py
52 lines (31 loc) · 854 Bytes
/
linear_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import sys
import torch
from torch import nn
from torch.autograd import Variable
from holder import *
from util import *
from locked_dropout import *
# linear classifier
class LinearClassifier(torch.nn.Module):
def __init__(self, opt, shared):
super(LinearClassifier, self).__init__()
self.opt = opt
self.shared = shared
# weights will be initialized later
self.linear = nn.Sequential(
#nn.Dropout(opt.dropout),
nn.Linear(opt.hidden_size, opt.num_label))
self.fp16 = opt.fp16 == 1
def forward(self, concated):
batch_l, concated_l, enc_size = concated.shape
head = concated[:, 0, :]
scores = self.linear(head) # (batch_l, num_label)
log_p = nn.LogSoftmax(1)(scores)
self.shared.y_scores = scores
return log_p
def begin_pass(self):
pass
def end_pass(self):
pass
if __name__ == '__main__':
pass