Skip to content

Kaldi MFCC #228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Aug 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ to use and feel like a natural extension.
- Common audio transforms
- [Spectrogram, AmplitudeToDB, MelScale, MelSpectrogram, MFCC, MuLawEncoding, MuLawDecoding, Resample](http://pytorch.org/audio/transforms.html)
- Compliance interfaces: Run code using PyTorch that align with other libraries
- [Kaldi: fbank, spectrogram, resample_waveform](https://pytorch.org/audio/compliance.kaldi.html)
- [Kaldi: spectrogram, fbank, mfcc, resample_waveform](https://pytorch.org/audio/compliance.kaldi.html)

Dependencies
------------
Expand Down
9 changes: 7 additions & 2 deletions docs/source/compliance.kaldi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,20 @@ produce similar outputs.
Functions
---------

:hidden:`spectrogram`
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: spectrogram

:hidden:`fbank`
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: fbank

:hidden:`spectrogram`
:hidden:`mfcc`
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: spectrogram
.. autofunction:: mfcc

:hidden:`resample_waveform`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
2 changes: 1 addition & 1 deletion test/compliance/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import random
import torchaudio

TEST_PREFIX = ['fbank', 'spec', 'resample']
TEST_PREFIX = ['spec', 'fbank', 'mfcc', 'resample']


def generate_rand_boolean():
Expand Down
34 changes: 34 additions & 0 deletions test/test_compliance_kaldi.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,40 @@ def get_output_fn(sound, args):

self._compliance_test_helper(self.test_filepath, 'fbank', 97, 22, get_output_fn, atol=1e-3, rtol=1e-1)

def test_mfcc(self):
def get_output_fn(sound, args):
output = kaldi.mfcc(
sound,
blackman_coeff=args[1],
dither=0.0,
energy_floor=args[2],
frame_length=args[3],
frame_shift=args[4],
high_freq=args[5],
htk_compat=args[6],
low_freq=args[7],
num_mel_bins=args[8],
preemphasis_coefficient=args[9],
raw_energy=args[10],
remove_dc_offset=args[11],
round_to_power_of_two=args[12],
snip_edges=args[13],
subtract_mean=args[14],
use_energy=args[15],
num_ceps=args[16],
cepstral_lifter=args[17],
vtln_high=args[18],
vtln_low=args[19],
vtln_warp=args[20],
window_type=args[21])
return output

self._compliance_test_helper(self.test_filepath, 'mfcc', 145, 22, get_output_fn, atol=1e-3)

def test_mfcc_empty(self):
# Passing in an empty tensor should result in an error
self.assertRaises(AssertionError, kaldi.mfcc, torch.empty(0))

def test_resample_waveform(self):
def get_output_fn(sound, args):
output = kaldi.resample_waveform(sound, args[1], args[2])
Expand Down
154 changes: 142 additions & 12 deletions torchaudio/compliance/kaldi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@
import fractions
import random
import torch

import torchaudio

__all__ = [
'fbank',
'get_mel_banks',
'inverse_mel_scale',
'inverse_mel_scale_scalar',
'mel_scale',
'mel_scale_scalar',
'spectrogram',
'fbank',
'mfcc',
'vtln_warp_freq',
'vtln_warp_mel_freq',
'resample_waveform',
Expand Down Expand Up @@ -117,7 +118,9 @@ def _get_waveform_and_window_properties(waveform, channel, sample_frequency, fra
frame_length, round_to_power_of_two, preemphasis_coefficient):
r"""Gets the waveform and window properties
"""
waveform = waveform[max(channel, 0), :] # size (n)
channel = max(channel, 0)
assert channel < waveform.size(0), ('Invalid channel %d for size %d' % (channel, waveform.size(0)))
waveform = waveform[channel, :] # size (n)
window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size
Expand Down Expand Up @@ -182,6 +185,15 @@ def _get_window(waveform, padded_window_size, window_size, window_shift, window_
return strided_input, signal_log_energy


def _subtract_column_mean(tensor, subtract_mean):
# subtracts the column mean of the tensor size (m, n) if subtract_mean=True
# it returns size (m, n)
if subtract_mean:
col_means = torch.mean(tensor, dim=0).unsqueeze(0)
tensor = tensor - col_means
return tensor


def spectrogram(
waveform, blackman_coeff=0.42, channel=-1, dither=1.0, energy_floor=0.0,
frame_length=25.0, frame_shift=10.0, min_duration=0.0,
Expand Down Expand Up @@ -239,10 +251,7 @@ def spectrogram(
power_spectrum = torch.max(fft.pow(2).sum(2), EPSILON).log() # size (m, padded_window_size // 2 + 1)
power_spectrum[:, 0] = signal_log_energy

if subtract_mean:
col_means = torch.mean(power_spectrum, dim=0).unsqueeze(0) # size (1, padded_window_size // 2 + 1)
power_spectrum = power_spectrum - col_means

power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
return power_spectrum


Expand Down Expand Up @@ -504,7 +513,7 @@ def fbank(
# avoid log of zero (which should be prevented anyway by dithering)
mel_energies = torch.max(mel_energies, EPSILON).log()

# if use_energy then add it as the first column for htk_compat == true else last column
# if use_energy then add it as the last column for htk_compat == true else first column
if use_energy:
signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1)
# returns size (m, num_mel_bins + 1)
Expand All @@ -513,13 +522,134 @@ def fbank(
else:
mel_energies = torch.cat((signal_log_energy, mel_energies), dim=1)

if subtract_mean:
col_means = torch.mean(mel_energies, dim=0).unsqueeze(0) # size (1, num_mel_bins + use_energy)
mel_energies = mel_energies - col_means

mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
return mel_energies


def _get_dct_matrix(num_ceps, num_mel_bins):
# returns a dct matrix of size (num_mel_bins, num_ceps)
# size (num_mel_bins, num_mel_bins)
dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, 'ortho')
# kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins)
# this would be the first column in the dct_matrix for torchaudio as it expects a
# right multiply (which would be the first column of the kaldi's dct_matrix as kaldi
# expects a left multiply e.g. dct_matrix * vector).
dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins))
dct_matrix = dct_matrix[:, :num_ceps]
return dct_matrix


def _get_lifter_coeffs(num_ceps, cepstral_lifter):
# returns size (num_ceps)
# Compute liftering coefficients (scaling on cepstral coeffs)
# coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected.
i = torch.arange(num_ceps, dtype=torch.get_default_dtype())
return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter)


def mfcc(
waveform, blackman_coeff=0.42, cepstral_lifter=22.0, channel=-1, dither=1.0,
energy_floor=0.0, frame_length=25.0, frame_shift=10.0, high_freq=0.0, htk_compat=False,
low_freq=20.0, num_ceps=13, min_duration=0.0, num_mel_bins=23, preemphasis_coefficient=0.97,
raw_energy=True, remove_dc_offset=True, round_to_power_of_two=True,
sample_frequency=16000.0, snip_edges=True, subtract_mean=False, use_energy=False,
vtln_high=-500.0, vtln_low=100.0, vtln_warp=1.0, window_type=POVEY):
r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's
compute-mfcc-feats.

Args:
waveform (torch.Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
blackman_coeff (float): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
cepstral_lifter (float): Constant that controls scaling of MFCCs (Default: ``22.0``)
channel (int): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
dither (float): Dithering constant (0.0 means no dither). If you turn this off, you should set
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``1.0``)
energy_floor (float): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``0.0``)
frame_length (float): Frame length in milliseconds (Default: ``25.0``)
frame_shift (float): Frame shift in milliseconds (Default: ``10.0``)
high_freq (float): High cutoff frequency for mel bins (if <= 0, offset from Nyquist) (Default: ``0.0``)
htk_compat (bool): If true, put energy last. Warning: not sufficient to get HTK compatible features (need
to change other parameters). (Default: ``False``)
low_freq (float): Low cutoff frequency for mel bins (Default: ``20.0``)
num_ceps (int): Number of cepstra in MFCC computation (including C0) (Default: ``13``)
min_duration (float): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
num_mel_bins (int): Number of triangular mel-frequency bins (Default: ``23``)
preemphasis_coefficient (float): Coefficient for use in signal preemphasis (Default: ``0.97``)
raw_energy (bool): If True, compute energy before preemphasis and windowing (Default: ``True``)
remove_dc_offset: Subtract mean from waveform on each frame (Default: ``True``)
round_to_power_of_two (bool): If True, round window size to power of two by zero-padding input
to FFT. (Default: ``True``)
sample_frequency (float): Waveform data sample frequency (must match the waveform file, if
specified there) (Default: ``16000.0``)
snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit
in the file, and the number of frames depends on the frame_length. If False, the number of frames
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
subtract_mean (bool): Subtract mean of each feature file [CMS]; not recommended to do
it this way. (Default: ``False``)
use_energy (bool): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
vtln_high (float): High inflection point in piecewise linear VTLN warping function (if
negative, offset from high-mel-freq (Default: ``-500.0``)
vtln_low (float): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
vtln_warp (float): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
window_type (str): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') (Default: ``'povey'``)

Returns:
torch.Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``)
where m is calculated in _get_strided
"""
assert num_ceps <= num_mel_bins, 'num_ceps cannot be larger than num_mel_bins: %d vs %d' % (num_ceps, num_mel_bins)

# The mel_energies should not be squared (use_power=True), not have mean subtracted
# (subtract_mean=False), and use log (use_log_fbank=True).
# size (m, num_mel_bins + use_energy)
feature = fbank(waveform=waveform, blackman_coeff=blackman_coeff, channel=channel,
dither=dither, energy_floor=energy_floor, frame_length=frame_length,
frame_shift=frame_shift, high_freq=high_freq, htk_compat=htk_compat,
low_freq=low_freq, min_duration=min_duration, num_mel_bins=num_mel_bins,
preemphasis_coefficient=preemphasis_coefficient, raw_energy=raw_energy,
remove_dc_offset=remove_dc_offset, round_to_power_of_two=round_to_power_of_two,
sample_frequency=sample_frequency, snip_edges=snip_edges, subtract_mean=False,
use_energy=use_energy, use_log_fbank=True, use_power=True,
vtln_high=vtln_high, vtln_low=vtln_low, vtln_warp=vtln_warp, window_type=window_type)

if use_energy:
# size (m)
signal_log_energy = feature[:, num_mel_bins if htk_compat else 0]
# offset is 0 if htk_compat==True else 1
mel_offset = int(not htk_compat)
feature = feature[:, mel_offset:(num_mel_bins + mel_offset)]

# size (num_mel_bins, num_ceps)
dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins)

# size (m, num_ceps)
feature = feature.matmul(dct_matrix)

if cepstral_lifter != 0.0:
# size (1, num_ceps)
lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0)
feature *= lifter_coeffs

# if use_energy then replace the last column for htk_compat == true else first column
if use_energy:
feature[:, 0] = signal_log_energy

if htk_compat:
energy = feature[:, 0].unsqueeze(1) # size (m, 1)
feature = feature[:, 1:] # size (m, num_ceps - 1)
if not use_energy:
# scale on C0 (actually removing a scale we previously added that's
# part of one common definition of the cosine transform.)
energy *= math.sqrt(2)

feature = torch.cat((feature, energy), dim=1)

feature = _subtract_column_mean(feature, subtract_mean)
return feature


def _get_LR_indices_and_weights(orig_freq, new_freq, output_samples_in_unit, window_width,
lowpass_cutoff, lowpass_filter_width):
r"""Based on LinearResample::SetIndexesAndWeights where it retrieves the weights for
Expand Down