Skip to content

Commit a7aa440

Browse files
committed
remove batch support for istft
1 parent be082f8 commit a7aa440

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

torchaudio/functional.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,7 @@ def istft(stft_matrix, # type: Tensor
8989
9090
Args:
9191
stft_matrix (torch.Tensor): Output of stft where each row of a batch is a frequency and each
92-
column is a window. it has a shape of either (batch, fft_size, n_frames, 2) or (
93-
fft_size, n_frames, 2)
92+
column is a window. it has a shape of (fft_size, n_frames, 2)
9493
n_fft (int): Size of Fourier transform
9594
hop_length (Optional[int]): The distance between neighboring sliding window frames.
9695
(Default: ``win_length // 4``)
@@ -107,10 +106,13 @@ def istft(stft_matrix, # type: Tensor
107106
108107
Returns:
109108
torch.Tensor: Least squares estimation of the original signal of size
110-
(batch, signal_length) or (signal_length)
109+
(signal_length)
111110
"""
112111
stft_matrix_dim = stft_matrix.dim()
113-
assert 3 <= stft_matrix_dim <= 4, ('Incorrect stft dimension: %d' % (stft_matrix_dim))
112+
# Technically this function can accept either (batch, fft_size, n_frames, 2) or
113+
# (fft_size, n_frames, 2). But going to temporarily remove batch support (
114+
# through adding an assert) to make torchaudio functions consistent.
115+
assert stft_matrix_dim == 3, ('Incorrect stft dimension: %d' % (stft_matrix_dim))
114116

115117
if stft_matrix_dim == 3:
116118
# add a batch dimension

0 commit comments

Comments
 (0)