Skip to content

Commit 1317643

Browse files
author
Alejandro Gaston Alvarez Franceschi
committed
More fixes
1 parent c66d1aa commit 1317643

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9534,13 +9534,13 @@ def forward(self, x):
95349534
[None, 1, 3], # channels
95359535
[16, 32], # n_fft
95369536
[5, 9], # num_frames
9537-
[None, 4, 5], # hop_length
9537+
[None, 5], # hop_length
95389538
[None, 10, 8], # win_length
95399539
[None, torch.hann_window], # window
95409540
[False, True], # center
95419541
[False, True], # normalized
95429542
[None, False, True], # onesided
9543-
[None, 30, 40], # length
9543+
[None, "shorter", "larger"], # length
95449544
[False, True], # return_complex
95459545
)
95469546
)
@@ -9551,9 +9551,19 @@ def test_istft(self, compute_unit, backend, channels, n_fft, num_frames, hop_len
95519551
if hop_length is None and win_length is not None:
95529552
pytest.skip("If win_length is set then we must set hop_length and 0 < hop_length <= win_length")
95539553

9554+
# Compute input_shape to generate test case
95549555
freq = n_fft//2+1 if onesided else n_fft
95559556
input_shape = (channels, freq, num_frames) if channels else (freq, num_frames)
95569557

9558+
# If not set,c ompute hop_length for capturing errors
9559+
if hop_length is None:
9560+
hop_length = n_fft // 4
9561+
9562+
if length == "shorter":
9563+
length = n_fft//2 + hop_length * (num_frames - 1)
9564+
elif length == "larger":
9565+
length = n_fft*3//2 + hop_length * (num_frames - 1)
9566+
95579567
class ISTFTModel(torch.nn.Module):
95589568
def forward(self, x):
95599569
applied_window = window(win_length) if window and win_length else None
@@ -9573,7 +9583,7 @@ def forward(self, x):
95739583
else:
95749584
return torch.real(x)
95759585

9576-
if win_length and center is False:
9586+
if (center is False and win_length) or (center and win_length and length):
95779587
# For some reason Pytorch raises an error https://github.com/pytorch/audio/issues/427#issuecomment-1829593033
95789588
with pytest.raises(RuntimeError, match="istft\(.*\) window overlap add min: 1"):
95799589
TorchBaseTest.run_compare_torch(
@@ -9582,7 +9592,7 @@ def forward(self, x):
95829592
backend=backend,
95839593
compute_unit=compute_unit
95849594
)
9585-
elif length is not None and return_complex is True:
9595+
elif length and return_complex:
95869596
with pytest.raises(ValueError, match="New var type `<class 'coremltools.converters.mil.mil.types.type_tensor.tensor.<locals>.tensor'>` not a subtype of existing var type `<class 'coremltools.converters.mil.mil.types.type_tensor.tensor.<locals>.tensor'>`"):
95879597
TorchBaseTest.run_compare_torch(
95889598
input_shape,

0 commit comments

Comments
 (0)