|
4 | 4 | from torch.nn.utils.rnn import pad_packed_sequence as unpack |
5 | 5 | import modules.data_loader as data_loader |
6 | 6 |
|
| 7 | +class Attention(nn.Module): |
| 8 | + |
| 9 | + def __init__(self, hidden_size): |
| 10 | + super(Attention, self).__init__() |
| 11 | + |
| 12 | + self.linear = nn.Linear(hidden_size, hidden_size, bias=False) |
| 13 | + self.softmax = nn.Softmax(dim=-1) |
| 14 | + |
| 15 | + def forward(self, h_src, h_t_tgt, mask=None): |
| 16 | + ''' |
| 17 | + Encoder 단의 원문 벡터값과 현재 값의 유사도 산출 |
| 18 | + h_src: Encoder 전체 time-step의 히든 스테이트 출력값 |
| 19 | + h_t_tgt: Decoder의 현재 time-step 히든 스테이트 출력값 |
| 20 | + mask: 각 문장 각 토큰에 대한 PAD 여부 (True or False) |
| 21 | + ''' |
| 22 | + # h_src = (batch_size, length, hidden_size) |
| 23 | + # h_t_tgt = (batch_size, 1, hidden_size) |
| 24 | + # mask = (batch_size, length) |
| 25 | + |
| 26 | + # query = (batch_size, 1, hidden_size) |
| 27 | + query = self.linear(h_t_tgt) |
| 28 | + # weight = (batch_size, 1, length) |
| 29 | + # length -> encoder 단 모든 타임스텝 결과에 대한 가중치를 뜻함 |
| 30 | + weight = torch.bmm(query, h_src.transpose(1, 2)) |
| 31 | + |
| 32 | + if mask is not None: |
| 33 | + # PAD token 자리의 가중치를 모두 -inf로 치환(학습 미반영) |
| 34 | + # 마스크를 씌우기 위해 mask가 해당 weight의 shape과 같아야함 |
| 35 | + # mask.unsqueeze(1) = (batch_size, 1, length) |
| 36 | + weight.masked_fill_(mask.unsqueeze(1), -float('inf')) |
| 37 | + |
| 38 | + # weight = (batch_size, 1, length) |
| 39 | + weight = self.softmax(weight) |
| 40 | + # h_src = (batch_size, length, hidden_size) |
| 41 | + # context_vector = (batch_size, 1, hidden_size) |
| 42 | + context_vector = torch.bmm(weight, h_src) |
| 43 | + return context_vector |
| 44 | + |
7 | 45 |
|
8 | 46 | class Encoder(nn.Module): |
9 | 47 |
|
@@ -91,45 +129,6 @@ def forward(self, emb_t, h_t_1_tilde, h_t_1): |
91 | 129 | return y, h |
92 | 130 |
|
93 | 131 |
|
94 | | -class Attention(nn.Module): |
95 | | - |
96 | | - def __init__(self, hidden_size): |
97 | | - super(Attention, self).__init__() |
98 | | - |
99 | | - self.linear = nn.Linear(hidden_size, hidden_size, bias=False) |
100 | | - self.softmax = nn.Softmax(dim=-1) |
101 | | - |
102 | | - def forward(self, h_src, h_t_tgt, mask=None): |
103 | | - ''' |
104 | | - Encoder 단의 원문 벡터값과 현재 값의 유사도 산출 |
105 | | - h_src: Encoder 전체 time-step의 히든 스테이트 출력값 |
106 | | - h_t_tgt: Decoder의 현재 time-step 히든 스테이트 출력값 |
107 | | - mask: 각 문장 각 토큰에 대한 PAD 여부 (True or False) |
108 | | - ''' |
109 | | - # h_src = (batch_size, length, hidden_size) |
110 | | - # h_t_tgt = (batch_size, 1, hidden_size) |
111 | | - # mask = (batch_size, length) |
112 | | - |
113 | | - # query = (batch_size, 1, hidden_size) |
114 | | - query = self.linear(h_t_tgt) |
115 | | - # weight = (batch_size, 1, length) |
116 | | - # length -> encoder 단 모든 타임스텝 결과에 대한 가중치를 뜻함 |
117 | | - weight = torch.bmm(query, h_src.transpose(1, 2)) |
118 | | - |
119 | | - if mask is not None: |
120 | | - # PAD token 자리의 가중치를 모두 -inf로 치환(학습 미반영) |
121 | | - # 마스크를 씌우기 위해 mask가 해당 weight의 shape과 같아야함 |
122 | | - # mask.unsqueeze(1) = (batch_size, 1, length) |
123 | | - weight.masked_fill_(mask.unsqueeze(1), -float('inf')) |
124 | | - |
125 | | - # weight = (batch_size, 1, length) |
126 | | - weight = self.softmax(weight) |
127 | | - # h_src = (batch_size, length, hidden_size) |
128 | | - # context_vector = (batch_size, 1, hidden_size) |
129 | | - context_vector = torch.bmm(weight, h_src) |
130 | | - return context_vector |
131 | | - |
132 | | - |
133 | 132 | class Generator(nn.Module): |
134 | 133 |
|
135 | 134 | def __init__(self, hidden_size, output_size): |
|
0 commit comments