-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add kaldi's equivalent
add-deltas
to PyTorch.
- Loading branch information
1 parent
ffa861c
commit 5c7fdec
Showing
2 changed files
with
173 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright 2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) | ||
# Apache 2.0 | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
|
||
def compute_delta_feat(x, weight): | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
csukuangfj
Author
Contributor
|
||
''' | ||
Args: | ||
x: input feat of shape [batch_size, seq_len, feat_dim] | ||
weight: coefficients for computing delta features; | ||
it has a shape of [feat_dim, 1, kernel_size]. | ||
Returns: | ||
a tensor fo shape [batch_size, seq_len, feat_dim] | ||
''' | ||
|
||
assert x.ndim == 3 | ||
assert weight.ndim == 3 | ||
assert weight.size(0) == x.size(2) | ||
assert weight.size(1) == 1 | ||
assert weight.size(2) % 2 == 1 | ||
|
||
feat_dim = x.size(2) | ||
|
||
pad_size = weight.size(2) // 2 | ||
|
||
# F.pad requires a 4-D tensor in our case | ||
x = x.unsqueeze(0) | ||
|
||
# (0, 0, pad_size, pad_size) == (left, right, top, bottom) | ||
padded_x = F.pad(x, (0, 0, pad_size, pad_size), mode='replicate') | ||
|
||
# after padding, we have to convert it back to 3-D | ||
# since conv1d requires 3-D input | ||
padded_x = padded_x.squeeze(0) | ||
|
||
# conv1d requires a shape of [batch_size, feat_dim, seq_len] | ||
padded_x = padded_x.permute(0, 2, 1) | ||
|
||
# NOTE(fangjun): we perform a depthwise convolution here by | ||
# setting groups == number of channels | ||
y = F.conv1d(input=padded_x, weight=weight, groups=feat_dim) | ||
|
||
# now convert y back to be of shape [batch_size, seq_len, feat_dim] | ||
y = y.permute(0, 2, 1) | ||
|
||
return y | ||
|
||
|
||
class AddDeltasTransform: | ||
''' | ||
This class implements `add-deltas` in kaldi with | ||
order == 2 and window == 2. | ||
It generates the identical output as kaldi's `add-deltas` with default | ||
parameters given the same input. | ||
''' | ||
|
||
def __init__(self): | ||
# yapf: disable | ||
self.first_order_coef = torch.tensor([-0.2, -0.1, 0, 0.1, 0.2]) | ||
self.second_order_coef = torch.tensor([0.04, 0.04, 0.01, -0.04, -0.1, -0.04, 0.01, 0.04, 0.04]) | ||
# yapf: enable | ||
|
||
# TODO(fangjun): change the coefficients to the following as suggested by Dan | ||
# [-1, 0, 1] | ||
# [1, 0, -2, 0, 1] | ||
|
||
def __call__(self, x): | ||
''' | ||
Args: | ||
x: a tensor of shape [batch_size, seq_len, feat_dim] | ||
Returns: | ||
a tensor of shape [batch_size, seq_len, feat_dim * 3] | ||
''' | ||
if self.first_order_coef.ndim != 3: | ||
num_duplicates = x.size(2) | ||
|
||
# yapf: disable | ||
self.first_order_coef = self.first_order_coef.reshape(1, 1, -1) | ||
self.first_order_coef = torch.cat([self.first_order_coef] * num_duplicates, dim=0) | ||
|
||
self.second_order_coef = self.second_order_coef.reshape(1, 1, -1) | ||
self.second_order_coef = torch.cat([self.second_order_coef] * num_duplicates, dim=0) | ||
# yapf: enable | ||
|
||
device = x.device | ||
self.first_order_coef = self.first_order_coef.to(device) | ||
self.second_order_coef = self.second_order_coef.to(device) | ||
|
||
first_order = compute_delta_feat(x, self.first_order_coef) | ||
second_order = compute_delta_feat(x, self.second_order_coef) | ||
|
||
y = torch.cat([x, first_order, second_order], dim=2) | ||
|
||
return y |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright 2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) | ||
# Apache 2.0 | ||
|
||
import os | ||
import shutil | ||
import tempfile | ||
import unittest | ||
|
||
import numpy as np | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
import kaldi | ||
|
||
from transform import AddDeltasTransform | ||
|
||
|
||
class TransformTest(unittest.TestCase): | ||
|
||
def test_add_deltas_transform(self): | ||
x = torch.tensor([ | ||
[1, 3], | ||
[5, 10], | ||
[0, 1], | ||
[10, 20], | ||
[3, 1], | ||
[3, 2], | ||
[5, 1], | ||
[10, -2], | ||
]).float() | ||
|
||
x = x.unsqueeze(0) | ||
|
||
transform = AddDeltasTransform() | ||
y = transform(x) | ||
|
||
# now use kaldi's add-deltas to compute the ground truth | ||
d = tempfile.mkdtemp() | ||
|
||
wspecifier = 'ark:{}/feats.ark'.format(d) | ||
|
||
writer = kaldi.MatrixWriter(wspecifier) | ||
writer.Write('utt1', x.squeeze(0).numpy()) | ||
writer.Close() | ||
|
||
delta_feats_specifier = 'ark:{dir}/delta.ark'.format(dir=d) | ||
|
||
cmd = ''' | ||
add-deltas --print-args=false --delta-order=2 --delta-window=2 {} {} | ||
'''.format(wspecifier, delta_feats_specifier) | ||
|
||
os.system(cmd) | ||
|
||
reader = kaldi.RandomAccessMatrixReader(delta_feats_specifier) | ||
|
||
expected = reader['utt1'] | ||
|
||
y = y.squeeze(0) | ||
|
||
np.testing.assert_array_almost_equal(y.numpy(), expected.numpy()) | ||
|
||
reader.Close() | ||
|
||
shutil.rmtree(d) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
I didn't actually want a version to be implemented that was compatible with what the binary does (although that may be useful for other purposes). What I want is something similar to what we do in the example neural net scripts. The one I am referring to has no padding (it adds context) and uses a smaller context; the windows are [1], [-1,0,1], [1,0,-2,0,1]. This will be easier to replicate the nnet training scripts, and introduces less discontinuities at the edges.