Skip to content

Commit

Permalink
This PR is the first step of a code refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
pltrdy authored and vince62s committed Jun 20, 2018
1 parent e61589d commit e0c10af
Show file tree
Hide file tree
Showing 78 changed files with 6,059 additions and 1,720 deletions.
6 changes: 3 additions & 3 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ script:
# test speech2text preprocessing
- python preprocess.py -data_type audio -src_dir /tmp/speech/an4_dataset -train_src /tmp/speech/src-train.txt -train_tgt /tmp/speech/tgt-train.txt -valid_src /tmp/speech/src-val.txt -valid_tgt /tmp/speech/tgt-val.txt -save_data /tmp/speech/data && rm -rf /tmp/speech/data*.pt
# test nmt translation
- head data/src-test.txt > /tmp/src-test.txt; python translate.py -model test/test_model.pt -src /tmp/src-test.txt -verbose
- head data/src-test.txt > /tmp/src-test.txt; python translate.py -model onmt/tests/test_model.pt -src /tmp/src-test.txt -verbose
# test im2text translation
- head /tmp/im2text/src-val.txt > /tmp/im2text/src-val-head.txt; head /tmp/im2text/tgt-val.txt > /tmp/im2text/tgt-val-head.txt; python translate.py -data_type img -src_dir /tmp/im2text/images -model /tmp/test_model_im2text.pt -src /tmp/im2text/src-val-head.txt -tgt /tmp/im2text/tgt-val-head.txt -verbose -out /tmp/im2text/trans
# test speech2text translation
Expand All @@ -57,9 +57,9 @@ script:
# test speech2text preprocessing and training
- head /tmp/speech/src-val.txt > /tmp/speech/src-val-head.txt; head /tmp/speech/tgt-val.txt > /tmp/speech/tgt-val-head.txt; python preprocess.py -data_type audio -src_dir /tmp/speech/an4_dataset -train_src /tmp/speech/src-val-head.txt -train_tgt /tmp/speech/tgt-val-head.txt -valid_src /tmp/speech/src-val-head.txt -valid_tgt /tmp/speech/tgt-val-head.txt -save_data /tmp/speech/q; python train.py -model_type audio -data /tmp/speech/q -rnn_size 2 -batch_size 10 -word_vec_size 5 -report_every 5 -rnn_size 10 -epochs 1 && rm -rf /tmp/speech/q*.pt
# test nmt translation
- python translate.py -model test/test_model2.pt -src data/morph/src.valid -verbose -batch_size 10 -beam_size 10 -tgt data/morph/tgt.valid -out /tmp/trans; diff data/morph/tgt.valid /tmp/trans
- python translate.py -model onmt/tests/test_model2.pt -src data/morph/src.valid -verbose -batch_size 10 -beam_size 10 -tgt data/morph/tgt.valid -out /tmp/trans; diff data/morph/tgt.valid /tmp/trans
# test tool
- PYTHONPATH=$PYTHONPATH:. python tools/extract_embeddings.py -model test/test_model.pt
- PYTHONPATH=$PYTHONPATH:. python tools/extract_embeddings.py -model onmt/tests/test_model.pt

env:
global:
Expand Down
4 changes: 2 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
from recommonmark.transform import AutoStructify

source_parsers = {
'.md': CommonMarkParser,
}
'.md': CommonMarkParser,
}

source_suffix = ['.rst', '.md']
extensions = ['sphinx.ext.autodoc',
Expand Down
147 changes: 0 additions & 147 deletions onmt/Optim.py

This file was deleted.

13 changes: 3 additions & 10 deletions onmt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
import onmt.io
import onmt.Models
import onmt.Loss
import onmt.translate
import onmt.opts
from onmt.Trainer import Trainer, Statistics
from onmt.Optim import Optim
""" Main entry point of the ONMT library """
from onmt.trainer import Trainer

# For flake8 compatibility
__all__ = [onmt.Loss, onmt.Models, onmt.opts,
Trainer, Optim, Statistics, onmt.io, onmt.translate]
__all__ = ["Trainer"]
1 change: 1 addition & 0 deletions onmt/decoders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Module defining decoders."""
113 changes: 24 additions & 89 deletions onmt/modules/Conv2Conv.py → onmt/decoders/cnn_decoder.py
Original file line number Diff line number Diff line change
@@ -1,95 +1,18 @@
"""
Implementation of "Convolutional Sequence to Sequence Learning"
Implementation of the CNN Decoder part of
"Convolutional Sequence to Sequence Learning"
"""
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

import onmt.modules
from onmt.modules.WeightNorm import WeightNormConv2d
from onmt.Models import EncoderBase
from onmt.Models import DecoderState
from onmt.Utils import aeq

from onmt.decoders.decoder import DecoderState
from onmt.utils.misc import aeq
from onmt.utils.cnn_factory import shape_transform, GatedConv

SCALE_WEIGHT = 0.5 ** 0.5


def shape_transform(x):
""" Tranform the size of the tensors to fit for conv input. """
return torch.unsqueeze(torch.transpose(x, 1, 2), 3)


class GatedConv(nn.Module):
def __init__(self, input_size, width=3, dropout=0.2, nopad=False):
super(GatedConv, self).__init__()
self.conv = WeightNormConv2d(input_size, 2 * input_size,
kernel_size=(width, 1), stride=(1, 1),
padding=(width // 2 * (1 - nopad), 0))
init.xavier_uniform_(self.conv.weight, gain=(4 * (1 - dropout))**0.5)
self.dropout = nn.Dropout(dropout)

def forward(self, x_var, hidden=None):
x_var = self.dropout(x_var)
x_var = self.conv(x_var)
out, gate = x_var.split(int(x_var.size(1) / 2), 1)
out = out * F.sigmoid(gate)
return out


class StackedCNN(nn.Module):
def __init__(self, num_layers, input_size, cnn_kernel_width=3,
dropout=0.2):
super(StackedCNN, self).__init__()
self.dropout = dropout
self.num_layers = num_layers
self.layers = nn.ModuleList()
for i in range(num_layers):
self.layers.append(
GatedConv(input_size, cnn_kernel_width, dropout))

def forward(self, x, hidden=None):
for conv in self.layers:
x = x + conv(x)
x *= SCALE_WEIGHT
return x


class CNNEncoder(EncoderBase):
"""
Encoder built on CNN based on
:cite:`DBLP:journals/corr/GehringAGYD17`.
"""

def __init__(self, num_layers, hidden_size,
cnn_kernel_width, dropout, embeddings):
super(CNNEncoder, self).__init__()

self.embeddings = embeddings
input_size = embeddings.embedding_size
self.linear = nn.Linear(input_size, hidden_size)
self.cnn = StackedCNN(num_layers, hidden_size,
cnn_kernel_width, dropout)

def forward(self, input, lengths=None, hidden=None):
""" See :obj:`onmt.modules.EncoderBase.forward()`"""
self._check_args(input, lengths, hidden)

emb = self.embeddings(input)
s_len, batch, emb_dim = emb.size()

emb = emb.transpose(0, 1).contiguous()
emb_reshape = emb.view(emb.size(0) * emb.size(1), -1)
emb_remap = self.linear(emb_reshape)
emb_remap = emb_remap.view(emb.size(0), emb.size(1), -1)
emb_remap = shape_transform(emb_remap)
out = self.cnn(emb_remap)

return emb_remap.squeeze(3).transpose(0, 1).contiguous(),\
out.squeeze(3).transpose(0, 1).contiguous()


class CNNDecoder(nn.Module):
"""
Decoder built on CNN, based on :cite:`DBLP:journals/corr/GehringAGYD17`.
Expand All @@ -114,13 +37,13 @@ def __init__(self, num_layers, hidden_size, attn_type,
input_size = self.embeddings.embedding_size
self.linear = nn.Linear(input_size, self.hidden_size)
self.conv_layers = nn.ModuleList()
for i in range(self.num_layers):
for _ in range(self.num_layers):
self.conv_layers.append(
GatedConv(self.hidden_size, self.cnn_kernel_width,
self.dropout, True))

self.attn_layers = nn.ModuleList()
for i in range(self.num_layers):
for _ in range(self.num_layers):
self.attn_layers.append(
onmt.modules.ConvMultiStepAttention(self.hidden_size))

Expand All @@ -134,10 +57,12 @@ def __init__(self, num_layers, hidden_size, attn_type,

def forward(self, tgt, memory_bank, state, memory_lengths=None):
""" See :obj:`onmt.modules.RNNDecoderBase.forward()`"""
# NOTE: memory_lengths is only here for compatibility reasons
# with onmt.modules.RNNDecoderBase.forward()
# CHECKS
assert isinstance(state, CNNDecoderState)
tgt_len, tgt_batch, _ = tgt.size()
contxt_len, contxt_batch, _ = memory_bank.size()
_, tgt_batch, _ = tgt.size()
_, contxt_batch, _ = memory_bank.size()
aeq(tgt_batch, contxt_batch)
# END CHECKS

Expand Down Expand Up @@ -194,11 +119,18 @@ def forward(self, tgt, memory_bank, state, memory_lengths=None):

return outputs, state, attns

def init_decoder_state(self, src, memory_bank, enc_hidden):
def init_decoder_state(self, _, memory_bank, enc_hidden):
"""
Init decoder state.
"""
return CNNDecoderState(memory_bank, enc_hidden)


class CNNDecoderState(DecoderState):
"""
Init CNN decoder state.
"""

def __init__(self, memory_bank, enc_hidden):
self.init_src = (memory_bank + enc_hidden) * SCALE_WEIGHT
self.previous_input = None
Expand All @@ -210,9 +142,12 @@ def _all(self):
"""
return (self.previous_input,)

def update_state(self, input):
def detach(self):
self.previous_input = self.previous_input.detach()

def update_state(self, new_input):
""" Called for every decoder forward pass. """
self.previous_input = input
self.previous_input = new_input

def repeat_beam_size_times(self, beam_size):
""" Repeat beam_size times along batch dimension. """
Expand Down
Loading

0 comments on commit e0c10af

Please sign in to comment.