Skip to content

Commit

Permalink
trim vocab(s) before saving checkpoint (OpenNMT#1453)
Browse files Browse the repository at this point in the history
  • Loading branch information
francoishernandez authored and vince62s committed Jun 3, 2019
1 parent 1b3cc33 commit 8292a0e
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion onmt/models/model_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,25 @@ def _save(self, step, model):
model_state_dict = {k: v for k, v in model_state_dict.items()
if 'generator' not in k}
generator_state_dict = real_generator.state_dict()

# NOTE: We need to trim the vocab to remove any unk tokens that
# were not originally here.

vocab = deepcopy(self.fields)
for side in ["src", "tgt"]:
keys_to_pop = []
if hasattr(vocab[side], "fields"):
unk_token = vocab[side].fields[0][1].vocab.itos[0]
for key, value in vocab[side].fields[0][1].vocab.stoi.items():
if value == 0 and key != unk_token:
keys_to_pop.append(key)
for key in keys_to_pop:
vocab[side].fields[0][1].vocab.stoi.pop(key, None)

checkpoint = {
'model': model_state_dict,
'generator': generator_state_dict,
'vocab': self.fields,
'vocab': vocab,
'opt': self.model_opt,
'optim': self.optim.state_dict(),
}
Expand Down

0 comments on commit 8292a0e

Please sign in to comment.