Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
pltrdy committed May 25, 2018
1 parent 1953d39 commit 662a84e
Show file tree
Hide file tree
Showing 44 changed files with 177 additions and 812 deletions.
2 changes: 2 additions & 0 deletions onmt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
""" Main entry point of the ONMT library """
from __future__ import division, print_function

import onmt.inputters
import onmt.encoders
import onmt.decoders
Expand Down
12 changes: 6 additions & 6 deletions onmt/decoders/cnn_decoder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
Implementation of the CNN Decoder part of "Convolutional Sequence to Sequence Learning"
Implementation of the CNN Decoder part of
"Convolutional Sequence to Sequence Learning"
"""
import torch
import torch.nn as nn
from torch.autograd import Variable

import onmt.modules
from onmt.decoders.decoder import DecoderState
Expand Down Expand Up @@ -92,8 +92,8 @@ def forward(self, tgt, memory_bank, state, memory_lengths=None):
x = linear_out.view(tgt_emb.size(0), tgt_emb.size(1), -1)
x = shape_transform(x)

pad = Variable(torch.zeros(x.size(0), x.size(1),
self.cnn_kernel_width - 1, 1))
pad = torch.zeros(x.size(0), x.size(1),
self.cnn_kernel_width - 1, 1)
pad = pad.type_as(x)
base_target_emb = x

Expand Down Expand Up @@ -152,5 +152,5 @@ def update_state(self, new_input):

def repeat_beam_size_times(self, beam_size):
""" Repeat beam_size times along batch dimension. """
self.init_src = Variable(
self.init_src.data.repeat(1, beam_size, 1), volatile=True)
self.init_src = torch.tensor(
self.init_src.data.repeat(1, beam_size, 1), requires_grad=False)
4 changes: 2 additions & 2 deletions onmt/decoders/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ def _fix_enc_hidden(hidden):
# The encoder hidden is (layers*directions) x batch x dim.
# We need to convert it to layers x batch x (directions*dim).
if self.bidirectional_encoder:
hidden = torch.cat(
[hidden[0:hidden.size(0):2], hidden[1:hidden.size(0):2]], 2)
hidden = torch.cat([hidden[0:hidden.size(0):2],
hidden[1:hidden.size(0):2]], 2)
return hidden

if isinstance(encoder_final, tuple): # LSTM
Expand Down
5 changes: 2 additions & 3 deletions onmt/decoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np

import onmt
Expand Down Expand Up @@ -57,8 +56,8 @@ def forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask,

src_batch, _, s_len = src_pad_mask.size()
tgt_batch, _, _ = tgt_pad_mask.size()
#src_batch, t_len, s_len = src_pad_mask.size()
#tgt_batch, t_len_, t_len__ = tgt_pad_mask.size()
# src_batch, t_len, s_len = src_pad_mask.size()
# tgt_batch, t_len_, t_len__ = tgt_pad_mask.size()
aeq(input_batch, contxt_batch, src_batch, tgt_batch)
# aeq(t_len, t_len_, t_len__, input_len)
aeq(s_len, contxt_len)
Expand Down
3 changes: 3 additions & 0 deletions onmt/encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
from onmt.encoders.rnn_encoder import RNNEncoder
from onmt.encoders.cnn_encoder import CNNEncoder
from onmt.encoders.mean_encoder import MeanEncoder

__all__ = ["EncoderBase", "TransformerEncoder", "RNNEncoder", "CNNEncoder",
"MeanEncoder"]
2 changes: 1 addition & 1 deletion onmt/encoders/cnn_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def forward(self, input, lengths=None, hidden=None):
self._check_args(input, lengths, hidden)

emb = self.embeddings(input)
#s_len, batch, emb_dim = emb.size()
# s_len, batch, emb_dim = emb.size()

emb = emb.transpose(0, 1).contiguous()
emb_reshape = emb.view(emb.size(0) * emb.size(1), -1)
Expand Down
2 changes: 1 addition & 1 deletion onmt/encoders/rnn_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def forward(self, src, lengths=None):
self._check_args(src, lengths)

emb = self.embeddings(src)
#s_len, batch, emb_dim = emb.size()
# s_len, batch, emb_dim = emb.size()

packed_emb = emb
if lengths is not None and not self.no_pack_padded_seq:
Expand Down
2 changes: 1 addition & 1 deletion onmt/inputters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
collect_features, get_num_features, \
load_fields_from_vocab, get_fields, \
save_fields_to_vocab, build_dataset, \
build_vocab, merge_vocabs, OrderedIterator
build_vocab, merge_vocabs, OrderedIterator
from onmt.inputters.dataset_base import DatasetBase, PAD_WORD, BOS_WORD, \
EOS_WORD, UNK
from onmt.inputters.text_dataset import TextDataset, ShardedTextCorpusIterator
Expand Down
10 changes: 5 additions & 5 deletions onmt/inputters/audio_dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# -*- coding: utf-8 -*-
""" Dataset for data_type=='audio'"""

"""
AudioDataset
"""
import codecs
import os

import torch
import torchtext

from onmt.inputters.dataset_base import DatasetBase, PAD_WORD, BOS_WORD, EOS_WORD
from onmt.inputters.dataset_base import DatasetBase, PAD_WORD, BOS_WORD, \
EOS_WORD


class AudioDataset(DatasetBase):
Expand Down Expand Up @@ -140,7 +141,6 @@ def read_audio_file(path, src_dir, side, sample_rate, window_size,
assert (src_dir is not None) and os.path.exists(src_dir),\
"src_dir must be a valid directory if data_type is audio"

#global torchaudio, librosa, np
import torchaudio
import librosa
import numpy as np
Expand Down
10 changes: 5 additions & 5 deletions onmt/inputters/dataset_base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# -*- coding: utf-8 -*-
"""Define inputters DatasetBase Class."""


"""
Base dataset class and constants
"""
from itertools import chain
import torch
import torchtext

import onmt

PAD_WORD = '<blank>'
UNK_WORD = '<unk>'
UNK = 0
Expand Down
20 changes: 10 additions & 10 deletions onmt/inputters/image_dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# -*- coding: utf-8 -*-
""" Dataset for data_type=='img'"""

"""
ImageDataset
"""
import codecs
import os

import torch
import torchtext

from onmt.inputters.dataset_base import DatasetBase, PAD_WORD, BOS_WORD, EOS_WORD
from onmt.inputters.dataset_base import DatasetBase, PAD_WORD, BOS_WORD, \
EOS_WORD


class ImageDataset(DatasetBase):
Expand Down Expand Up @@ -95,13 +96,13 @@ def make_image_examples_nfeats_tpl(img_iter, img_path, img_dir):
else:
raise ValueError("""One of 'img_iter' and 'img_path'
must be not None""")
examples_iter = ImageDataset.make_examples(data_iter, img_dir, 'src')
examples_iter = ImageDataset.make_examples(img_iter, img_dir, 'src')
num_feats = 0 # Source side(img) has no features.

return (examples_iter, num_feats)

@staticmethod
def make_examples(data_iter, src_dir, side, truncate=None):
def make_examples(img_iter, src_dir, side, truncate=None):
"""
Args:
path (str): location of a src file containing image paths
Expand All @@ -115,10 +116,6 @@ def make_examples(data_iter, src_dir, side, truncate=None):
assert (src_dir is not None) and os.path.exists(src_dir),\
'src_dir must be a valid directory if data_type is img'

#global Image, transforms
from PIL import Image
from torchvision import transforms

for index, (img, filename) in enumerate(img_iter):
if truncate and truncate != (0, 0):
if not (img.size(1) <= truncate[0]
Expand All @@ -141,6 +138,9 @@ def make_img_iterator_from_file(path, src_dir):
img: and image tensor
filename(str): the image filename
"""
from PIL import Image
from torchvision import transforms

with codecs.open(path, "r", "utf-8") as corpus_file:
for line in corpus_file:
filename = line.strip()
Expand Down
13 changes: 6 additions & 7 deletions onmt/inputters/inputter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
"""Define generic inputters."""

from __future__ import print_function
"""
Defining general functions for inputters
"""
import glob
import os
from collections import Counter, defaultdict, OrderedDict
Expand Down Expand Up @@ -177,8 +176,8 @@ def collect_feature_vocabs(fields, side):
return feature_vocabs


def build_dataset(fields, data_type, src_data_iter=None, src_path=None, src_dir=None,
tgt_data_iter=None, tgt_path=None,
def build_dataset(fields, data_type, src_data_iter=None, src_path=None,
src_dir=None, tgt_data_iter=None, tgt_path=None,
src_seq_length=0, tgt_seq_length=0,
src_seq_length_trunc=0, tgt_seq_length_trunc=0,
dynamic_dict=True, sample_rate=0,
Expand Down Expand Up @@ -522,7 +521,7 @@ def _lazy_dataset_loader(pt_file, corpus_type):

def _load_fields(dataset, data_type, opt, checkpoint):
if checkpoint is not None:
print('Loading vocab from checkpoint at %s.' % OPT.train_from)
print('Loading vocab from checkpoint at %s.' % opt.train_from)
fields = load_fields_from_vocab(
checkpoint['vocab'], data_type)
else:
Expand Down
1 change: 1 addition & 0 deletions onmt/inputters/text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from onmt.inputters.dataset_base import (DatasetBase, UNK_WORD,
PAD_WORD, BOS_WORD, EOS_WORD)
from onmt.utils import aeq


class TextDataset(DatasetBase):
Expand Down
4 changes: 4 additions & 0 deletions onmt/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@
from onmt.encoders.transformer import TransformerEncoder
from onmt.encoders.cnn_encoder import CNNEncoder
from onmt.encoders.mean_encoder import MeanEncoder
from onmt.encoders.audio_encoder import AudioEncoder
from onmt.encoders.image_encoder import ImageEncoder

from onmt.decoders.decoder import InputFeedRNNDecoder, StdRNNDecoder
from onmt.decoders.transformer import TransformerDecoder
from onmt.decoders.cnn_decoder import CNNDecoder

from onmt.modules import Embeddings, CopyGenerator
from onmt.utils.misc import use_gpu

Expand Down
Loading

0 comments on commit 662a84e

Please sign in to comment.