Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: add TDNNF to pytorch. #3892

Merged
merged 5 commits into from
Feb 11, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update training scripts.
  • Loading branch information
csukuangfj committed Jan 31, 2020
commit 154e36680bf401bc3eeff3403f899b4be68667c1
5 changes: 3 additions & 2 deletions egs/aishell/s10/chain/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ def main():
output_dim=args.output_dim,
lda_mat_filename=args.lda_mat_filename,
hidden_dim=args.hidden_dim,
kernel_size_list=args.kernel_size_list,
stride_list=args.stride_list)
bottleneck_dim=args.bottleneck_dim,
time_stride_list=args.time_stride_list,
conv_stride_list=args.conv_stride_list)

load_checkpoint(args.checkpoint, model)

Expand Down
9 changes: 9 additions & 0 deletions egs/aishell/s10/chain/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,14 @@ def forward(self, x):

return nnet_output, xent_output

def constrain_orthonormal(self):
for i in range(len(self.tdnnfs)):
self.tdnnfs[i].constrain_orthonormal()

self.prefinal_l.constrain_orthonormal()
self.prefinal_chain.constrain_orthonormal()
self.prefinal_xent.constrain_orthonormal()


if __name__ == '__main__':
feat_dim = 43
Expand All @@ -212,3 +220,4 @@ def forward(self, x):
x = torch.arange(N * T * C).reshape(N, T, C).float()
nnet_output, xent_output = model(x)
print(x.shape, nnet_output.shape, xent_output.shape)
model.constrain_orthonormal()
33 changes: 20 additions & 13 deletions egs/aishell/s10/chain/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,18 +129,19 @@ def _check_args(args):
assert args.feat_dim > 0
assert args.output_dim > 0
assert args.hidden_dim > 0
assert args.bottleneck_dim > 0

assert args.kernel_size_list is not None
assert len(args.kernel_size_list) > 0
assert args.time_stride_list is not None
assert len(args.time_stride_list) > 0

assert args.stride_list is not None
assert len(args.stride_list) > 0
assert args.conv_stride_list is not None
assert len(args.conv_stride_list) > 0

args.kernel_size_list = [int(k) for k in args.kernel_size_list.split(', ')]
args.time_stride_list = [int(k) for k in args.time_stride_list.split(', ')]

args.stride_list = [int(k) for k in args.stride_list.split(', ')]
args.conv_stride_list = [int(k) for k in args.conv_stride_list.split(', ')]

assert len(args.kernel_size_list) == len(args.stride_list)
assert len(args.time_stride_list) == len(args.conv_stride_list)

assert args.log_level in ['debug', 'info', 'warning']

Expand Down Expand Up @@ -195,15 +196,21 @@ def get_args():
required=True,
type=int)

parser.add_argument('--kernel-size-list',
dest='kernel_size_list',
help='kernel size list',
parser.add_argument('--bottleneck-dim',
dest='bottleneck_dim',
help='nn bottleneck dimension',
required=True,
type=int)

parser.add_argument('--time-stride-list',
dest='time_stride_list',
help='time stride list',
required=True,
type=str)

parser.add_argument('--stride-list',
dest='stride_list',
help='stride list',
parser.add_argument('--conv-stride-list',
dest='conv_stride_list',
help='conv stride list',
required=True,
type=str)

Expand Down
32 changes: 20 additions & 12 deletions egs/aishell/s10/chain/tdnnf_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch.nn.functional as F


def _constraint_orthonormal_internal(M):
def _constrain_orthonormal_internal(M):
'''
Refer to
void ConstrainOrthonormalInternal(BaseFloat scale, CuMatrixBase<BaseFloat> *M)
Expand Down Expand Up @@ -58,7 +58,7 @@ def __init__(self, dim, bottleneck_dim, time_stride):
assert time_stride in [0, 1]
# WARNING(fangjun): kaldi uses [-1, 0] for the first linear layer
# and [0, 1] for the second affine layer;
# We use [-1, 0, 1] for the first linear layer
# we use [-1, 0, 1] for the first linear layer if time_stride == 1

if time_stride == 0:
kernel_size = 1
Expand All @@ -79,7 +79,7 @@ def forward(self, x):
x = self.conv(x)
return x

def constraint_orthonormal(self):
def constrain_orthonormal(self):
state_dict = self.conv.state_dict()
w = state_dict['weight']
# w is of shape [out_channels, in_channels, kernel_size]
Expand All @@ -97,7 +97,7 @@ def constraint_orthonormal(self):
w = w.t()
need_transpose = True

w = _constraint_orthonormal_internal(w)
w = _constrain_orthonormal_internal(w)

if need_transpose:
w = w.t()
Expand Down Expand Up @@ -142,6 +142,9 @@ def forward(self, x):

return x

def constrain_orthonormal(self):
self.linear.constrain_orthonormal()


class FactorizedTDNN(nn.Module):
'''
Expand Down Expand Up @@ -175,6 +178,8 @@ def __init__(self,
time_stride=time_stride)

# affine requires [N, C, T]
# WARNING(fangjun): we do not use nn.Linear here
# since we want to use `stride`
self.affine = nn.Conv1d(in_channels=bottleneck_dim,
out_channels=dim,
kernel_size=1,
Expand All @@ -191,31 +196,34 @@ def forward(self, x):
input_x = x

x = self.linear(x)

# at this point, x is [N, C, T]

x = self.affine(x)

# at this point, x is [N, C, T]

x = F.relu(x)

# at this point, x is [N, C, T]

x = self.batchnorm(x)

# at this point, x is [N, C, T]

# TODO(fangjun): implement GeneralDropoutComponent in PyTorch

# at this point, x is [N, C, T]
if self.linear.kernel_size == 3:
x = self.bypass_scale * input_x[:, :, 1:-1:self.conv_stride] + x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be c:-c:c rather than 1:-1:c, where c is self.conv_stride?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suppose the time_stride is 1 and the conv_stride is 1.

If the input time index is

0 1 2 3 4 5 6

After self.linear, the time index will be

1 2 3 4 5

since the kernel shape is [-1, 0, 1] (time_stride == 1)

After self.affine, the time index is still

1 2 3 4 5

The index of input[1:-1:self.conv_stride] is [1, 2, 3, 4, 5] which matches
the output of self.affine.


It is assumed that

  • time_stride == 1, conv_stride == 1

or

  • time_stride == 0, conv_stride == 3

So c:-c:c is equivalent to 1:-1:c when time_stride==1 and conv_stride == 1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it should be called time_stride here. Perhaps in the original Kaldi code it wasn't super clear but when implemented as convolution it gets very confusing. Better to make (stride, kernel_size) the parameters and have them be (1, 3), (1, 3), ... (3, 3), (1, 1), (1, 3), (1, 3) ...
In any case, please revert other aspects of the implementation to more similar to the way it was before and start doing experiments with that. I don't see much point starting from such a strange starting point. (i.e. the way the code is right now).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree.

I also find them confusing but I wrote it this way to follow the naming style in Kaldi.

I'll change them now.

else:
x = self.bypass_scale * input_x[:, :, ::self.conv_stride] + x
return x

def constraint_orthonormal(self):
self.linear.constraint_orthonormal()
def constrain_orthonormal(self):
self.linear.constrain_orthonormal()


def _test_constraint_orthonormal():
def _test_constrain_orthonormal():

def compute_loss(M):
P = torch.mm(M, M.t())
Expand All @@ -238,7 +246,7 @@ def compute_loss(M):
loss.append(compute_loss(w))

for i in range(15):
w = _constraint_orthonormal_internal(w)
w = _constrain_orthonormal_internal(w)
loss.append(compute_loss(w))

for i in range(1, len(loss)):
Expand All @@ -252,11 +260,11 @@ def compute_loss(M):
time_stride=1,
conv_stride=3)
loss = []
model.constraint_orthonormal()
model.constrain_orthonormal()
loss.append(
compute_loss(model.linear.conv.state_dict()['weight'].reshape(128, -1)))
for i in range(5):
model.constraint_orthonormal()
model.constrain_orthonormal()
loss.append(
compute_loss(model.linear.conv.state_dict()['weight'].reshape(
128, -1)))
Expand Down Expand Up @@ -308,4 +316,4 @@ def _test_factorized_tdnn():
if __name__ == '__main__':
torch.manual_seed(20200130)
_test_factorized_tdnn()
_test_constraint_orthonormal()
_test_constrain_orthonormal()
11 changes: 9 additions & 2 deletions egs/aishell/s10/chain/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# disable warnings when loading tensorboard
warnings.simplefilter(action='ignore', category=FutureWarning)

import numpy as np
import torch
import torch.optim as optim
from torch.nn.utils import clip_grad_value_
Expand Down Expand Up @@ -84,6 +85,11 @@ def train_one_epoch(dataloader, model, device, optimizer, criterion,
total_weight += objf_l2_term_weight[2].item()
num_frames = nnet_output.shape[0]
total_frames += num_frames

if np.random.choice(4) == 0:
with torch.no_grad():
model.constraint_orthonormal()

if batch_idx % 100 == 0:
logging.info(
'Process {}/{}({:.6f}%) global average objf: {:.6f} over {} '
Expand Down Expand Up @@ -135,8 +141,9 @@ def main():
output_dim=args.output_dim,
lda_mat_filename=args.lda_mat_filename,
hidden_dim=args.hidden_dim,
kernel_size_list=args.kernel_size_list,
stride_list=args.stride_list)
bottleneck_dim=args.bottleneck_dim,
time_stride_list=args.time_stride_list,
conv_stride_list=args.conv_stride_list)

start_epoch = 0
num_epochs = args.num_epochs
Expand Down
10 changes: 10 additions & 0 deletions egs/aishell/s10/conf/mfcc_hires.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# config for high-resolution MFCC features, intended for neural network training.
# Note: we keep all cepstra, so it has the same info as filterbank features,
# but MFCC is more easily compressible (because less correlated) which is why
# we prefer this method.
--use-energy=false # use average of log energy, not energy.
--sample-frequency=16000 # AISHELL-2 is sampled at 16kHz
--num-mel-bins=40 # similar to Google's setup.
--num-ceps=40 # there is no dimensionality reduction.
--low-freq=20 # low cutoff frequency for mel bins
--high-freq=-400 # high cutoff frequency, relative to Nyquist of 8000 (=7600)
46 changes: 27 additions & 19 deletions egs/aishell/s10/local/run_chain.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ stage=0

# GPU device id to use (count from 0).
# you can also set `CUDA_VISIBLE_DEVICES` and set `device_id=0`
device_id=0
device_id=6

nj=10

Expand All @@ -19,8 +19,8 @@ lat_dir=exp/tri5a_lats # input lat dir
treedir=exp/chain/tri5_tree # output tree dir

# You should know how to calculate your model's left/right context **manually**
model_left_context=12
model_right_context=12
model_left_context=28
model_right_context=28
egs_left_context=$[$model_left_context + 1]
egs_right_context=$[$model_right_context + 1]
frames_per_eg=150,110,90
Expand All @@ -30,9 +30,10 @@ minibatch_size=128
num_epochs=6
lr=1e-3

hidden_dim=625
kernel_size_list="1, 3, 3, 3, 3, 3" # comma separated list
stride_list="1, 1, 3, 1, 1, 1" # comma separated list
hidden_dim=1024
bottleneck_dim=128
time_stride_list="1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1" # comma separated list
conv_stride_list="1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1" # comma separated list

log_level=info # valid values: debug, info, warning

Expand All @@ -47,11 +48,16 @@ save_nn_output_as_compressed=false

if [[ $stage -le 0 ]]; then
for datadir in train dev test; do
dst_dir=data/fbank_pitch/$datadir
dst_dir=data/mfcc_hires/$datadir
if [[ ! -f $dst_dir/feats.scp ]]; then
echo "making mfcc-pitch features for LF-MMI training"
utils/copy_data_dir.sh data/$datadir $dst_dir
echo "making fbank-pitch features for LF-MMI training"
steps/make_fbank_pitch.sh --cmd $train_cmd --nj $nj $dst_dir || exit 1
steps/make_mfcc_pitch.sh \
--mfcc-config conf/mfcc_hires.conf \
--pitch-config conf/pitch.conf \
--cmd "$train_cmd" \
--nj $nj \
$dst_dir || exit 1
steps/compute_cmvn_stats.sh $dst_dir || exit 1
utils/fix_data_dir.sh $dst_dir
else
Expand Down Expand Up @@ -80,12 +86,12 @@ if [[ $stage -le 2 ]]; then
# step compared with other recipes.
steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \
--context-opts "--context-width=2 --central-position=1" \
--cmd "$train_cmd" 5000 data/train $lang $ali_dir $treedir
--cmd "$train_cmd" 5000 data/mfcc/train $lang $ali_dir $treedir
fi

if [[ $stage -le 3 ]]; then
echo "creating phone language-model"
$train_cmd exp/chain/log/make_phone_lm.log \
"$train_cmd" exp/chain/log/make_phone_lm.log \
chain-est-phone-lm \
"ark:gunzip -c $treedir/ali.*.gz | ali-to-phones $treedir/final.mdl ark:- ark:- |" \
exp/chain/phone_lm.fst || exit 1
Expand All @@ -95,7 +101,7 @@ if [[ $stage -le 4 ]]; then
echo "creating denominator FST"
copy-transition-model $treedir/final.mdl exp/chain/0.trans_mdl
cp $treedir/tree exp/chain
$train_cmd exp/chain/log/make_den_fst.log \
"$train_cmd" exp/chain/log/make_den_fst.log \
chain-make-den-fst exp/chain/tree exp/chain/0.trans_mdl exp/chain/phone_lm.fst \
exp/chain/den.fst exp/chain/normalization.fst || exit 1
fi
Expand All @@ -119,7 +125,7 @@ if [[ $stage -le 5 ]]; then
--right-tolerance 5 \
--srand 0 \
--stage -10 \
data/fbank_pitch/train \
data/mfcc_hires/train \
exp/chain $lat_dir exp/chain/egs
fi

Expand Down Expand Up @@ -157,16 +163,17 @@ if [[ $stage -le 8 ]]; then

# sort the options alphabetically
python3 ./chain/train.py \
--bottleneck-dim $bottleneck_dim \
--checkpoint=${train_checkpoint:-} \
--conv-stride-list "$conv_stride_list" \
--device-id $device_id \
--dir exp/chain/train \
--feat-dim $feat_dim \
--hidden-dim $hidden_dim \
--is-training true \
--kernel-size-list "$kernel_size_list" \
--log-level $log_level \
--output-dim $output_dim \
--stride-list "$stride_list" \
--time-stride-list "$time_stride_list" \
--train.cegs-dir exp/chain/merged_egs \
--train.den-fst exp/chain/den.fst \
--train.egs-left-context $egs_left_context \
Expand All @@ -186,20 +193,21 @@ if [[ $stage -le 9 ]]; then
best_epoch=$(cat exp/chain/train/best-epoch-info | grep 'best epoch' | awk '{print $NF}')
inference_checkpoint=exp/chain/train/epoch-${best_epoch}.pt
python3 ./chain/inference.py \
--bottleneck-dim $bottleneck_dim \
--checkpoint $inference_checkpoint \
--conv-stride-list "$conv_stride_list" \
--device-id $device_id \
--dir exp/chain/inference/$x \
--feat-dim $feat_dim \
--feats-scp data/fbank_pitch/$x/feats.scp \
--feats-scp data/mfcc_hires/$x/feats.scp \
--hidden-dim $hidden_dim \
--is-training false \
--kernel-size-list "$kernel_size_list" \
--log-level $log_level \
--model-left-context $model_left_context \
--model-right-context $model_right_context \
--output-dim $output_dim \
--save-as-compressed $save_nn_output_as_compressed \
--stride-list "$stride_list" || exit 1
--time-stride-list "$time_stride_list" || exit 1
fi
done
fi
Expand Down Expand Up @@ -228,7 +236,7 @@ if [[ $stage -le 11 ]]; then

for x in test dev; do
./local/score.sh --cmd "$decode_cmd" \
data/fbank_pitch/$x \
data/mfcc_hires/$x \
exp/chain/graph \
exp/chain/decode_res/$x || exit 1
done
Expand Down
Loading