@@ -219,6 +219,15 @@ def test_fft2(self):
219219 self .assertEqual (real_output .shape , ref_shape )
220220 self .assertEqual (imag_output .shape , ref_shape )
221221
222+ def test_ifft2 (self ):
223+ real = KerasTensor ((None , 4 , 3 ), dtype = "float32" )
224+ imag = KerasTensor ((None , 4 , 3 ), dtype = "float32" )
225+ real_output , imag_output = kmath .ifft2 ((real , imag ))
226+ ref = np .fft .ifft2 (np .ones ((2 , 4 , 3 )))
227+ ref_shape = (None ,) + ref .shape [1 :]
228+ self .assertEqual (real_output .shape , ref_shape )
229+ self .assertEqual (imag_output .shape , ref_shape )
230+
222231 @parameterized .parameters ([(None ,), (1 ,), (5 ,)])
223232 def test_rfft (self , fft_length ):
224233 x = KerasTensor ((None , 4 , 3 ), dtype = "float32" )
@@ -355,6 +364,14 @@ def test_fft2(self):
355364 self .assertEqual (real_output .shape , ref .shape )
356365 self .assertEqual (imag_output .shape , ref .shape )
357366
367+ def test_ifft2 (self ):
368+ real = KerasTensor ((2 , 4 , 3 ), dtype = "float32" )
369+ imag = KerasTensor ((2 , 4 , 3 ), dtype = "float32" )
370+ real_output , imag_output = kmath .ifft2 ((real , imag ))
371+ ref = np .fft .ifft2 (np .ones ((2 , 4 , 3 )))
372+ self .assertEqual (real_output .shape , ref .shape )
373+ self .assertEqual (imag_output .shape , ref .shape )
374+
358375 def test_rfft (self ):
359376 x = KerasTensor ((2 , 4 , 3 ), dtype = "float32" )
360377 real_output , imag_output = kmath .rfft (x )
@@ -717,6 +734,18 @@ def test_fft2(self):
717734 self .assertAllClose (real_ref , real_output )
718735 self .assertAllClose (imag_ref , imag_output )
719736
737+ def test_ifft2 (self ):
738+ real = np .random .random ((2 , 4 , 3 ))
739+ imag = np .random .random ((2 , 4 , 3 ))
740+ complex_arr = real + 1j * imag
741+
742+ real_output , imag_output = kmath .ifft2 ((real , imag ))
743+ ref = np .fft .ifft2 (complex_arr )
744+ real_ref = np .real (ref )
745+ imag_ref = np .imag (ref )
746+ self .assertAllClose (real_ref , real_output )
747+ self .assertAllClose (imag_ref , imag_output )
748+
720749 @parameterized .parameters ([(None ,), (3 ,), (15 ,)])
721750 def test_rfft (self , n ):
722751 # Test 1D.
0 commit comments