Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix onmt as library example #1292

Merged
merged 8 commits into from
Feb 15, 2019
257 changes: 137 additions & 120 deletions docs/source/Library.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
For this example, we will assume that we have run preprocess to
create our datasets. For instance

> python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/data -src_vocab_size 1000 -tgt_vocab_size 1000
> python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/data -src_vocab_size 10000 -tgt_vocab_size 10000



Expand All @@ -21,191 +21,208 @@ We begin by loading in the vocabulary for the model of interest. This will let u


```python
vocab = dict(torch.load("../../data/data.vocab.pt"))
src_padding = vocab["src"].stoi[onmt.inputters.PAD_WORD]
tgt_padding = vocab["tgt"].stoi[onmt.inputters.PAD_WORD]
vocab_fields = torch.load("data/data.vocab.pt")

src_text_field = vocab_fields["src"].base_field
src_vocab = src_text_field.vocab
src_padding = src_vocab.stoi[src_text_field.pad_token]

tgt_text_field = vocab_fields['tgt'].base_field
tgt_vocab = tgt_text_field.vocab
tgt_padding = tgt_vocab.stoi[tgt_text_field.pad_token]
```

Next we specify the core model itself. Here we will build a small model with an encoder and an attention based input feeding decoder. Both models will be RNNs and the encoder will be bidirectional


```python
emb_size = 10
rnn_size = 6
emb_size = 100
rnn_size = 500
# Specify the core model.
encoder_embeddings = onmt.modules.Embeddings(emb_size, len(vocab["src"]),

encoder_embeddings = onmt.modules.Embeddings(emb_size, len(src_vocab),
word_padding_idx=src_padding)

encoder = onmt.encoders.RNNEncoder(hidden_size=rnn_size, num_layers=1,
rnn_type="LSTM", bidirectional=True,
embeddings=encoder_embeddings)
rnn_type="LSTM", bidirectional=True,
embeddings=encoder_embeddings)

decoder_embeddings = onmt.modules.Embeddings(emb_size, len(vocab["tgt"]),
decoder_embeddings = onmt.modules.Embeddings(emb_size, len(tgt_vocab),
word_padding_idx=tgt_padding)
decoder = onmt.decoders.decoder.InputFeedRNNDecoder(hidden_size=rnn_size, num_layers=1,
bidirectional_encoder=True,
rnn_type="LSTM", embeddings=decoder_embeddings)
decoder = onmt.decoders.decoder.InputFeedRNNDecoder(
hidden_size=rnn_size, num_layers=1, bidirectional_encoder=True,
rnn_type="LSTM", embeddings=decoder_embeddings)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = onmt.models.model.NMTModel(encoder, decoder)
model.to(device)

# Specify the tgt word generator and loss computation module
model.generator = nn.Sequential(
nn.Linear(rnn_size, len(vocab["tgt"])),
nn.LogSoftmax())
loss = onmt.utils.loss.NMTLossCompute(model.generator, vocab["tgt"])
nn.Linear(rnn_size, len(tgt_vocab)),
nn.LogSoftmax(dim=-1))

loss = onmt.utils.loss.NMTLossCompute(
criterion=nn.NLLLoss(ignore_index=tgt_padding, reduction="sum"),
generator=model.generator)
```

Now we set up the optimizer. This could be a core torch optim class, or our wrapper which handles learning rate updates and gradient normalization automatically.
Now we set up the optimizer. Our wrapper around a core torch optim class handles learning rate updates and gradient normalization automatically.


```python
optim = onmt.utils.optimizers.Optimizer(method="sgd", learning_rate=1, max_grad_norm=2)
optim.set_parameters(model.named_parameters())
lr = 1
torch_optimizer = torch.optim.SGD(model.parameters(), lr=lr)
optim = onmt.utils.optimizers.Optimizer(
torch_optimizer, learning_rate=lr, max_grad_norm=2)
```

Now we load the data from disk. Currently will need to call a function to load the fields into the data as well.
Now we load the data from disk with the associated vocab fields. To iterate through the data itself we use a wrapper around a torchtext iterator class. We specify one for both the training and test data.


```python
# Load some data
data = torch.load("../../data/data.train.1.pt")
valid_data = torch.load("../../data/data.valid.1.pt")
data.load_fields(vocab)
valid_data.load_fields(vocab)
data.examples = data.examples[:100]
from itertools import chain
train_data_file = "data/data.train.0.pt"
valid_data_file = "data/data.valid.0.pt"
train_iter = onmt.inputters.inputter.DatasetLazyIter(dataset_paths=[train_data_file],
fields=vocab_fields,
batch_size=50,
batch_size_multiple=1,
batch_size_fn=None,
device=device,
is_train=True,
repeat=True)

valid_iter = onmt.inputters.inputter.DatasetLazyIter(dataset_paths=[valid_data_file],
fields=vocab_fields,
batch_size=10,
batch_size_multiple=1,
batch_size_fn=None,
device=device,
is_train=False,
repeat=False)
```

To iterate through the data itself we use a torchtext iterator class. We specify one for both the training and test data.
Finally we train. Keeping track of the output requires a report manager.


```python
train_iter = onmt.inputters.OrderedIterator(
dataset=data, batch_size=10,
device=-1,
repeat=False)
valid_iter = onmt.inputters.OrderedIterator(
dataset=valid_data, batch_size=10,
device=-1,
train=False)
report_manager = onmt.utils.ReportMgr(
report_every=50, start_time=None, tensorboard_writer=None)
trainer = onmt.Trainer(model=model,
train_loss=loss,
valid_loss=loss,
optim=optim,
report_manager=report_manager)
trainer.train(train_iter=train_iter,
train_steps=400,
valid_iter=valid_iter,
valid_steps=200)
```

Finally we train.


```python
trainer = onmt.Trainer(model, loss, loss, optim)

def report_func(*args):
stats = args[-1]
stats.output(args[0], args[1], 10, 0)
return stats

for epoch in range(2):
trainer.train(epoch, report_func)
val_stats = trainer.validate()

print("Validation")
val_stats.output(epoch, 11, 10, 0)
trainer.epoch_step(val_stats.ppl(), epoch)
```
[2019-02-15 16:34:17,475 INFO] Start training loop and validate every 200 steps...
[2019-02-15 16:34:17,601 INFO] Loading dataset from data/data.train.0.pt, number of examples: 10000
[2019-02-15 16:35:43,873 INFO] Step 50/ 400; acc: 11.54; ppl: 1714.07; xent: 7.45; lr: 1.00000; 662/656 tok/s; 86 sec
[2019-02-15 16:37:05,965 INFO] Step 100/ 400; acc: 13.75; ppl: 534.80; xent: 6.28; lr: 1.00000; 675/671 tok/s; 168 sec
[2019-02-15 16:38:31,289 INFO] Step 150/ 400; acc: 15.02; ppl: 439.96; xent: 6.09; lr: 1.00000; 675/668 tok/s; 254 sec
[2019-02-15 16:39:56,715 INFO] Step 200/ 400; acc: 16.08; ppl: 357.62; xent: 5.88; lr: 1.00000; 642/647 tok/s; 339 sec
[2019-02-15 16:39:56,811 INFO] Loading dataset from data/data.valid.0.pt, number of examples: 3000
[2019-02-15 16:41:13,415 INFO] Validation perplexity: 208.73
[2019-02-15 16:41:13,415 INFO] Validation accuracy: 23.3507
[2019-02-15 16:41:13,567 INFO] Loading dataset from data/data.train.0.pt, number of examples: 10000
[2019-02-15 16:42:41,562 INFO] Step 250/ 400; acc: 17.07; ppl: 310.41; xent: 5.74; lr: 1.00000; 347/344 tok/s; 504 sec
[2019-02-15 16:44:04,899 INFO] Step 300/ 400; acc: 19.17; ppl: 262.81; xent: 5.57; lr: 1.00000; 665/661 tok/s; 587 sec
[2019-02-15 16:45:33,653 INFO] Step 350/ 400; acc: 19.38; ppl: 244.81; xent: 5.50; lr: 1.00000; 649/642 tok/s; 676 sec
[2019-02-15 16:47:06,141 INFO] Step 400/ 400; acc: 20.44; ppl: 214.75; xent: 5.37; lr: 1.00000; 593/598 tok/s; 769 sec
[2019-02-15 16:47:06,265 INFO] Loading dataset from data/data.valid.0.pt, number of examples: 3000
[2019-02-15 16:48:27,328 INFO] Validation perplexity: 150.277
[2019-02-15 16:48:27,328 INFO] Validation accuracy: 24.2132
```

Epoch 0, 0/ 10; acc: 0.00; ppl: 1225.23; 1320 src tok/s; 1320 tgt tok/s; 1514090454 s elapsed
Epoch 0, 1/ 10; acc: 9.50; ppl: 996.33; 1188 src tok/s; 1194 tgt tok/s; 1514090454 s elapsed
Epoch 0, 2/ 10; acc: 16.51; ppl: 694.48; 1265 src tok/s; 1267 tgt tok/s; 1514090454 s elapsed
Epoch 0, 3/ 10; acc: 20.49; ppl: 470.39; 1459 src tok/s; 1420 tgt tok/s; 1514090454 s elapsed
Epoch 0, 4/ 10; acc: 22.68; ppl: 387.03; 1511 src tok/s; 1462 tgt tok/s; 1514090454 s elapsed
Epoch 0, 5/ 10; acc: 24.58; ppl: 345.44; 1625 src tok/s; 1509 tgt tok/s; 1514090454 s elapsed
Epoch 0, 6/ 10; acc: 25.37; ppl: 314.39; 1586 src tok/s; 1493 tgt tok/s; 1514090454 s elapsed
Epoch 0, 7/ 10; acc: 26.14; ppl: 291.15; 1593 src tok/s; 1520 tgt tok/s; 1514090455 s elapsed
Epoch 0, 8/ 10; acc: 26.32; ppl: 274.79; 1606 src tok/s; 1545 tgt tok/s; 1514090455 s elapsed
Epoch 0, 9/ 10; acc: 26.83; ppl: 247.32; 1669 src tok/s; 1614 tgt tok/s; 1514090455 s elapsed
Validation
Epoch 0, 11/ 10; acc: 13.41; ppl: 111.94; 0 src tok/s; 7329 tgt tok/s; 1514090464 s elapsed
Epoch 1, 0/ 10; acc: 6.59; ppl: 147.05; 1849 src tok/s; 1743 tgt tok/s; 1514090464 s elapsed
Epoch 1, 1/ 10; acc: 22.10; ppl: 130.66; 2002 src tok/s; 1957 tgt tok/s; 1514090464 s elapsed
Epoch 1, 2/ 10; acc: 20.16; ppl: 122.49; 1748 src tok/s; 1760 tgt tok/s; 1514090464 s elapsed
Epoch 1, 3/ 10; acc: 23.52; ppl: 117.41; 1690 src tok/s; 1698 tgt tok/s; 1514090464 s elapsed
Epoch 1, 4/ 10; acc: 24.16; ppl: 119.42; 1647 src tok/s; 1662 tgt tok/s; 1514090464 s elapsed
Epoch 1, 5/ 10; acc: 25.44; ppl: 115.31; 1775 src tok/s; 1709 tgt tok/s; 1514090465 s elapsed
Epoch 1, 6/ 10; acc: 24.05; ppl: 115.11; 1780 src tok/s; 1718 tgt tok/s; 1514090465 s elapsed
Epoch 1, 7/ 10; acc: 25.32; ppl: 109.59; 1799 src tok/s; 1765 tgt tok/s; 1514090465 s elapsed
Epoch 1, 8/ 10; acc: 25.14; ppl: 108.16; 1771 src tok/s; 1734 tgt tok/s; 1514090465 s elapsed
Epoch 1, 9/ 10; acc: 25.58; ppl: 107.13; 1817 src tok/s; 1757 tgt tok/s; 1514090465 s elapsed
Validation
Epoch 1, 11/ 10; acc: 19.58; ppl: 88.09; 0 src tok/s; 7371 tgt tok/s; 1514090474 s elapsed


To use the model, we need to load up the translation functions
To use the model, we need to load up the translation functions. A Translator object requires the vocab fields, readers for source and target and a global scorer.


```python
import onmt.translate
```

src_reader = onmt.inputters.str2reader["text"]
tgt_reader = onmt.inputters.str2reader["text"]
scorer = onmt.translate.GNMTGlobalScorer(alpha=0.7,
beta=0.,
length_penalty="avg",
coverage_penalty="none")
gpu = 0 if torch.cuda.is_available() else -1
translator = onmt.translate.Translator(model=model,
fields=vocab_fields,
src_reader=src_reader,
tgt_reader=tgt_reader,
global_scorer=scorer,
gpu=gpu)
builder = onmt.translate.TranslationBuilder(data=torch.load(valid_data_file),
fields=vocab_fields)

```python
translator = onmt.translate.Translator(beam_size=10, fields=data.fields, model=model)
builder = onmt.translate.TranslationBuilder(data=valid_data, fields=data.fields)

valid_data.src_vocabs
for batch in valid_iter:
trans_batch = translator.translate_batch(batch=batch, data=valid_data)
trans_batch = translator.translate_batch(
batch=batch, src_vocabs=[src_vocab],
attn_debug=False)
translations = builder.from_batch(trans_batch)
for trans in translations:
print(trans.log(0))
break
```
```
[2019-02-15 16:48:27,419 INFO] Loading dataset from data/data.valid.0.pt, number of examples: 3000

PRED SCORE: -4.0690

SENT 0: ('The', 'competitors', 'have', 'other', 'advantages', ',', 'too', '.')
PRED 0: .

PRED SCORE: -4.2736

SENT 0: ('The', 'company', ''s', 'durability', 'goes', 'back', 'to', 'its', 'first', 'boss', ',', 'a', 'visionary', ',', 'Thomas', 'J.', 'Watson', 'Sr.')
PRED 0: .
SENT 0: ['Parliament', 'Does', 'Not', 'Support', 'Amendment', 'Freeing', 'Tymoshenko']
PRED 0: <unk> ist ein <unk> <unk> <unk> .
PRED SCORE: -1.0983

PRED SCORE: -4.0144

SENT 0: ('&quot;', 'From', 'what', 'we', 'know', 'today', ',', 'you', 'have', 'to', 'ask', 'how', 'I', 'could', 'be', 'so', 'wrong', '.', '&quot;')
PRED 0: .
SENT 0: ['Today', ',', 'the', 'Ukraine', 'parliament', 'dismissed', ',', 'within', 'the', 'Code', 'of', 'Criminal', 'Procedure', 'amendment', ',', 'the', 'motion', 'to', 'revoke', 'an', 'article', 'based', 'on', 'which', 'the', 'opposition', 'leader', ',', 'Yulia', 'Tymoshenko', ',', 'was', 'sentenced', '.']
PRED 0: <unk> ist das <unk> <unk> .
PRED SCORE: -1.5950

PRED SCORE: -4.1361

SENT 0: ('Boeing', 'Co', 'shares', 'rose', '1.5%', 'to', '$', '67.94', '.')
PRED 0: .
SENT 0: ['The', 'amendment', 'that', 'would', 'lead', 'to', 'freeing', 'the', 'imprisoned', 'former', 'Prime', 'Minister', 'was', 'revoked', 'during', 'second', 'reading', 'of', 'the', 'proposal', 'for', 'mitigation', 'of', 'sentences', 'for', 'economic', 'offences', '.']
PRED 0: Es gibt es das <unk> der <unk> für <unk> <unk> .
PRED SCORE: -1.5128

PRED SCORE: -4.1382

SENT 0: ('Some', 'did', 'not', 'believe', 'him', ',', 'they', 'said', 'that', 'he', 'got', 'dizzy', 'even', 'in', 'the', 'truck', ',', 'but', 'always', 'wanted', 'to', 'fulfill', 'his', 'dream', ',', 'that', 'of', 'becoming', 'a', 'pilot', '.')
PRED 0: .
SENT 0: ['In', 'October', ',', 'Tymoshenko', 'was', 'sentenced', 'to', 'seven', 'years', 'in', 'prison', 'for', 'entering', 'into', 'what', 'was', 'reported', 'to', 'be', 'a', 'disadvantageous', 'gas', 'deal', 'with', 'Russia', '.']
PRED 0: <unk> ist ein <unk> <unk> .
PRED SCORE: -1.5578

PRED SCORE: -3.8881

SENT 0: ('In', 'your', 'opinion', ',', 'the', 'council', 'should', 'ensure', 'that', 'the', 'band', 'immediately', 'above', 'the', 'Ronda', 'de', 'Dalt', 'should', 'provide', 'in', 'its', 'entirety', ',', 'an', 'area', 'of', 'equipment', 'to', 'conduct', 'a', 'smooth', 'transition', 'between', 'the', 'city', 'and', 'the', 'green', '.')
PRED 0: .
SENT 0: ['The', 'verdict', 'is', 'not', 'yet', 'final;', 'the', 'court', 'will', 'hear', 'Tymoshenko', '&apos;s', 'appeal', 'in', 'December', '.']
PRED 0: <unk> ist nicht <unk> .
PRED SCORE: -0.9623

PRED SCORE: -4.0778

SENT 0: ('The', 'clerk', 'of', 'the', 'court', ',', 'Jorge', 'Yanez', ',', 'went', 'to', 'the', 'jail', 'of', 'the', 'municipality', 'of', 'San', 'Nicolas', 'of', 'Garza', 'to', 'notify', 'Jonah', 'that', 'he', 'has', 'been', 'legally', 'pardoned', 'and', 'his', 'record', 'will', 'be', 'filed', '.')
PRED 0: .
SENT 0: ['Tymoshenko', 'claims', 'the', 'verdict', 'is', 'a', 'political', 'revenge', 'of', 'the', 'regime;', 'in', 'the', 'West', ',', 'the', 'trial', 'has', 'also', 'evoked', 'suspicion', 'of', 'being', 'biased', '.']
PRED 0: <unk> ist ein <unk> <unk> .
PRED SCORE: -0.8703

PRED SCORE: -4.2479

SENT 0: ('&quot;', 'In', 'a', 'research', 'it', 'is', 'reported', 'that', 'there', 'are', 'no', 'parts', 'or', 'components', 'of', 'the', 'ship', 'in', 'another', 'place', ',', 'the', 'impact', 'is', 'presented', 'in', 'a', 'structural', 'way', '.')
PRED 0: .
SENT 0: ['The', 'proposal', 'to', 'remove', 'Article', '365', 'from', 'the', 'Code', 'of', 'Criminal', 'Procedure', ',', 'upon', 'which', 'the', 'former', 'Prime', 'Minister', 'was', 'sentenced', ',', 'was', 'supported', 'by', '147', 'members', 'of', 'parliament', '.']
PRED 0: <unk> Sie sich mit <unk> <unk> .
PRED SCORE: -1.4778

PRED SCORE: -3.8585

SENT 0: ('On', 'the', 'asphalt', 'covering', ',', 'he', 'added', ',', 'is', 'placed', 'a', 'final', 'layer', 'called', 'rolling', 'covering', ',', 'which', 'is', 'made', '\u200b', '\u200b', 'of', 'a', 'fine', 'stone', 'material', ',', 'meaning', 'sand', 'also', 'dipped', 'into', 'the', 'asphalt', '.')
PRED 0: .
SENT 0: ['Its', 'ratification', 'would', 'require', '226', 'votes', '.']
PRED 0: <unk> Sie sich <unk> .
PRED SCORE: -1.3341

PRED SCORE: -4.2298

SENT 0: ('This', 'is', '200', 'bar', 'on', 'leaving', 'and', '100', 'bar', 'on', 'arrival', '.')
PRED 0: .
SENT 0: ['Libya', '&apos;s', 'Victory']
PRED 0: <unk> Sie die <unk> <unk> .
PRED SCORE: -1.5192


SENT 0: ['The', 'story', 'of', 'Libya', '&apos;s', 'liberation', ',', 'or', 'rebellion', ',', 'already', 'has', 'its', 'defeated', '.']
PRED 0: <unk> ist ein <unk> <unk> .
PRED SCORE: -1.2772

/usr/local/lib/python3.5/dist-packages/torch/tensor.py:297: UserWarning: other is not broadcastable to self, but they have the same number of elements. Falling back to deprecated pointwise behavior.
return self.add_(other)
...