Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix onmt as library example #1292

Merged
merged 8 commits into from
Feb 15, 2019

Conversation

elisemicho
Copy link
Contributor

This example is great for one to dig in how things work in OpenNMT but there has been quite a lot of changes recently in datasets, vocabs, losses, optimizers, iterators, trainer, translator, logging so that the provided example was not working anymore.
I made it work again for training, but I am missing something for translating as only unks are ouput.
Can you help correct this last part and check the modifications make a correct usage of the functions?
Thanks!

fixing some pep8 stuff
Copy link
Contributor

@pltrdy pltrdy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the Pr,
I just suggested bunch of pep8 fixes to make it a bit more compliant.
As I've no direct write access to that PR, I opened and PR on your branch, plz merge it on your repo before we merge this one on master.

elisemicho#1

a bit more pep8 compliant version
@elisemicho
Copy link
Contributor Author

Thanks for the formatting!

@vince62s
Copy link
Member

Thanks for your contribution
@elisemicho can you please adjust based on #1296 ?

Do you think we could have a more meaning full toyset / example that does not output only unk ?
if not, that's ok but for academic purpose that would be great.

@vince62s vince62s mentioned this pull request Feb 14, 2019
@elisemicho
Copy link
Contributor Author

Thank you for your comments. I also think it would be better to output a real translation, but we should be able to get it from this data and this model as it was the case before code changes.
I probably just misuse one of the object in the translation part. Can you check?

from itertools import chain
train_data_file = "data/data.train.0.pt"
valid_data_file = "data/data.valid.0.pt"
dataset_fields = dict(chain.from_iterable(vocab_fields.values()))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not very intuitive that iterators expect "fields" to be a dict[str, List[Tuple[str, Field]]] whereas Tranlator or TranslatorBuilder expect "fields" to be a dict[str, list[tuple[str, torchtext.data.Field]]].
Maybe choose discriminant names?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or they should all expect the same structure. See this comment

https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/train_single.py#L139

The docs for the format that DatasetLazyIter expects is evidently wrong.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should they? As for now, giving DatasetLazyIter a dict[str, torchtext.data.Field] actually makes the code work, while giving it the original dict[str, list[tuple[str, torchtext.data.Field]]] does not allow training to start because the structure is not correct. Could you expand a bit more on the logic behind?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're correct that DatasetLazyIter actually needs dict[str, torchtext.data.Field]. The docs claim DatasetLazyIter expects dict[str, list[tuple[str, torchtext.data.Field]]] (I wrote those docs and I got it wrong because I incorrectly assumed it would take the same structure as the other field arguments, sorry!).

Now, the comment in train_single.py that I linked above says that DatasetLazyIter expecting a different structure is (or, should have been) a temporary thing. It's a trivial fix. Move the dict(itertools.chain.from_iterable(fields.values())) business inside DatasetLazyIter.

But, it gets weird because when torchtext says fields, e.g. in Dataset, they're talking about list[tuple[str, Field]]. Onmt uses the extra level of nesting because onmt text features (POS tags, NER, etc) used to be handled as separate fields. Now, that's unnecessary because of TextMultiField.

All to say, yes they should be named different and the docs should be updated OR they should expect the same thing (either the extra level of nesting or, more elegantly, without it so that it matches torchtext).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, torchtext is inconsistent about what it means too:
https://github.com/pytorch/text/blob/master/torchtext/data/dataset.py#L23 (fields is a dict[str, Field])
https://github.com/pytorch/text/blob/master/torchtext/data/example.py#L19 (fields should be dict[str, list[tuple[str, Field]]] or dict[str, tuple[str, Field]])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, I let it that way at the moment in the example, just expliciting a bit more, let me know when things are set on how onmt handles this.

@flauted
Copy link
Contributor

flauted commented Feb 14, 2019

@elisemicho I think I got it. The root of the problem is preprocessing with 1000 tokens in the input and output vocabs. Bump that to 10000:

python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/data -src_vocab_size 10000 -tgt_vocab_size 10000

That alone will fix the <unk>-iness, but the predictions will mostly be empty or periods.

If you want a more realistic example like @vince62s is saying, try this:

emb_size = 100
rnn_size = 500
(train_iter) batch_size = 50
(train_iter) repeat = True
train_steps = 400
valid_steps = 200
(ReportMgr) report_every = 50

Here's my code:

from itertools import chain

import torch
import torch.nn as nn

import onmt
import onmt.inputters
import onmt.modules
import onmt.utils
import onmt.translate
from onmt.utils.logging import init_logger


init_logger(None)


def main():
    vocab_fields = torch.load("data/data.vocab.pt")

    src_text_field = vocab_fields["src"][0][1].base_field
    src_vocab = src_text_field.vocab
    src_padding = src_vocab.stoi[src_text_field.pad_token]

    tgt_text_field = vocab_fields['tgt'][0][1].base_field
    tgt_vocab = tgt_text_field.vocab
    tgt_padding = tgt_vocab.stoi[tgt_text_field.pad_token]

    emb_size = 100
    rnn_size = 500
    # Specify the core model.
    encoder_embeddings = onmt.modules.Embeddings(emb_size, len(src_vocab),
                                                 word_padding_idx=src_padding)

    encoder = onmt.encoders.RNNEncoder(hidden_size=rnn_size, num_layers=1,
                                       rnn_type="LSTM", bidirectional=True,
                                       embeddings=encoder_embeddings)

    decoder_embeddings = onmt.modules.Embeddings(emb_size, len(tgt_vocab),
                                                 word_padding_idx=tgt_padding)
    decoder = onmt.decoders.decoder.InputFeedRNNDecoder(
        hidden_size=rnn_size, num_layers=1, bidirectional_encoder=True,
        rnn_type="LSTM", embeddings=decoder_embeddings)

    model = onmt.models.model.NMTModel(encoder, decoder)

    # you could use torch.device(str) instead of str if you want
    dev = "cuda" if torch.cuda.is_available() else "cpu"

    # Specify the tgt word generator and loss computation module
    model.generator = nn.Sequential(
        nn.Linear(rnn_size, len(tgt_vocab)),
        nn.LogSoftmax(dim=-1))

    model.to(dev)

    loss = onmt.utils.loss.NMTLossCompute(
        criterion=nn.NLLLoss(ignore_index=tgt_padding, reduction='sum'),
        generator=model.generator)

    lr = 1
    torch_optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    optim = onmt.utils.optimizers.Optimizer(
        torch_optimizer, learning_rate=lr, max_grad_norm=2)

    # Load some data
    train_data_file = "data/data.train.0.pt"
    valid_data_file = "data/data.valid.0.pt"
    dataset_fields = dict(chain.from_iterable(vocab_fields.values()))
    train_iter = onmt.inputters.inputter.DatasetLazyIter(
        dataset_paths=[train_data_file],
        fields=dataset_fields,
        batch_size=50,
        batch_size_multiple=1,
        batch_size_fn=None,
        device=dev,
        is_train=True,
        repeat=True)

    valid_iter = onmt.inputters.inputter.DatasetLazyIter(
        dataset_paths=[valid_data_file],
        fields=dataset_fields,
        batch_size=10,
        batch_size_multiple=1,
        batch_size_fn=None,
        device=dev,
        is_train=False,
        repeat=False)

    report_mgr = onmt.utils.ReportMgr(
        report_every=50, start_time=None, tensorboard_writer=None)
    trainer = onmt.Trainer(model=model,
                           train_loss=loss,
                           valid_loss=loss,
                           optim=optim,
                           report_manager=report_mgr)
    trainer.train(train_iter=train_iter,
                  train_steps=400,
                  valid_iter=valid_iter,
                  valid_steps=200)

    src_reader = onmt.inputters.str2reader["text"]
    tgt_reader = onmt.inputters.str2reader["text"]
    scorer = onmt.translate.GNMTGlobalScorer(alpha=0.7,
                                             beta=0,
                                             length_penalty="avg",
                                             coverage_penalty="none")

    xdev = 0 if torch.cuda.is_available() else -1
    translator = onmt.translate.Translator(model=model,
                                           fields=vocab_fields,
                                           src_reader=src_reader,
                                           tgt_reader=tgt_reader,
                                           global_scorer=scorer,
                                           gpu=xdev)
    builder = onmt.translate.TranslationBuilder(
        data=torch.load(valid_data_file),
        fields=vocab_fields)

    for batch in valid_iter:
        trans_batch = translator.translate_batch(
            batch=batch, src_vocabs=[src_vocab],
            attn_debug=False)
        translations = builder.from_batch(trans_batch)
        for trans in translations:
            print(trans.log(0))


if __name__ == "__main__":
    main()

@elisemicho
Copy link
Contributor Author

Thank you for the improvements, indeed increasing the scale is a good idea, however I do reproduce the many-unks output with your code. Do you get a better output?

@flauted
Copy link
Contributor

flauted commented Feb 14, 2019

@elisemicho Did you preprocess it with 10000 tokens instead of 1000? That's key here.

@elisemicho
Copy link
Contributor Author

Yes, just to make sure I have tried it again, perplexity and prediction scores look better but unks are still prevalent.

python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/data -src_vocab_size 10000 -tgt_vocab_size 10000
[2019-02-14 18:25:43,350 INFO] Extracting features...
[2019-02-14 18:25:43,350 INFO]  * number of source features: 0.
[2019-02-14 18:25:43,350 INFO]  * number of target features: 0.
[2019-02-14 18:25:43,350 INFO] Building `Fields` object...
[2019-02-14 18:25:43,350 INFO] Building & saving training data...
[2019-02-14 18:25:43,350 INFO] Reading source and target files: data/src-train.txt data/tgt-train.txt.
[2019-02-14 18:25:43,354 INFO] Building shard 0.
[2019-02-14 18:25:43,614 INFO]  * saving 0th train data shard to data/data.train.0.pt.
[2019-02-14 18:25:44,084 INFO] Building & saving validation data...
[2019-02-14 18:25:44,084 INFO] Reading source and target files: data/src-val.txt data/tgt-val.txt.
[2019-02-14 18:25:44,086 INFO] Building shard 0.
[2019-02-14 18:25:44,176 INFO]  * saving 0th valid data shard to data/data.valid.0.pt.
[2019-02-14 18:25:44,383 INFO] Building & saving vocabulary...
[2019-02-14 18:25:44,496 INFO]  * reloading data/data.train.0.pt.
[2019-02-14 18:25:44,915 INFO]  * tgt vocab size: 10004.
[2019-02-14 18:25:44,975 INFO]  * src vocab size: 10002.

python example.py
[2019-02-14 18:25:57,291 INFO] Start training loop and validate every 200 steps...
[2019-02-14 18:25:57,389 INFO] Loading dataset from data/data.train.0.pt, number of examples: 10000
[2019-02-14 18:27:20,504 INFO] Step 50/  400; acc:  11.40; ppl: 1573.22; xent: 7.36; lr: 1.00000; 671/671 tok/s;     83 sec
[2019-02-14 18:28:44,065 INFO] Step 100/  400; acc:  13.35; ppl: 592.66; xent: 6.38; lr: 1.00000; 679/668 tok/s;    167 sec
[2019-02-14 18:30:13,601 INFO] Step 150/  400; acc:  14.77; ppl: 447.79; xent: 6.10; lr: 1.00000; 661/657 tok/s;    256 sec
[2019-02-14 18:31:36,674 INFO] Step 200/  400; acc:  17.26; ppl: 345.90; xent: 5.85; lr: 1.00000; 642/645 tok/s;    339 sec
[2019-02-14 18:31:36,789 INFO] Loading dataset from data/data.valid.0.pt, number of examples: 3000
[2019-02-14 18:32:31,256 INFO] Validation perplexity: 210.332
[2019-02-14 18:32:31,256 INFO] Validation accuracy: 23.4257
[2019-02-14 18:32:31,375 INFO] Loading dataset from data/data.train.0.pt, number of examples: 10000
[2019-02-14 18:33:23,477 INFO] Step 250/  400; acc:  17.86; ppl: 303.87; xent: 5.72; lr: 1.00000; 523/522 tok/s;    446 sec
[2019-02-14 18:34:14,834 INFO] Step 300/  400; acc:  18.68; ppl: 267.40; xent: 5.59; lr: 1.00000; 1105/1087 tok/s;    498 sec
[2019-02-14 18:35:08,973 INFO] Step 350/  400; acc:  19.28; ppl: 236.45; xent: 5.47; lr: 1.00000; 1093/1086 tok/s;    552 sec
[2019-02-14 18:35:59,600 INFO] Step 400/  400; acc:  21.33; ppl: 210.03; xent: 5.35; lr: 1.00000; 1053/1058 tok/s;    602 sec
[2019-02-14 18:35:59,689 INFO] Loading dataset from data/data.valid.0.pt, number of examples: 3000
[2019-02-14 18:36:43,776 INFO] Validation perplexity: 181.166
[2019-02-14 18:36:43,776 INFO] Validation accuracy: 18.3015
[2019-02-14 18:36:43,827 INFO] Loading dataset from data/data.valid.0.pt, number of examples: 3000

SENT 0: ['Parliament', 'Does', 'Not', 'Support', 'Amendment', 'Freeing', 'Tymoshenko']
PRED 0: <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk>
PRED SCORE: -1.2510


SENT 0: ['Today', ',', 'the', 'Ukraine', 'parliament', 'dismissed', ',', 'within', 'the', 'Code', 'of', 'Criminal', 'Procedure', 'amendment', ',', 'the', 'motion', 'to', 'revoke', 'an', 'article', 'based', 'on', 'which', 'the', 'opposition', 'leader', ',', 'Yulia', 'Tymoshenko', ',', 'was', 'sentenced', '.']
PRED 0: In der Europäischen Union , die die Mitgliedstaaten der EU , die die Mitgliedstaaten der EU , die die Mitgliedstaaten der EU , die die Mitgliedstaaten der EU , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk>
PRED SCORE: -1.4730


SENT 0: ['The', 'amendment', 'that', 'would', 'lead', 'to', 'freeing', 'the', 'imprisoned', 'former', 'Prime', 'Minister', 'was', 'revoked', 'during', 'second', 'reading', 'of', 'the', 'proposal', 'for', 'mitigation', 'of', 'sentences', 'for', 'economic', 'offences', '.']
PRED 0: Das <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> , die für die
PRED SCORE: -1.8630

@flauted
Copy link
Contributor

flauted commented Feb 14, 2019

I mean there's stochasticity. Some runs will be better than others.

[2019-02-14 12:46:11,178 INFO] Start training loop and validate every 200 steps...
[2019-02-14 12:46:11,268 INFO] Loading dataset from data/data.train.0.pt, number of examples: 10000
[2019-02-14 12:46:13,577 INFO] Step 50/  400; acc:  11.44; ppl: 1639.46; xent: 7.40; lr: 1.00000; 24735/24464 tok/s;      2 sec
[2019-02-14 12:46:15,554 INFO] Step 100/  400; acc:  14.10; ppl: 539.04; xent: 6.29; lr: 1.00000; 26306/26549 tok/s;      4 sec
[2019-02-14 12:46:17,278 INFO] Step 150/  400; acc:  15.27; ppl: 433.24; xent: 6.07; lr: 1.00000; 32708/32752 tok/s;      6 sec
[2019-02-14 12:46:19,050 INFO] Step 200/  400; acc:  16.79; ppl: 349.49; xent: 5.86; lr: 1.00000; 32367/31810 tok/s;      8 sec
[2019-02-14 12:46:19,117 INFO] Loading dataset from data/data.valid.0.pt, number of examples: 3000
[2019-02-14 12:46:25,429 INFO] Validation perplexity: 296.3
[2019-02-14 12:46:25,429 INFO] Validation accuracy: 14.98
[2019-02-14 12:46:25,517 INFO] Loading dataset from data/data.train.0.pt, number of examples: 10000
[2019-02-14 12:46:27,377 INFO] Step 250/  400; acc:  17.22; ppl: 318.21; xent: 5.76; lr: 1.00000; 7128/7050 tok/s;     16 sec
[2019-02-14 12:46:28,966 INFO] Step 300/  400; acc:  19.25; ppl: 253.99; xent: 5.54; lr: 1.00000; 32722/33025 tok/s;     18 sec
[2019-02-14 12:46:30,687 INFO] Step 350/  400; acc:  19.49; ppl: 241.71; xent: 5.49; lr: 1.00000; 32763/32807 tok/s;     20 sec
[2019-02-14 12:46:32,440 INFO] Step 400/  400; acc:  20.14; ppl: 214.79; xent: 5.37; lr: 1.00000; 32725/32162 tok/s;     21 sec
[2019-02-14 12:46:32,503 INFO] Loading dataset from data/data.valid.0.pt, number of examples: 3000
[2019-02-14 12:46:38,781 INFO] Validation perplexity: 160.807
[2019-02-14 12:46:38,782 INFO] Validation accuracy: 21.1167
[2019-02-14 12:46:38,813 INFO] Loading dataset from data/data.valid.0.pt, number of examples: 3000

SENT 0: ['Parliament', 'Does', 'Not', 'Support', 'Amendment', 'Freeing', 'Tymoshenko']
PRED 0: <unk> .
PRED SCORE: -1.5438


SENT 0: ['Today', ',', 'the', 'Ukraine', 'parliament', 'dismissed', ',', 'within', 'the', 'Code', 'of', 'Criminal', 'Procedure', 'amendment', ',', 'the', 'motion', 'to', 'revoke', 'an', 'article', 'based', 'on', 'which', 'the', 'opposition', 'leader', ',', 'Yulia', 'Tymoshenko', ',', 'was', 'sentenced', '.']
PRED 0: Herr Präsident , die <unk> , die in der <unk> <unk> , die in der <unk> <unk> , die in der <unk> <unk> , die in der <unk> <unk> , die in der <unk> <unk> , die in der <unk> <unk> , die in der <unk> <unk> , die in der <unk> <unk> , die in der <unk> <unk> , die in der <unk> <unk> , die in der <unk> <unk> , die in der <unk> <unk> , die in der <unk> <unk> , die in der <unk> <unk> , die in der <unk> <unk> , die in der <unk>
PRED SCORE: -1.4523


SENT 0: ['The', 'amendment', 'that', 'would', 'lead', 'to', 'freeing', 'the', 'imprisoned', 'former', 'Prime', 'Minister', 'was', 'revoked', 'during', 'second', 'reading', 'of', 'the', 'proposal', 'for', 'mitigation', 'of', 'sentences', 'for', 'economic', 'offences', '.']
PRED 0: Die <unk> ist die <unk> der <unk> für die <unk> des <unk> für die <unk> des <unk> für die <unk> des <unk> für die <unk> des <unk> für die <unk> des <unk> für die <unk> des <unk> für die <unk> des <unk> für die <unk> des <unk> für die <unk> des <unk> für die <unk> des <unk> für die <unk> des <unk> für die <unk> des <unk> für die <unk> des <unk> für die <unk> des <unk> für die <unk> des <unk> für die <unk> des <unk> für die <unk> des <unk> für die <unk> des <unk> für die <unk>
PRED SCORE: -1.5317


SENT 0: ['In', 'October', ',', 'Tymoshenko', 'was', 'sentenced', 'to', 'seven', 'years', 'in', 'prison', 'for', 'entering', 'into', 'what', 'was', 'reported', 'to', 'be', 'a', 'disadvantageous', 'gas', 'deal', 'with', 'Russia', '.']
PRED 0: In der <unk> ist in der <unk> in der <unk> , in der <unk> in der <unk> <unk> , in der <unk> , in der <unk> , in der <unk> und <unk> , in der <unk> , in der <unk> , in der <unk> , in der <unk> , in der <unk> , in der <unk> , in der <unk> , in der <unk> , in der <unk> , in der <unk> , in der <unk> , in der <unk> , in der <unk> , in der <unk> , in der <unk> , in der <unk> , in der <unk>
PRED SCORE: -1.7339


SENT 0: ['The', 'verdict', 'is', 'not', 'yet', 'final;', 'the', 'court', 'will', 'hear', 'Tymoshenko', '&apos;s', 'appeal', 'in', 'December', '.']
PRED 0: Die <unk> ist nicht in der <unk> <unk> .
PRED SCORE: -1.0052


SENT 0: ['Tymoshenko', 'claims', 'the', 'verdict', 'is', 'a', 'political', 'revenge', 'of', 'the', 'regime;', 'in', 'the', 'West', ',', 'the', 'trial', 'has', 'also', 'evoked', 'suspicion', 'of', 'being', 'biased', '.']
PRED 0: <unk> ist in der <unk> <unk> in der <unk> <unk> .
PRED SCORE: -1.1670


SENT 0: ['The', 'proposal', 'to', 'remove', 'Article', '365', 'from', 'the', 'Code', 'of', 'Criminal', 'Procedure', ',', 'upon', 'which', 'the', 'former', 'Prime', 'Minister', 'was', 'sentenced', ',', 'was', 'supported', 'by', '147', 'members', 'of', 'parliament', '.']
PRED 0: Die <unk> ist in der <unk> <unk> , die <unk> , die <unk> , die <unk> , die <unk> , die <unk> , die <unk> , die <unk> , die <unk> , die <unk> , die <unk> , die <unk> , die <unk> , die <unk> , die <unk> , die <unk> , die <unk> , die <unk> , die <unk> , die <unk> , die <unk> , die <unk> , die <unk> , die <unk> , die <unk> , die <unk> , die <unk> , die <unk> , die <unk> , die <unk> , die <unk> , die <unk>
PRED SCORE: -1.5980


SENT 0: ['Its', 'ratification', 'would', 'require', '226', 'votes', '.']
PRED 0: Die <unk> .
PRED SCORE: -1.5181

@vince62s
Copy link
Member

Not sure if you have some new changes to make after #1299
Let me know when ready to merge.

@elisemicho
Copy link
Contributor Author

elisemicho commented Feb 15, 2019

@vince62s Indeed I had some changes to make, it's done and ready to be merged.
Thank you all, library example is back!

@vince62s vince62s merged commit 492f0cf into OpenNMT:master Feb 15, 2019
@vince62s
Copy link
Member

reverting to understand why travis build broke.

vince62s added a commit that referenced this pull request Feb 16, 2019
ItaySofer pushed a commit to ItaySofer/OpenNMT-py that referenced this pull request Mar 17, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants