Skip to content

Complex STFT transform from spectrogram #327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,18 @@ def test_compute_deltas_twochannel(self):
computed = transform(specgram)
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))

def test_batch_spectrogram(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)

# Single then transform then batch
expected = transforms.Spectrogram()(waveform).repeat(3, 1, 1, 1)

# Batch then transform
computed = transforms.Spectrogram()(waveform.repeat(3, 1, 1))

self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))


if __name__ == '__main__':
unittest.main()
37 changes: 23 additions & 14 deletions torchaudio/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ def istft(

Args:
stft_matrix (torch.Tensor): Output of stft where each row of a channel is a frequency and each
column is a window. it has a size of either (channel, fft_size, n_frame, 2) or (
fft_size, n_frame, 2)
column is a window. it has a size of either (..., fft_size, n_frame, 2)
n_fft (int): Size of Fourier transform
hop_length (Optional[int]): The distance between neighboring sliding window frames.
(Default: ``win_length // 4``)
Expand Down Expand Up @@ -218,42 +217,52 @@ def istft(
def spectrogram(
waveform, pad, window, n_fft, hop_length, win_length, power, normalized
):
# type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor
# type: (Tensor, int, Tensor, int, int, int, Optional[int], bool) -> Tensor
r"""
spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, normalized)

Create a spectrogram from a raw audio signal.
Create a spectrogram or a batch of spectrograms from a raw audio signal.
The spectrogram can be either magnitude-only or complex.

Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
waveform (torch.Tensor): Tensor of audio of dimension (..., channel, time)
pad (int): Two sided padding of signal
window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of FFT
hop_length (int): Length of hop between STFT windows
win_length (int): Window size
power (int): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc.
If None, then the complex spectrum is returned instead.
normalized (bool): Whether to normalize by magnitude after stft

Returns:
torch.Tensor: Dimension (channel, freq, time), where channel
torch.Tensor: Dimension (..., channel, freq, time), where channel
is unchanged, freq is ``n_fft // 2 + 1`` and ``n_fft`` is the number of
Fourier bins, and time is the number of window hops (n_frame).
"""
assert waveform.dim() == 2

if pad > 0:
# TODO add "with torch.no_grad():" back when JIT supports it
waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant")

# pack batch
shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1])

# default values are consistent with librosa.core.spectrum._spectrogram
spec_f = _stft(
waveform, n_fft, hop_length, win_length, window, True, "reflect", False, True
)

# unpack batch
spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-3:])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vincentqb is support for batching going to be done optionally then? 😁 I can add it to the augmentation transforms if needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we decided to go ahead and add batching using reshape now :) we'll then update these codes when nested tensor comes out


if normalized:
spec_f /= window.pow(2).sum().sqrt()
spec_f = spec_f.pow(power).sum(-1) # get power of "complex" tensor
if power is not None:
spec_f = spec_f.pow(power).sum(-1) # get power of "complex" tensor

return spec_f


Expand Down Expand Up @@ -431,11 +440,11 @@ def complex_norm(complex_tensor, power=1.0):
r"""Compute the norm of complex tensor input.

Args:
complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)`
complex_tensor (torch.Tensor): Tensor shape of `(..., complex=2)`
power (float): Power of the norm. (Default: `1.0`).

Returns:
torch.Tensor: Power of the normed input tensor. Shape of `(*, )`
torch.Tensor: Power of the normed input tensor. Shape of `(..., )`
"""
if power == 1.0:
return torch.norm(complex_tensor, 2, -1)
Expand All @@ -448,21 +457,21 @@ def angle(complex_tensor):
r"""Compute the angle of complex tensor input.

Args:
complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)`
complex_tensor (torch.Tensor): Tensor shape of `(..., complex=2)`

Return:
torch.Tensor: Angle of a complex tensor. Shape of `(*, )`
torch.Tensor: Angle of a complex tensor. Shape of `(..., )`
"""
return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0])


@torch.jit.script
def magphase(complex_tensor, power=1.0):
# type: (Tensor, float) -> Tuple[Tensor, Tensor]
r"""Separate a complex-valued spectrogram with shape `(*, 2)` into its magnitude and phase.
r"""Separate a complex-valued spectrogram with shape `(..., 2)` into its magnitude and phase.

Args:
complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)`
complex_tensor (torch.Tensor): Tensor shape of `(..., complex=2)`
power (float): Power of the norm. (Default: `1.0`)

Returns:
Expand Down