Skip to content

Commit

Permalink
add black to ci and change autopep8 -> pycodestyle
Browse files Browse the repository at this point in the history
  • Loading branch information
kamo-naoyuki committed Dec 24, 2019
1 parent 4f8b4ec commit 3e22a98
Show file tree
Hide file tree
Showing 54 changed files with 427 additions and 576 deletions.
10 changes: 8 additions & 2 deletions ci/test_python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@ set -euo pipefail

"$(dirname $0)"/test_flake8.sh

autopep8 -r espnet espnet2 test utils --global-config .pep8 --diff --max-line-length 120 | tee check_autopep8
test ! -s check_autopep8
pycodestyle -r espnet test utils --show-source --show-pep8
if ! black --check espnet2 test/espnet2 setup.py; then
echo "Please apply: 'black espnet2/ test/espnet2 setup.py'"
exit 1
fi

# espnet2 follows "black" style.
pycodestyle -r espnet2 test/espnet2 setup.py --max-line-length 88 --ignore E203,W503 --show-source --show-pep8

LD_LIBRARY_PATH="${LD_LIBRARY_PATH:-}:$(pwd)/tools/chainer_ctc/ext/warp-ctc/build" pytest
24 changes: 5 additions & 19 deletions espnet2/asr/decoder/rnn_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,7 @@ def build_attention_list(
if num_encs == 1:
for i in range(num_att):
att = initial_att(
atype,
eprojs,
dunits,
aheads,
adim,
awin,
aconv_chans,
aconv_filts,
atype, eprojs, dunits, aheads, adim, awin, aconv_chans, aconv_filts,
)
att_list.append(att)
elif num_encs > 1: # no multi-speaker mode
Expand Down Expand Up @@ -155,8 +148,7 @@ def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev):
z_list[0], c_list[0] = self.decoder[0](ey, (z_prev[0], c_prev[0]))
for l in range(1, self.dlayers):
z_list[l], c_list[l] = self.decoder[l](
self.dropout_dec[l - 1](z_list[l - 1]),
(z_prev[l], c_prev[l]),
self.dropout_dec[l - 1](z_list[l - 1]), (z_prev[l], c_prev[l]),
)
else:
z_list[0] = self.decoder[0](ey, z_prev[0])
Expand Down Expand Up @@ -220,9 +212,7 @@ def forward(self, hs_pad, hlens, ys_in_pad, ys_in_lens, strm_idx=0):
)
hs_pad_han = torch.stack(att_c_list, dim=1)
hlens_han = [self.num_encs] * len(ys_in_pad)
att_c, att_w_list[self.num_encs] = self.att_list[
self.num_encs
](
att_c, att_w_list[self.num_encs] = self.att_list[self.num_encs](
hs_pad_han,
hlens_han,
self.dropout_dec[0](z_list[0]),
Expand All @@ -236,14 +226,10 @@ def forward(self, hs_pad, hlens, ys_in_pad, ys_in_lens, strm_idx=0):
else:
# utt x (zdim + hdim)
ey = torch.cat((eys[:, i, :], att_c), dim=1)
z_list, c_list = self.rnn_forward(
ey, z_list, c_list, z_list, c_list
)
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
if self.context_residual:
z_all.append(
torch.cat(
(self.dropout_dec[-1](z_list[-1]), att_c), dim=-1
)
torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
) # utt x (zdim + hdim)
else:
z_all.append(self.dropout_dec[-1](z_list[-1])) # utt x (zdim)
Expand Down
19 changes: 5 additions & 14 deletions espnet2/asr/decoder/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,9 @@
from typeguard import check_argument_types

from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
from espnet.nets.pytorch_backend.transformer.attention import (
MultiHeadedAttention,
)
from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention
from espnet.nets.pytorch_backend.transformer.decoder_layer import DecoderLayer
from espnet.nets.pytorch_backend.transformer.embedding import (
PositionalEncoding,
)
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import (
Expand Down Expand Up @@ -82,13 +78,10 @@ def __init__(
)
elif isinstance(input_layer, torch.nn.Module):
self.embed = torch.nn.Sequential(
input_layer,
pos_enc_class(attention_dim, positional_dropout_rate),
input_layer, pos_enc_class(attention_dim, positional_dropout_rate),
)
else:
raise NotImplementedError(
"only `embed` or torch.nn.Module is supported."
)
raise NotImplementedError("only `embed` or torch.nn.Module is supported.")

self.normalize_before = normalize_before
self.decoders = repeat(
Expand All @@ -101,9 +94,7 @@ def __init__(
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(
attention_dim, linear_units, dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
Expand Down
24 changes: 6 additions & 18 deletions espnet2/asr/e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,18 +126,14 @@ def forward(

# 2c. RNN-T branch
if self.rnnt_decoder is not None:
_ = self._calc_rnnt_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
_ = self._calc_rnnt_loss(encoder_out, encoder_out_lens, text, text_lengths)

if self.ctc_weight == 0.0:
loss = loss_att
elif self.ctc_weight == 1.0:
loss = loss_ctc
else:
loss = (
self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
)
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att

stats = dict(
loss=loss.detach(),
Expand All @@ -150,9 +146,7 @@ def forward(
)

# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable(
(loss, stats, batch_size), loss.device
)
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight

def collect_feats(
Expand Down Expand Up @@ -223,9 +217,7 @@ def _calc_att_loss(
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
ys_in_pad, ys_out_pad = add_sos_eos(
ys_pad, self.sos, self.eos, self.ignore_id
)
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_in_lens = ys_pad_lens + 1

# 1. Forward decoder
Expand All @@ -246,9 +238,7 @@ def _calc_att_loss(
cer_att, wer_att = None, None
else:
ys_hat = decoder_out.argmax(dim=-1)
cer_att, wer_att = self.error_calculator(
ys_hat.cpu(), ys_pad.cpu()
)
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

return loss_att, acc_att, cer_att, wer_att

Expand All @@ -266,9 +256,7 @@ def _calc_ctc_loss(
cer_ctc = None
if self.error_calculator is not None:
ys_hat = self.ctc.argmax(encoder_out).data
cer_ctc = self.error_calculator(
ys_hat.cpu(), ys_pad.cpu(), is_ctc=True
)
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
return loss_ctc, cer_ctc

def _calc_rnnt_loss(
Expand Down
15 changes: 11 additions & 4 deletions espnet2/asr/encoder/rnn_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,16 @@ def __init__(

else:
self.enc = torch.nn.ModuleList(
[RNN(input_size, num_layers, hidden_size, output_size, dropout, typ=rnn_type)]
[
RNN(
input_size,
num_layers,
hidden_size,
output_size,
dropout,
typ=rnn_type,
)
]
)

def output_size(self) -> int:
Expand All @@ -97,9 +106,7 @@ def forward(

current_states = []
for module, prev_state in zip(self.enc, prev_states):
xs_pad, ilens, states = module(
xs_pad, ilens, prev_state=prev_state
)
xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state)
current_states.append(states)

if self.use_projection:
Expand Down
24 changes: 6 additions & 18 deletions espnet2/asr/encoder/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,17 @@
from typeguard import check_argument_types

from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
from espnet.nets.pytorch_backend.transformer.attention import (
MultiHeadedAttention,
)
from espnet.nets.pytorch_backend.transformer.embedding import (
PositionalEncoding,
)
from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
from espnet.nets.pytorch_backend.transformer.encoder_layer import EncoderLayer
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import (
Conv1dLinear,
)
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import (
MultiLayeredConv1d,
)
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import Conv1dLinear
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import MultiLayeredConv1d
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import (
PositionwiseFeedForward,
)
from espnet.nets.pytorch_backend.transformer.repeat import repeat
from espnet.nets.pytorch_backend.transformer.subsampling import (
Conv2dSubsampling,
)
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling
from espnet2.asr.encoder.abs_encoder import AbsEncoder


Expand Down Expand Up @@ -91,9 +81,7 @@ def __init__(
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(
input_size, output_size, padding_idx=padding_idx
),
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer is None:
Expand Down
4 changes: 1 addition & 3 deletions espnet2/asr/encoder/vgg_rnn_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ def forward(

current_states = []
for module, prev_state in zip(self.enc, prev_states):
xs_pad, ilens, states = module(
xs_pad, ilens, prev_state=prev_state
)
xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state)
current_states.append(states)

if self.use_projection:
Expand Down
7 changes: 1 addition & 6 deletions espnet2/asr/frontend/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,7 @@ def __init__(
self.frontend = None

self.logmel = LogMel(
fs=fs,
n_fft=n_fft,
fmin=fmin,
fmax=fmax,
htk=htk,
norm=norm,
fs=fs, n_fft=n_fft, fmin=fmin, fmax=fmax, htk=htk, norm=norm,
)
self.n_mels = n_mels

Expand Down
4 changes: 1 addition & 3 deletions espnet2/bin/aggregate_stats_dirs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@


def aggregate_stats_dirs(
input_dir: Iterable[Union[str, Path]],
output_dir: Union[str, Path],
log_level: str,
input_dir: Iterable[Union[str, Path]], output_dir: Union[str, Path], log_level: str,
):
logging.basicConfig(
level=log_level,
Expand Down
32 changes: 7 additions & 25 deletions espnet2/bin/asr_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,7 @@ def recog(

# 4. Build BeamSearch object
weights = dict(
decoder=1.0 - ctc_weight,
ctc=ctc_weight,
lm=lm_weight,
length_bonus=penalty,
decoder=1.0 - ctc_weight, ctc=ctc_weight, lm=lm_weight, length_bonus=penalty,
)
beam_search = BeamSearch(
beam_size=beam_size,
Expand Down Expand Up @@ -210,9 +207,7 @@ def recog(
ibest_writer = writer[f"{n}best_recog"]

# Write the result to each files
ibest_writer["token"][key] = " ".join(token).replace(
blank_symbol, ""
)
ibest_writer["token"][key] = " ".join(token).replace(blank_symbol, "")
ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["score"][key] = str(hyp.score)

Expand All @@ -230,9 +225,7 @@ def get_parser():

# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--config", is_config_file=True, help="config file path"
)
parser.add_argument("--config", is_config_file=True, help="config file path")

parser.add_argument(
"--log_level",
Expand All @@ -244,10 +237,7 @@ def get_parser():

parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
"--ngpu", type=int, default=0, help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
Expand All @@ -271,9 +261,7 @@ def get_parser():
action="append",
)
group.add_argument("--key_file", type=str_or_none)
group.add_argument(
"--allow_variable_data_keys", type=str2bool, default=False
)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)

group = parser.add_argument_group("The model configuration related")
group.add_argument("--asr_train_config", type=str, required=True)
Expand All @@ -285,10 +273,7 @@ def get_parser():

group = parser.add_argument_group("Beam-search related")
group.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
"--batch_size", type=int, default=1, help="The batch size for inference",
)
group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
group.add_argument("--beam_size", type=int, default=20, help="Beam size")
Expand All @@ -309,10 +294,7 @@ def get_parser():
help="Input length ratio to obtain min output length",
)
group.add_argument(
"--ctc_weight",
type=float,
default=0.5,
help="CTC weight in joint decoding",
"--ctc_weight", type=float, default=0.5, help="CTC weight in joint decoding",
)
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
group.add_argument(
Expand Down
Loading

0 comments on commit 3e22a98

Please sign in to comment.