Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastianGehrmann committed Jul 13, 2017
1 parent 49b579e commit 52eac06
Showing 1 changed file with 33 additions and 27 deletions.
60 changes: 33 additions & 27 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,33 +247,39 @@ def test_46_nmtmodel_forward_coverage(self):
self.ntmmodel_forward(opt)

# NO TRANSFORMER FOR NOW - DOES NOT WORK
# def test_47_nmtmodel_forward_decoder_transformer(self):
# """
# Test to check whether the model forward yields the correct size
# with coverage attention
# """
# opt = copy.deepcopy(self.opt)
# opt.decoder_layer = 'transformer'
# self.ntmmodel_forward(opt)

# def test_47_1_nmtmodel_forward_encoder_transformer(self):
# """
# Test to check whether the model forward yields the correct size
# with coverage attention
# """
# opt = copy.deepcopy(self.opt)
# opt.encoder_layer = 'transformer'
# self.ntmmodel_forward(opt)

# def test_47_2_nmtmodel_forward_both_transformer(self):
# """
# Test to check whether the model forward yields the correct size
# with coverage attention
# """
# opt = copy.deepcopy(self.opt)
# opt.decoder_layer = 'transformer'
# opt.encoder_layer = 'transformer'
# self.ntmmodel_forward(opt)
def test_47_nmtmodel_forward_decoder_transformer(self):
"""
Test to check whether the model forward yields the correct size
with coverage attention
"""
opt = copy.deepcopy(self.opt)
opt.decoder_layer = 'transformer'
opt.word_vec_size = 64
opt.rnn_size = 64
self.ntmmodel_forward(opt)

def test_47_1_nmtmodel_forward_encoder_transformer(self):
"""
Test to check whether the model forward yields the correct size
with coverage attention
"""
opt = copy.deepcopy(self.opt)
opt.encoder_layer = 'transformer'
opt.word_vec_size = 64
opt.rnn_size = 64
self.ntmmodel_forward(opt)

def test_47_2_nmtmodel_forward_both_transformer(self):
"""
Test to check whether the model forward yields the correct size
with coverage attention
"""
opt = copy.deepcopy(self.opt)
opt.decoder_layer = 'transformer'
opt.encoder_layer = 'transformer'
opt.word_vec_size = 64
opt.rnn_size = 64
self.ntmmodel_forward(opt)

def test_48_nmtmodel_forward_no_input_feed(self):
"""
Expand Down

0 comments on commit 52eac06

Please sign in to comment.