Skip to content

Commit

Permalink
Add from_opt constructor to all encs/decs. (OpenNMT#1230)
Browse files Browse the repository at this point in the history
* Add from_opt constructor to all encs/decs.
  • Loading branch information
flauted authored and vince62s committed Jan 29, 2019
1 parent 4e08ac4 commit 45857f7
Show file tree
Hide file tree
Showing 13 changed files with 161 additions and 113 deletions.
11 changes: 11 additions & 0 deletions onmt/decoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,12 @@
"""Module defining decoders."""
from onmt.decoders.decoder import DecoderBase, InputFeedRNNDecoder, \
StdRNNDecoder
from onmt.decoders.transformer import TransformerDecoder
from onmt.decoders.cnn_decoder import CNNDecoder


str2dec = {"rnn": StdRNNDecoder, "ifrnn": InputFeedRNNDecoder,
"cnn": CNNDecoder, "transformer": TransformerDecoder}

__all__ = ["DecoderBase", "TransformerDecoder", "StdRNNDecoder", "CNNDecoder",
"InputFeedRNNDecoder", "str2dec"]
14 changes: 13 additions & 1 deletion onmt/decoders/cnn_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@

from onmt.modules import ConvMultiStepAttention, GlobalAttention
from onmt.utils.cnn_factory import shape_transform, GatedConv
from onmt.decoders.decoder import DecoderBase

SCALE_WEIGHT = 0.5 ** 0.5


class CNNDecoder(nn.Module):
class CNNDecoder(DecoderBase):
"""
Decoder built on CNN, based on :cite:`DBLP:journals/corr/GehringAGYD17`.
Expand Down Expand Up @@ -46,6 +47,17 @@ def __init__(self, num_layers, hidden_size, attn_type,
else:
self.copy_attn = None

@classmethod
def from_opt(cls, opt, embeddings):
return cls(
opt.dec_layers,
opt.dec_rnn_size,
opt.global_attention,
opt.copy_attn,
opt.cnn_kernel_width,
opt.dropout,
embeddings)

def init_state(self, _, memory_bank, enc_hidden):
"""
Init decoder state.
Expand Down
24 changes: 23 additions & 1 deletion onmt/decoders/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
from onmt.utils.misc import aeq


class RNNDecoderBase(nn.Module):
class DecoderBase(nn.Module):
@classmethod
def from_opt(cls, opt, embeddings):
raise NotImplementedError


class RNNDecoderBase(DecoderBase):
"""
Base recurrent attention-based decoder class.
Specifies the interface used by different decoder types
Expand Down Expand Up @@ -102,6 +108,22 @@ def __init__(self, rnn_type, bidirectional_encoder, num_layers,

self._reuse_copy_attn = reuse_copy_attn and copy_attn

@classmethod
def from_opt(cls, opt, embeddings):
return cls(
opt.rnn_type,
opt.brnn,
opt.dec_layers,
opt.dec_rnn_size,
opt.global_attention,
opt.global_attention_function,
opt.coverage_attn,
opt.context_gate,
opt.copy_attn,
opt.dropout,
embeddings,
opt.reuse_copy_attn)

def init_state(self, src, memory_bank, encoder_final):
""" Init decoder state with last state of the encoder """
def _fix_enc_hidden(hidden):
Expand Down
16 changes: 15 additions & 1 deletion onmt/decoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.nn as nn
import numpy as np

from onmt.decoders.decoder import DecoderBase
from onmt.modules import MultiHeadedAttention, AverageAttention
from onmt.modules.position_ffn import PositionwiseFeedForward

Expand Down Expand Up @@ -107,7 +108,7 @@ def _get_attn_subsequent_mask(self, size):
return subsequent_mask


class TransformerDecoder(nn.Module):
class TransformerDecoder(DecoderBase):
"""
The Transformer decoder from "Attention is All You Need".
Expand Down Expand Up @@ -157,6 +158,19 @@ def __init__(self, num_layers, d_model, heads, d_ff, attn_type,
self._copy = copy_attn
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

@classmethod
def from_opt(cls, opt, embeddings):
return cls(
opt.dec_layers,
opt.dec_rnn_size,
opt.heads,
opt.transformer_ff,
opt.global_attention,
opt.copy_attn,
opt.self_attn_type,
opt.dropout,
embeddings)

def init_state(self, src, memory_bank, enc_hidden):
""" Init decoder state """
self.state["src"] = src
Expand Down
9 changes: 8 additions & 1 deletion onmt/encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@
from onmt.encoders.rnn_encoder import RNNEncoder
from onmt.encoders.cnn_encoder import CNNEncoder
from onmt.encoders.mean_encoder import MeanEncoder
from onmt.encoders.audio_encoder import AudioEncoder
from onmt.encoders.image_encoder import ImageEncoder


str2enc = {"rnn": RNNEncoder, "brnn": RNNEncoder, "cnn": CNNEncoder,
"transformer": TransformerEncoder, "img": ImageEncoder,
"audio": AudioEncoder, "mean": MeanEncoder}

__all__ = ["EncoderBase", "TransformerEncoder", "RNNEncoder", "CNNEncoder",
"MeanEncoder"]
"MeanEncoder", "str2enc"]
19 changes: 18 additions & 1 deletion onmt/encoders/audio_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from torch.nn.utils.rnn import pad_packed_sequence as unpack

from onmt.utils.rnn_factory import rnn_factory
from onmt.encoders.encoder import EncoderBase


class AudioEncoder(nn.Module):
class AudioEncoder(EncoderBase):
"""
A simple encoder convolutional -> recurrent neural network for
audio input.
Expand Down Expand Up @@ -75,6 +76,22 @@ def __init__(self, rnn_type, enc_layers, dec_layers, brnn,
nn.MaxPool1d(enc_pooling[l + 1]))
setattr(self, 'batchnorm_%d' % (l + 1), batchnorm)

@classmethod
def from_opt(cls, opt, embeddings=None):
if embeddings is not None:
raise ValueError("Cannot use embeddings with AudioEncoder.")
return cls(
opt.rnn_type,
opt.enc_layers,
opt.dec_layers,
opt.brnn,
opt.enc_rnn_size,
opt.dec_rnn_size,
opt.audio_enc_pooling,
opt.dropout,
opt.sample_rate,
opt.window_size)

def forward(self, src, lengths=None):
"See :obj:`onmt.encoders.encoder.EncoderBase.forward()`"

Expand Down
9 changes: 9 additions & 0 deletions onmt/encoders/cnn_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ def __init__(self, num_layers, hidden_size,
self.cnn = StackedCNN(num_layers, hidden_size,
cnn_kernel_width, dropout)

@classmethod
def from_opt(cls, opt, embeddings):
return cls(
opt.enc_layers,
opt.enc_rnn_size,
opt.cnn_kernel_width,
opt.dropout,
embeddings)

def forward(self, input, lengths=None, hidden=None):
""" See :obj:`onmt.modules.EncoderBase.forward()`"""
self._check_args(input, lengths, hidden)
Expand Down
4 changes: 4 additions & 0 deletions onmt/encoders/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class EncoderBase(nn.Module):
E-->G
"""

@classmethod
def from_opt(cls, opt, embeddings=None):
raise NotImplementedError

def _check_args(self, src, lengths=None, hidden=None):
_, n_batch, _ = src.size()
if lengths is not None:
Expand Down
21 changes: 20 additions & 1 deletion onmt/encoders/image_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import torch.nn.functional as F
import torch

from onmt.encoders.encoder import EncoderBase

class ImageEncoder(nn.Module):

class ImageEncoder(EncoderBase):
"""
A simple encoder convolutional -> recurrent neural network for
image src.
Expand Down Expand Up @@ -47,6 +49,23 @@ def __init__(self, num_layers, bidirectional, rnn_size, dropout,
bidirectional=bidirectional)
self.pos_lut = nn.Embedding(1000, src_size)

@classmethod
def from_opt(cls, opt, embeddings=None):
if embeddings is not None:
raise ValueError("Cannot use embeddings with ImageEncoder.")
# why is the model_opt.__dict__ check necessary?
if "image_channel_size" not in opt.__dict__:
image_channel_size = 3
else:
image_channel_size = opt.image_channel_size
return cls(
opt.enc_layers,
opt.brnn,
opt.enc_rnn_size,
opt.dropout,
image_channel_size
)

def load_pretrained_vectors(self, opt):
""" Pass in needed options only when modify function definition."""
pass
Expand Down
6 changes: 6 additions & 0 deletions onmt/encoders/mean_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ def __init__(self, num_layers, embeddings):
self.num_layers = num_layers
self.embeddings = embeddings

@classmethod
def from_opt(cls, opt, embeddings):
return cls(
opt.enc_layers,
embeddings)

def forward(self, src, lengths=None):
"See :obj:`EncoderBase.forward()`"
self._check_args(src, lengths)
Expand Down
11 changes: 11 additions & 0 deletions onmt/encoders/rnn_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,17 @@ def __init__(self, rnn_type, bidirectional, num_layers,
hidden_size,
num_layers)

@classmethod
def from_opt(cls, opt, embeddings):
return cls(
opt.rnn_type,
opt.brnn,
opt.enc_layers,
opt.enc_rnn_size,
opt.dropout,
embeddings,
opt.bridge)

def forward(self, src, lengths=None):
"See :obj:`EncoderBase.forward()`"
self._check_args(src, lengths)
Expand Down
10 changes: 10 additions & 0 deletions onmt/encoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,16 @@ def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings):
for i in range(num_layers)])
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

@classmethod
def from_opt(cls, opt, embeddings):
return cls(
opt.enc_layers,
opt.enc_rnn_size,
opt.heads,
opt.transformer_ff,
opt.dropout,
embeddings)

def forward(self, src, lengths=None):
""" See :obj:`EncoderBase.forward()`"""
self._check_args(src, lengths)
Expand Down
Loading

0 comments on commit 45857f7

Please sign in to comment.