Skip to content

Commit 0207c03

Browse files
committed
Update seq2seq.py
1 parent f7a1ba1 commit 0207c03

File tree

1 file changed

+38
-39
lines changed

1 file changed

+38
-39
lines changed

src/11_seq2seq/modules/seq2seq.py

Lines changed: 38 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,44 @@
44
from torch.nn.utils.rnn import pad_packed_sequence as unpack
55
import modules.data_loader as data_loader
66

7+
class Attention(nn.Module):
8+
9+
def __init__(self, hidden_size):
10+
super(Attention, self).__init__()
11+
12+
self.linear = nn.Linear(hidden_size, hidden_size, bias=False)
13+
self.softmax = nn.Softmax(dim=-1)
14+
15+
def forward(self, h_src, h_t_tgt, mask=None):
16+
'''
17+
Encoder 단의 원문 벡터값과 현재 값의 유사도 산출
18+
h_src: Encoder 전체 time-step의 히든 스테이트 출력값
19+
h_t_tgt: Decoder의 현재 time-step 히든 스테이트 출력값
20+
mask: 각 문장 각 토큰에 대한 PAD 여부 (True or False)
21+
'''
22+
# h_src = (batch_size, length, hidden_size)
23+
# h_t_tgt = (batch_size, 1, hidden_size)
24+
# mask = (batch_size, length)
25+
26+
# query = (batch_size, 1, hidden_size)
27+
query = self.linear(h_t_tgt)
28+
# weight = (batch_size, 1, length)
29+
# length -> encoder 단 모든 타임스텝 결과에 대한 가중치를 뜻함
30+
weight = torch.bmm(query, h_src.transpose(1, 2))
31+
32+
if mask is not None:
33+
# PAD token 자리의 가중치를 모두 -inf로 치환(학습 미반영)
34+
# 마스크를 씌우기 위해 mask가 해당 weight의 shape과 같아야함
35+
# mask.unsqueeze(1) = (batch_size, 1, length)
36+
weight.masked_fill_(mask.unsqueeze(1), -float('inf'))
37+
38+
# weight = (batch_size, 1, length)
39+
weight = self.softmax(weight)
40+
# h_src = (batch_size, length, hidden_size)
41+
# context_vector = (batch_size, 1, hidden_size)
42+
context_vector = torch.bmm(weight, h_src)
43+
return context_vector
44+
745

846
class Encoder(nn.Module):
947

@@ -91,45 +129,6 @@ def forward(self, emb_t, h_t_1_tilde, h_t_1):
91129
return y, h
92130

93131

94-
class Attention(nn.Module):
95-
96-
def __init__(self, hidden_size):
97-
super(Attention, self).__init__()
98-
99-
self.linear = nn.Linear(hidden_size, hidden_size, bias=False)
100-
self.softmax = nn.Softmax(dim=-1)
101-
102-
def forward(self, h_src, h_t_tgt, mask=None):
103-
'''
104-
Encoder 단의 원문 벡터값과 현재 값의 유사도 산출
105-
h_src: Encoder 전체 time-step의 히든 스테이트 출력값
106-
h_t_tgt: Decoder의 현재 time-step 히든 스테이트 출력값
107-
mask: 각 문장 각 토큰에 대한 PAD 여부 (True or False)
108-
'''
109-
# h_src = (batch_size, length, hidden_size)
110-
# h_t_tgt = (batch_size, 1, hidden_size)
111-
# mask = (batch_size, length)
112-
113-
# query = (batch_size, 1, hidden_size)
114-
query = self.linear(h_t_tgt)
115-
# weight = (batch_size, 1, length)
116-
# length -> encoder 단 모든 타임스텝 결과에 대한 가중치를 뜻함
117-
weight = torch.bmm(query, h_src.transpose(1, 2))
118-
119-
if mask is not None:
120-
# PAD token 자리의 가중치를 모두 -inf로 치환(학습 미반영)
121-
# 마스크를 씌우기 위해 mask가 해당 weight의 shape과 같아야함
122-
# mask.unsqueeze(1) = (batch_size, 1, length)
123-
weight.masked_fill_(mask.unsqueeze(1), -float('inf'))
124-
125-
# weight = (batch_size, 1, length)
126-
weight = self.softmax(weight)
127-
# h_src = (batch_size, length, hidden_size)
128-
# context_vector = (batch_size, 1, hidden_size)
129-
context_vector = torch.bmm(weight, h_src)
130-
return context_vector
131-
132-
133132
class Generator(nn.Module):
134133

135134
def __init__(self, hidden_size, output_size):

0 commit comments

Comments
 (0)