@@ -218,6 +218,15 @@ def test_fft2(self):
218218 self .assertEqual (real_output .shape , ref_shape )
219219 self .assertEqual (imag_output .shape , ref_shape )
220220
221+ def test_ifft2 (self ):
222+ real = KerasTensor ((None , 4 , 3 ), dtype = "float32" )
223+ imag = KerasTensor ((None , 4 , 3 ), dtype = "float32" )
224+ real_output , imag_output = kmath .ifft2 ((real , imag ))
225+ ref = np .fft .ifft2 (np .ones ((2 , 4 , 3 )))
226+ ref_shape = (None ,) + ref .shape [1 :]
227+ self .assertEqual (real_output .shape , ref_shape )
228+ self .assertEqual (imag_output .shape , ref_shape )
229+
221230 @parameterized .parameters ([(None ,), (1 ,), (5 ,)])
222231 def test_rfft (self , fft_length ):
223232 x = KerasTensor ((None , 4 , 3 ), dtype = "float32" )
@@ -354,6 +363,14 @@ def test_fft2(self):
354363 self .assertEqual (real_output .shape , ref .shape )
355364 self .assertEqual (imag_output .shape , ref .shape )
356365
366+ def test_ifft2 (self ):
367+ real = KerasTensor ((2 , 4 , 3 ), dtype = "float32" )
368+ imag = KerasTensor ((2 , 4 , 3 ), dtype = "float32" )
369+ real_output , imag_output = kmath .ifft2 ((real , imag ))
370+ ref = np .fft .ifft2 (np .ones ((2 , 4 , 3 )))
371+ self .assertEqual (real_output .shape , ref .shape )
372+ self .assertEqual (imag_output .shape , ref .shape )
373+
357374 def test_rfft (self ):
358375 x = KerasTensor ((2 , 4 , 3 ), dtype = "float32" )
359376 real_output , imag_output = kmath .rfft (x )
@@ -715,6 +732,18 @@ def test_fft2(self):
715732 self .assertAllClose (real_ref , real_output )
716733 self .assertAllClose (imag_ref , imag_output )
717734
735+ def test_ifft2 (self ):
736+ real = np .random .random ((2 , 4 , 3 )).astype (np .float32 )
737+ imag = np .random .random ((2 , 4 , 3 )).astype (np .float32 )
738+ complex_arr = real + 1j * imag
739+
740+ real_output , imag_output = kmath .ifft2 ((real , imag ))
741+ ref = np .fft .ifft2 (complex_arr )
742+ real_ref = np .real (ref )
743+ imag_ref = np .imag (ref )
744+ self .assertAllClose (real_ref , real_output )
745+ self .assertAllClose (imag_ref , imag_output )
746+
718747 @parameterized .parameters ([(None ,), (3 ,), (15 ,)])
719748 def test_rfft (self , n ):
720749 # Test 1D.
0 commit comments