Skip to content

Commit 36daef3

Browse files
committed
Update translate.py
1 parent f70c52a commit 36daef3

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

src/12_transformer/translate.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from modules.data_loader import DataLoader
77
import modules.data_loader as data_loader
8-
from modules.seq2seq import Seq2Seq
8+
from modules.transformer import Transformer
99

1010

1111
def define_argparser():
@@ -108,14 +108,15 @@ def get_vocabs(train_config, config, saved_data):
108108

109109

110110
def get_model(input_size, output_size, train_config):
111-
model = Seq2Seq(
112-
input_size,
113-
train_config.word_vec_size,
114-
train_config.hidden_size,
115-
output_size,
116-
n_layers=train_config.n_layers,
117-
dropout_p=train_config.dropout,
118-
)
111+
model = Transformer(
112+
input_size,
113+
train_config.hidden_size,
114+
output_size,
115+
n_splits=train_config.n_splits,
116+
n_enc_blocks=train_config.n_layers,
117+
n_dec_blocks=train_config.n_layers,
118+
dropout_p=train_config.dropout,
119+
)
119120
model.load_state_dict(saved_data['model'])
120121
model.eval()
121122
return model

0 commit comments

Comments
 (0)