Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support ivector training in pytorch model #3969

Merged
merged 4 commits into from
Mar 3, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions egs/aishell/s10/chain/egs_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,17 @@ def __call__(self, batch):
self.egs_left_context + self.egs_right_context

# TODO(fangjun): support ivector
fanlu marked this conversation as resolved.
Show resolved Hide resolved
assert len(eg.inputs) == 1
assert eg.inputs[0].name == 'input'

_feats = kaldi.FloatMatrix()
eg.inputs[0].features.GetMatrix(_feats)
feats = _feats.numpy()

if len(eg.inputs) > 1:
_ivectors = kaldi.FloatMatrix()
eg.inputs[1].features.GetMatrix(_ivectors)
ivectors = _ivectors.numpy()

assert feats.shape[0] == batch_size * frames_per_sequence

feat_list = []
Expand All @@ -173,6 +177,9 @@ def __call__(self, batch):
end_index -= 1 # remove the rightmost frame added for frame shift
feat = feats[start_index:end_index:, :]
feat = splice_feats(feat)
if len(eg.inputs) > 1:
repeat_ivector = torch.from_numpy(ivectors[i]).repeat(feat.shape[0], 1)
feat = torch.cat((torch.from_numpy(feat), repeat_ivector), dim=1).numpy()
feat_list.append(feat)

batched_feat = np.stack(feat_list, axis=0)
Expand All @@ -182,7 +189,10 @@ def __call__(self, batch):
# the first -2 is from extra left/right context
# the second -2 is from lda feats splicing
assert batched_feat.shape[1] == frames_per_sequence - 4
assert batched_feat.shape[2] == feats.shape[-1] * 3
if len(eg.inputs) > 1:
assert batched_feat.shape[2] == feats.shape[-1] * 3 + ivectors.shape[-1]
else:
assert batched_feat.shape[2] == feats.shape[-1] * 3

torch_feat = torch.from_numpy(batched_feat).float()
feature_list.append(torch_feat)
Expand Down
46 changes: 37 additions & 9 deletions egs/aishell/s10/chain/feat_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
def get_feat_dataloader(feats_scp,
model_left_context,
model_right_context,
ivector_scp=None,
batch_size=16,
num_workers=10):
dataset = FeatDataset(feats_scp=feats_scp)
dataset = FeatDataset(feats_scp=feats_scp, ivector_scp=ivector_scp)

collate_fn = FeatDatasetCollateFunc(model_left_context=model_left_context,
model_right_context=model_right_context,
Expand Down Expand Up @@ -55,21 +56,36 @@ def _add_model_left_right_context(x, left_context, right_context):

class FeatDataset(Dataset):

def __init__(self, feats_scp):
def __init__(self, feats_scp, ivector_scp=None):
fanlu marked this conversation as resolved.
Show resolved Hide resolved
assert os.path.isfile(feats_scp)

self.feats_scp = feats_scp

# items is a list of [key, rxfilename]
items = list()
# items is a dict of {key: [key, rxfilename, ivec]}
fanlu marked this conversation as resolved.
Show resolved Hide resolved
items = dict()

with open(feats_scp, 'r') as f:
for line in f:
split = line.split()
assert len(split) == 2
items.append(split)

self.items = items
uttid, rxfilename =split
fanlu marked this conversation as resolved.
Show resolved Hide resolved
assert uttid not in items
items[uttid] = [uttid, rxfilename, None]
if ivector_scp:
self.ivector_scp = ivector_scp
expected_count = len(items)
n = 0
with open(ivector_scp, 'r') as f:
for line in f:
uttid_rxfilename = line.split()
assert len(uttid_rxfilename) == 2
uttid, rxfilename = uttid_rxfilename
assert uttid in items
items[uttid][-1] = rxfilename
n += 1
assert n == expected_count

self.items = list(items.values())

self.num_items = len(self.items)

Expand All @@ -81,6 +97,8 @@ def __getitem__(self, i):

def __str__(self):
s = 'feats scp: {}\n'.format(self.feats_scp)
if self.ivector_scp:
fanlu marked this conversation as resolved.
Show resolved Hide resolved
s += 'ivector_scp scp: {}\n'.format(self.ivector_scp)
s += 'num utt: {}\n'.format(self.num_items)
return s

Expand All @@ -105,11 +123,15 @@ def __call__(self, batch):
'''
key_list = []
feat_list = []
ivector_list = []
ivector_len_list = []
output_len_list = []
for b in batch:
key, rxfilename = b
key, rxfilename, ivector_rxfilename = b
key_list.append(key)
feat = kaldi.read_mat(rxfilename).numpy()
if ivector_rxfilename:
ivector = kaldi.read_mat(ivector_rxfilename).numpy() # L // 10 * C
feat_len = feat.shape[0]
output_len = (feat_len + self.frame_subsampling_factor -
1) // self.frame_subsampling_factor
Expand All @@ -120,10 +142,16 @@ def __call__(self, batch):
feat = splice_feats(feat)
feat_list.append(feat)
# no need to sort the feat by length
fanlu marked this conversation as resolved.
Show resolved Hide resolved
ivector_list.append(ivector)
fanlu marked this conversation as resolved.
Show resolved Hide resolved
ivector_len_list.append(ivector.shape[0])

# the user should sort utterances by length offline
# to avoid unnecessary padding
padded_feat = pad_sequence(
[torch.from_numpy(feat).float() for feat in feat_list],
batch_first=True)
return key_list, padded_feat, output_len_list
if ivector_list:
padded_ivector = pad_sequence(
[torch.from_numpy(ivector).float() for ivector in ivector_list],
batch_first=True)
return key_list, padded_feat, output_len_list, padded_ivector, ivector_len_list
32 changes: 27 additions & 5 deletions egs/aishell/s10/chain/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def main():
model = get_chain_model(
feat_dim=args.feat_dim,
output_dim=args.output_dim,
ivector_dim=args.ivector_dim,
lda_mat_filename=args.lda_mat_filename,
hidden_dim=args.hidden_dim,
bottleneck_dim=args.bottleneck_dim,
Expand All @@ -64,15 +65,36 @@ def main():

dataloader = get_feat_dataloader(
feats_scp=args.feats_scp,
ivector_scp=args.ivector_scp,
model_left_context=args.model_left_context,
model_right_context=args.model_right_context,
batch_size=32)

batch_size=1,
num_workers=0)
subsampling_factor = 3
subsampled_frames_per_chunk = args.frames_per_chunk // subsampling_factor
for batch_idx, batch in enumerate(dataloader):
key_list, padded_feat, output_len_list = batch
key_list, padded_feat, output_len_list, padded_ivector, ivector_len_list = batch
padded_feat = padded_feat.to(device)
if ivector_len_list:
padded_ivector = padded_ivector.to(device)
with torch.no_grad():
nnet_output, _ = model(padded_feat)
nnet_outputs = []
input_num_frames = padded_feat.shape[1] + 2 \
- args.model_left_context - args.model_right_context
for i in range(0, output_len_list[0], subsampled_frames_per_chunk):
# 418 -> [0, 17, 34, 51, 68, 85, 102, 119, 136]
first_output = i * subsampling_factor
last_output = min(input_num_frames, \
first_output + (subsampled_frames_per_chunk-1) * subsampling_factor)
first_input = first_output
last_input = last_output + args.model_left_context + args.model_right_context
input_x = padded_feat[:, first_input:last_input+1, :]
ivector_index = (first_output + last_output) // 2 // args.ivector_period
input_ivector = padded_ivector[:, ivector_index, :]
feat = torch.cat((input_x, input_ivector.repeat((1, input_x.shape[1], 1))), dim=-1)
nnet_output_temp, _ = model(feat)
nnet_outputs.append(nnet_output_temp)
nnet_output = torch.cat(nnet_outputs, dim=1)

num = len(key_list)
for i in range(num):
Expand All @@ -85,7 +107,7 @@ def main():
m = Matrix(m)
writer.Write(key, m)

if batch_idx % 10 == 0:
if batch_idx % 100 == 0:
logging.info('Processed batch {}/{} ({:.6f}%)'.format(
batch_idx, len(dataloader),
float(batch_idx) / len(dataloader) * 100))
Expand Down
10 changes: 7 additions & 3 deletions egs/aishell/s10/chain/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

def get_chain_model(feat_dim,
output_dim,
ivector_dim,
hidden_dim,
bottleneck_dim,
prefinal_bottleneck_dim,
Expand All @@ -25,6 +26,7 @@ def get_chain_model(feat_dim,
lda_mat_filename=None):
model = ChainModel(feat_dim=feat_dim,
output_dim=output_dim,
ivector_dim=ivector_dim,
lda_mat_filename=lda_mat_filename,
hidden_dim=hidden_dim,
bottleneck_dim=bottleneck_dim,
Expand Down Expand Up @@ -82,6 +84,7 @@ class ChainModel(nn.Module):
def __init__(self,
feat_dim,
output_dim,
ivector_dim=0,
lda_mat_filename=None,
hidden_dim=1024,
bottleneck_dim=128,
Expand All @@ -97,8 +100,9 @@ def __init__(self,
assert len(kernel_size_list) == len(subsampling_factor_list)
num_layers = len(kernel_size_list)

input_dim = feat_dim * 3 + ivector_dim
# tdnn1_affine requires [N, T, C]
self.tdnn1_affine = nn.Linear(in_features=feat_dim * 3,
self.tdnn1_affine = nn.Linear(in_features=input_dim,
out_features=hidden_dim)

# tdnn1_batchnorm requires [N, C, T]
Expand Down Expand Up @@ -142,11 +146,11 @@ def __init__(self,
if lda_mat_filename:
logging.info('Use LDA from {}'.format(lda_mat_filename))
self.lda_A, self.lda_b = load_lda_mat(lda_mat_filename)
assert feat_dim * 3 == self.lda_A.shape[0]
assert input_dim == self.lda_A.shape[0]
self.has_LDA = True
else:
logging.info('replace LDA with BatchNorm')
self.input_batch_norm = nn.BatchNorm1d(num_features=feat_dim * 3,
self.input_batch_norm = nn.BatchNorm1d(num_features=input_dim,
affine=False)
self.has_LDA = False

Expand Down
26 changes: 25 additions & 1 deletion egs/aishell/s10/chain/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,23 @@ def _set_inference_args(parser):
dest='feats_scp',
help='feats.scp filename, required for inference',
type=str)

parser.add_argument('--frames-per-chunk',
dest='frames_per_chunk',
help='frames per chunk',
type=int,
default=51)

parser.add_argument('--ivector-scp',
dest='ivector_scp',
help='ivector.scp filename, required for ivector inference',
type=str)

parser.add_argument('--ivector-period',
dest='ivector_period',
help='ivector period',
type=int,
default=10)

parser.add_argument('--model-left-context',
dest='model_left_context',
Expand Down Expand Up @@ -228,10 +245,17 @@ def get_args():

parser.add_argument('--feat-dim',
dest='feat_dim',
help='nn input dimension',
help='nn input 0 dimension',
required=True,
type=int)

parser.add_argument('--ivector-dim',
dest='ivector_dim',
help='nn input 1 dimension',
required=False,
default=0,
type=int)

parser.add_argument('--output-dim',
dest='output_dim',
help='nn output dimension',
Expand Down
1 change: 1 addition & 0 deletions egs/aishell/s10/chain/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def process_job(learning_rate, local_rank=None):
model = get_chain_model(
feat_dim=args.feat_dim,
output_dim=args.output_dim,
ivector_dim=args.ivector_dim,
lda_mat_filename=args.lda_mat_filename,
hidden_dim=args.hidden_dim,
bottleneck_dim=args.bottleneck_dim,
Expand Down
1 change: 1 addition & 0 deletions egs/aishell/s10/conf/online_cmvn.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# configuration file for apply-cmvn-online, used when invoking online2-wav-nnet3-latgen-faster.
Loading