@@ -9534,13 +9534,13 @@ def forward(self, x):
9534
9534
[None , 1 , 3 ], # channels
9535
9535
[16 , 32 ], # n_fft
9536
9536
[5 , 9 ], # num_frames
9537
- [None , 4 , 5 ], # hop_length
9537
+ [None , 5 ], # hop_length
9538
9538
[None , 10 , 8 ], # win_length
9539
9539
[None , torch .hann_window ], # window
9540
9540
[False , True ], # center
9541
9541
[False , True ], # normalized
9542
9542
[None , False , True ], # onesided
9543
- [None , 30 , 40 ], # length
9543
+ [None , "shorter" , "larger" ], # length
9544
9544
[False , True ], # return_complex
9545
9545
)
9546
9546
)
@@ -9551,9 +9551,19 @@ def test_istft(self, compute_unit, backend, channels, n_fft, num_frames, hop_len
9551
9551
if hop_length is None and win_length is not None :
9552
9552
pytest .skip ("If win_length is set then we must set hop_length and 0 < hop_length <= win_length" )
9553
9553
9554
+ # Compute input_shape to generate test case
9554
9555
freq = n_fft // 2 + 1 if onesided else n_fft
9555
9556
input_shape = (channels , freq , num_frames ) if channels else (freq , num_frames )
9556
9557
9558
+ # If not set,c ompute hop_length for capturing errors
9559
+ if hop_length is None :
9560
+ hop_length = n_fft // 4
9561
+
9562
+ if length == "shorter" :
9563
+ length = n_fft // 2 + hop_length * (num_frames - 1 )
9564
+ elif length == "larger" :
9565
+ length = n_fft * 3 // 2 + hop_length * (num_frames - 1 )
9566
+
9557
9567
class ISTFTModel (torch .nn .Module ):
9558
9568
def forward (self , x ):
9559
9569
applied_window = window (win_length ) if window and win_length else None
@@ -9573,7 +9583,7 @@ def forward(self, x):
9573
9583
else :
9574
9584
return torch .real (x )
9575
9585
9576
- if win_length and center is False :
9586
+ if ( center is False and win_length ) or ( center and win_length and length ) :
9577
9587
# For some reason Pytorch raises an error https://github.com/pytorch/audio/issues/427#issuecomment-1829593033
9578
9588
with pytest .raises (RuntimeError , match = "istft\(.*\) window overlap add min: 1" ):
9579
9589
TorchBaseTest .run_compare_torch (
@@ -9582,7 +9592,7 @@ def forward(self, x):
9582
9592
backend = backend ,
9583
9593
compute_unit = compute_unit
9584
9594
)
9585
- elif length is not None and return_complex is True :
9595
+ elif length and return_complex :
9586
9596
with pytest .raises (ValueError , match = "New var type `<class 'coremltools.converters.mil.mil.types.type_tensor.tensor.<locals>.tensor'>` not a subtype of existing var type `<class 'coremltools.converters.mil.mil.types.type_tensor.tensor.<locals>.tensor'>`" ):
9587
9597
TorchBaseTest .run_compare_torch (
9588
9598
input_shape ,
0 commit comments