Skip to content

Commit

Permalink
Merge pull request #14 from GJ98/master
Browse files Browse the repository at this point in the history
fix pad bug
  • Loading branch information
hyunwoongko authored Jan 14, 2023
2 parents 3821253 + 67ccbb0 commit 424c6da
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions models/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,27 +38,27 @@ def __init__(self, src_pad_idx, trg_pad_idx, trg_sos_idx, enc_voc_size, dec_voc_
device=device)

def forward(self, src, trg):
src_mask = self.make_pad_mask(src, src)
src_mask = self.make_pad_mask(src, src, self.src_pad_idx, self.src_pad_idx)

src_trg_mask = self.make_pad_mask(trg, src)
src_trg_mask = self.make_pad_mask(trg, src, self.trg_pad_idx, self.src_pad_idx)

trg_mask = self.make_pad_mask(trg, trg) * \
trg_mask = self.make_pad_mask(trg, trg, self.trg_pad_idx, self.trg_pad_idx) * \
self.make_no_peak_mask(trg, trg)

enc_src = self.encoder(src, src_mask)
output = self.decoder(trg, enc_src, trg_mask, src_trg_mask)
return output

def make_pad_mask(self, q, k):
def make_pad_mask(self, q, k, q_pad_idx, k_pad_idx):
len_q, len_k = q.size(1), k.size(1)

# batch_size x 1 x 1 x len_k
k = k.ne(self.src_pad_idx).unsqueeze(1).unsqueeze(2)
k = k.ne(k_pad_idx).unsqueeze(1).unsqueeze(2)
# batch_size x 1 x len_q x len_k
k = k.repeat(1, 1, len_q, 1)

# batch_size x 1 x len_q x 1
q = q.ne(self.src_pad_idx).unsqueeze(1).unsqueeze(3)
q = q.ne(q_pad_idx).unsqueeze(1).unsqueeze(3)
# batch_size x 1 x len_q x len_k
q = q.repeat(1, 1, 1, len_k)

Expand Down

0 comments on commit 424c6da

Please sign in to comment.