Skip to content

Commit

Permalink
pad mask issue and embedding issue
Browse files Browse the repository at this point in the history
  • Loading branch information
GJ98 committed Apr 9, 2021
1 parent db8c1e3 commit fa78973
Show file tree
Hide file tree
Showing 32 changed files with 33 additions and 26 deletions.
Binary file added __pycache__/conf.cpython-37.pyc
Binary file not shown.
Binary file added __pycache__/data.cpython-37.pyc
Binary file not shown.
Binary file added models/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file added models/blocks/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
4 changes: 2 additions & 2 deletions models/blocks/decoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
def forward(self, dec, enc, t_mask, s_mask):
# 1. compute self attention
_x = dec
x = self.self_attention(q=dec, k=dec, v=dec, mask=trg_mask)
x = self.self_attention(q=dec, k=dec, v=dec, mask=t_mask)

# 2. add and norm
x = self.norm1(x + _x)
Expand All @@ -38,7 +38,7 @@ def forward(self, dec, enc, t_mask, s_mask):
if enc is not None:
# 3. compute encoder - decoder attention
_x = x
x = self.enc_dec_attention(q=x, k=enc, v=enc, mask=src_mask)
x = self.enc_dec_attention(q=x, k=enc, v=enc, mask=s_mask)

# 4. add and norm
x = self.norm2(x + _x)
Expand Down
2 changes: 1 addition & 1 deletion models/blocks/encoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
def forward(self, x, s_mask):
# 1. compute self attention
_x = x
x = self.attention(q=x, k=x, v=x, mask=src_mask)
x = self.attention(q=x, k=x, v=x, mask=s_mask)

# 2. add and norm
x = self.norm1(x + _x)
Expand Down
Binary file added models/embedding/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
2 changes: 1 addition & 1 deletion models/embedding/token_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ class for token embedding that included positional information
:param vocab_size: size of vocabulary
:param d_model: dimensions of model
"""
super(TokenEmbedding, self).__init__(vocab_size, d_model, padding_idx=0)
super(TokenEmbedding, self).__init__(vocab_size, d_model, padding_idx=1)
Binary file added models/layers/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
2 changes: 2 additions & 0 deletions models/layers/scale_dot_product_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def forward(self, q, k, v, mask=None, e=1e-12):
k_t = k.view(batch_size, head, d_tensor, length) # transpose
score = (q @ k_t) / math.sqrt(d_tensor) # scaled dot product

print("score : {}" .format(score.size()))
print("mask : {}" .format(mask.size()))
# 2. apply masking (opt)
if mask is not None:
score = score.masked_fill(mask == 0, -e)
Expand Down
Binary file added models/model/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file added models/model/__pycache__/decoder.cpython-37.pyc
Binary file not shown.
Binary file added models/model/__pycache__/encoder.cpython-37.pyc
Binary file not shown.
Binary file not shown.
40 changes: 23 additions & 17 deletions models/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,31 +38,37 @@ def __init__(self, src_pad_idx, trg_pad_idx, trg_sos_idx, enc_voc_size, dec_voc_
device=device)

def forward(self, src, trg):
src_mask = self.make_src_mask(src)
trg_mask = self.make_trg_mask(trg)
src_mask = self.make_pad_mask(src, src)

src_trg_mask = self.make_pad_mask(trg, src)

trg_mask = self.make_pad_mask(trg, trg) * \
self.make_no_peak_mask(trg, trg)

enc_src = self.encoder(src, src_mask)
output = self.decoder(trg, enc_src, trg_mask, src_mask)
output = self.decoder(trg, enc_src, trg_mask, src_trg_mask)
return output

def make_src_mask(self, src):
batch_size, length = src.size()
def make_pad_mask(self, q, k):
len_q, len_k = q.size(1), k.size(1)

# batch_size x 1 x 1 x len_k
src_k = src.ne(self.src_pad_idx).unsqueeze(1).unsqueeze(2)
k = k.ne(self.src_pad_idx).unsqueeze(1).unsqueeze(2)
# batch_size x 1 x len_q x len_k
src_k = src_k.repeat(1, 1, length, 1)
k = k.repeat(1, 1, len_q, 1)

# batch_size x 1 x len_q x 1
src_q = src.ne(self.src_pad_idx).unsqueeze(1).unsqueeze(3)
q = q.ne(self.src_pad_idx).unsqueeze(1).unsqueeze(3)
# batch_size x 1 x len_q x len_k
src_q = src_q.repeat(1, 1, 1, length)
q = q.repeat(1, 1, 1, len_k)

mask = k & q
return mask

def make_no_peak_mask(self, q, k):
len_q, len_k = q.size(1), k.size(1)

src_mask = src_k & src_q
return src_mask
# len_q x len_k
mask = torch.tril(torch.ones(len_q, len_k)).type(torch.BoolTensor).to(self.device)

def make_trg_mask(self, trg):
trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(3)
trg_len = trg.shape[1]
trg_sub_mask = torch.tril(torch.ones(trg_len, trg_len)).type(torch.ByteTensor).to(self.device)
trg_mask = trg_pad_mask & trg_sub_mask
return trg_mask
return mask
1 change: 0 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def initialize_weights(m):
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
verbose=True,
factor=factor,
min_lr=min_lr,
patience=patience)

criterion = nn.CrossEntropyLoss(ignore_index=src_pad_idx)
Expand Down
Binary file added util/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file added util/__pycache__/bleu.cpython-37.pyc
Binary file not shown.
Binary file added util/__pycache__/data_loader.cpython-37.pyc
Binary file not shown.
Binary file added util/__pycache__/epoch_timer.cpython-37.pyc
Binary file not shown.
Binary file added util/__pycache__/tokenizer.cpython-37.pyc
Binary file not shown.
4 changes: 2 additions & 2 deletions util/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
@when : 2019-10-29
@homepage : https://github.com/gusdnd852
"""
from torchtext.data import Field, BucketIterator
from torchtext.datasets.translation import Multi30k
from torchtext.legacy.data import Field, BucketIterator
from torchtext.legacy.datasets.translation import Multi30k


class DataLoader:
Expand Down
4 changes: 2 additions & 2 deletions util/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
class Tokenizer:

def __init__(self):
self.spacy_de = spacy.load('de')
self.spacy_en = spacy.load('en')
self.spacy_de = spacy.load('de_core_news_sm')
self.spacy_en = spacy.load('en_core_web_sm')

def tokenize_de(self, text):
"""
Expand Down

0 comments on commit fa78973

Please sign in to comment.