Skip to content

Commit

Permalink
reduce ignores
Browse files Browse the repository at this point in the history
  • Loading branch information
ShigekiKarita committed Dec 25, 2017
1 parent 1e53828 commit e73eb1d
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 53 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ testpaths = test
python_paths = src/nets src/utils src/bin

[flake8]
ignore = H102,H301,H306,H404,H405
ignore = H102,H306
# 120 is a workaround, 79 is good
max-line-length = 120
exclude = src/utils
8 changes: 2 additions & 6 deletions src/bin/asr_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,7 @@ def delete_feat(batch):


def adadelta_eps_decay(eps_decay):
'''
Extension to perform adadelta eps decay
'''
'''Extension to perform adadelta eps decay'''
@training.make_extension(trigger=(1, 'epoch'))
def adadelta_eps_decay(trainer):
_adadelta_eps_decay(trainer, eps_decay)
Expand All @@ -236,9 +234,7 @@ def _adadelta_eps_decay(trainer, eps_decay):


def restore_snapshot(model, snapshot, load_fn=chainer.serializers.load_npz):
'''
Extension to restore snapshot
'''
'''Extension to restore snapshot'''
@training.make_extension(trigger=(1, 'epoch'))
def restore_snapshot(trainer):
_restore_snapshot(model, snapshot, load_fn)
Expand Down
12 changes: 4 additions & 8 deletions src/bin/asr_train_th.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,14 @@ def update_core(self):

# Custom trigger
class CompareValueTrigger(object):
'''Trigger invoked when key value getting bigger or lower than before
'''\Trigger invoked when key value getting bigger or lower than before
Args:
key (str): Key of value.
compare_fn: Function to compare the values.
trigger: Trigger that decide the comparison interval
'''
'''\

def __init__(self, key, compare_fn, trigger=(1, 'epoch')):
self._key = key
Expand Down Expand Up @@ -222,9 +222,7 @@ def delete_feat(batch):


def adadelta_eps_decay(eps_decay):
'''
Extension to perform adadelta eps decay
'''
'''Extension to perform adadelta eps decay'''
@training.make_extension(trigger=(1, 'epoch'))
def adadelta_eps_decay(trainer):
_adadelta_eps_decay(trainer, eps_decay)
Expand All @@ -240,9 +238,7 @@ def _adadelta_eps_decay(trainer, eps_decay):


def restore_snapshot(model, snapshot, load_fn=chainer.serializers.load_npz):
'''
Extension to restore snapshot
'''
'''Extension to restore snapshot'''
@training.make_extension(trigger=(1, 'epoch'))
def restore_snapshot(trainer):
_restore_snapshot(model, snapshot, load_fn)
Expand Down
33 changes: 16 additions & 17 deletions src/nets/e2e_asr_attctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ def _get_vgg2l_odim(idim, in_channel=3, out_channel=128):


def linear_tensor(linear, x):
'''
Apply linear matrix operation only for the last dimension of a tensor
'''Apply linear matrix operation only for the last dimension of a tensor
:param Link linear: Linear link (M x N matrix)
:param Variable x: Tensor (D_1 x D_2 x ... x M matrix)
Expand Down Expand Up @@ -72,7 +71,7 @@ def __init__(self, predictor, mtlalpha):
self.predictor = predictor

def __call__(self, x):
'''
'''Loss forward
:param x:
:return:
Expand Down Expand Up @@ -143,7 +142,7 @@ def __init__(self, idim, odim, args):

# x[i]: ('utt_id', {'ilen':'xxx',...}})
def __call__(self, data):
'''
'''E2E forward
:param data:
:return:
Expand Down Expand Up @@ -184,7 +183,7 @@ def __call__(self, data):
return loss_ctc, loss_att, acc

def recognize(self, x, recog_args, char_list):
'''
'''E2E greedy/beam search
:param x:
:param recog_args:
Expand Down Expand Up @@ -222,7 +221,7 @@ def __init__(self, odim, eprojs, dropout_rate):
self.ctc_lo = L.Linear(eprojs, odim)

def __call__(self, hs, ys):
'''
'''CTC forward
:param hs:
:param ys:
Expand Down Expand Up @@ -273,7 +272,7 @@ def __init__(self, eprojs, dunits, att_dim):
self.pre_compute_enc_h = None

def reset(self):
'''
'''reset states
:return:
'''
Expand All @@ -282,7 +281,7 @@ def reset(self):
self.pre_compute_enc_h = None

def __call__(self, enc_hs, dec_z, att_prev, scaling=2.0):
'''
'''AttDot forward
:param enc_hs:
:param dec_z:
Expand Down Expand Up @@ -337,7 +336,7 @@ def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts):
self.aconv_chans = aconv_chans

def reset(self):
'''
'''reset states
:return:
'''
Expand All @@ -346,7 +345,7 @@ def reset(self):
self.pre_compute_enc_h = None

def __call__(self, enc_hs, dec_z, att_prev, scaling=2.0):
'''
'''AttLoc forward
:param enc_hs:
:param dec_z:
Expand Down Expand Up @@ -424,7 +423,7 @@ def __init__(self, eprojs, odim, dlayers, dunits, sos, eos, att, verbose=0, char
self.char_list = char_list

def __call__(self, hs, ys):
'''
'''Decoder forward
:param hs:
:param ys:
Expand Down Expand Up @@ -499,7 +498,7 @@ def __call__(self, hs, ys):
return self.loss, acc, att_weight_all

def recognize(self, h, recog_args):
'''
'''greedy search implementation
:param h:
:param recog_args:
Expand Down Expand Up @@ -534,7 +533,7 @@ def recognize(self, h, recog_args):
return y_seq

def recognize_beam(self, h, recog_args, char_list):
'''
'''beam search implementation
:param h:
:param recog_args:
Expand Down Expand Up @@ -693,7 +692,7 @@ def __init__(self, etype, idim, elayers, eunits, eprojs, subsample, dropout, in_
self.etype = etype

def __call__(self, xs, ilens):
'''
'''Encoder forward
:param xs:
:param ilens:
Expand Down Expand Up @@ -737,7 +736,7 @@ def __init__(self, idim, elayers, cdim, hdim, subsample, dropout):
self.subsample = subsample

def __call__(self, xs, ilens):
'''
'''BLSTMP forward
:param xs:
:param ilens:
Expand Down Expand Up @@ -773,7 +772,7 @@ def __init__(self, idim, elayers, cdim, hdim, dropout):
self.l_last = L.Linear(cdim * 2, hdim)

def __call__(self, xs, ilens):
'''
'''BLSTM forward
:param xs:
:param ilens:
Expand Down Expand Up @@ -809,7 +808,7 @@ def __init__(self, in_channel=1):
self.in_channel = in_channel

def __call__(self, xs, ilens):
'''
'''VGG2L forward
:param xs:
:param ilens:
Expand Down
42 changes: 22 additions & 20 deletions src/nets/e2e_asr_attctc_th.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import torch
from torch.nn import functional
from torch.autograd import Variable
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence
from torch.nn.utils.rnn import pack_padded_sequence
from warpctc_pytorch import _CTC
import sys
import logging
Expand Down Expand Up @@ -76,8 +77,7 @@ def _get_max_pooled_size(idim, out_channel=128, n_layers=2, ksize=2, stride=2):


def linear_tensor(linear, x):
'''
Apply linear matrix operation only for the last dimension of a tensor
'''Apply linear matrix operation only for the last dimension of a tensor
:param Link linear: Linear link (M x N matrix)
:param Variable x: Tensor (D_1 x D_2 x ... x M matrix)
Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(self, predictor, mtlalpha):
self.reporter = Reporter()

def forward(self, x):
'''
'''Loss forward
:param x:
:return:
Expand Down Expand Up @@ -204,7 +204,8 @@ def __init__(self, idim, odim, args):
# set_forget_bias_to_one(p)

def init_like_chainer(self):
"""
"""Initialize weight like chainer
chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0
pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5)
Expand All @@ -223,7 +224,7 @@ def init_like_chainer(self):

# x[i]: ('utt_id', {'ilen':'xxx',...}})
def forward(self, data):
'''
'''E2E forward
:param data:
:return:
Expand Down Expand Up @@ -271,7 +272,7 @@ def forward(self, data):
return loss_ctc, loss_att, acc

def recognize(self, x, recog_args, char_list):
'''
'''E2E greedy/beam search
:param x:
:param recog_args:
Expand Down Expand Up @@ -312,7 +313,8 @@ def backward(self, grad_output):


def chainer_like_ctc_loss(acts, labels, act_lens, label_lens):
"""
"""Chainer like CTC Loss
acts: Tensor of (seqLength x batch x outputDim) containing output from network
labels: 1 dimensional Tensor containing all the targets of the batch in one sequence
act_lens: Tensor of size (batch) containing size of each output sequence from the network
Expand All @@ -335,7 +337,7 @@ def __init__(self, odim, eprojs, dropout_rate):
self.loss_fn = chainer_like_ctc_loss # CTCLoss()

def forward(self, hpad, ilens, ys):
'''
'''CTC forward
:param hs:
:param ys:
Expand Down Expand Up @@ -392,7 +394,7 @@ def __init__(self, eprojs, dunits, att_dim):
self.pre_compute_enc_h = None

def reset(self):
'''
'''reset states
:return:
'''
Expand All @@ -401,7 +403,7 @@ def reset(self):
self.pre_compute_enc_h = None

def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0):
'''
'''AttDot
:param enc_hs:
:param dec_z:
Expand Down Expand Up @@ -455,7 +457,7 @@ def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts):
self.aconv_chans = aconv_chans

def reset(self):
'''
'''reset states
:return:
'''
Expand All @@ -464,7 +466,7 @@ def reset(self):
self.pre_compute_enc_h = None

def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0):
'''
'''AttLoc forward
:param enc_hs:
:param dec_z:
Expand Down Expand Up @@ -551,7 +553,7 @@ def zero_state(self, hpad):
return Variable(hpad.data.new(hpad.size(0), self.dunits).zero_())

def forward(self, hpad, hlen, ys):
'''
'''Decoder forward
:param hs:
:param ys:
Expand Down Expand Up @@ -627,7 +629,7 @@ def forward(self, hpad, hlen, ys):
return self.loss, acc, att_weight_all

def recognize(self, h, recog_args):
'''
'''greedy search implementation
:param h:
:param recog_args:
Expand Down Expand Up @@ -666,7 +668,7 @@ def recognize(self, h, recog_args):
return y_seq

def recognize_beam(self, h, recog_args, char_list):
'''
'''beam search implementation
:param h:
:param recog_args:
Expand Down Expand Up @@ -829,7 +831,7 @@ def __init__(self, etype, idim, elayers, eunits, eprojs, subsample, dropout, in_
self.etype = etype

def forward(self, xs, ilens):
'''
'''Encoder forward
:param xs:
:param ilens:
Expand Down Expand Up @@ -871,7 +873,7 @@ def __init__(self, idim, elayers, cdim, hdim, subsample, dropout):
self.subsample = subsample

def forward(self, xpad, ilens):
'''
'''BLSTMP forward
:param xs:
:param ilens:
Expand Down Expand Up @@ -904,7 +906,7 @@ def __init__(self, idim, elayers, cdim, hdim, dropout):
self.l_last = torch.nn.Linear(cdim * 2, hdim)

def forward(self, xpad, ilens):
'''
'''BLSTM forward
:param xs:
:param ilens:
Expand Down Expand Up @@ -935,7 +937,7 @@ def __init__(self, in_channel=1):
self.in_channel = in_channel

def forward(self, xs, ilens):
'''
'''VGG2L forward
:param xs:
:param ilens:
Expand Down
Loading

0 comments on commit e73eb1d

Please sign in to comment.