Skip to content

Commit 272bb90

Browse files
authored
adding ifft2 method to ops (#20447)
* adding ifft2 method to ops * fixes all test checks * using built-in versions in backends
1 parent e488c6e commit 272bb90

File tree

8 files changed

+133
-0
lines changed

8 files changed

+133
-0
lines changed

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from keras.src.ops.math import extract_sequences
4848
from keras.src.ops.math import fft
4949
from keras.src.ops.math import fft2
50+
from keras.src.ops.math import ifft2
5051
from keras.src.ops.math import in_top_k
5152
from keras.src.ops.math import irfft
5253
from keras.src.ops.math import istft

keras/api/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from keras.src.ops.math import extract_sequences
4848
from keras.src.ops.math import fft
4949
from keras.src.ops.math import fft2
50+
from keras.src.ops.math import ifft2
5051
from keras.src.ops.math import in_top_k
5152
from keras.src.ops.math import irfft
5253
from keras.src.ops.math import istft

keras/src/backend/jax/math.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ def fft2(x):
123123
return jnp.real(complex_output), jnp.imag(complex_output)
124124

125125

126+
def ifft2(x):
127+
complex_input = _get_complex_tensor_from_tuple(x)
128+
complex_output = jnp.fft.ifft2(complex_input)
129+
return jnp.real(complex_output), jnp.imag(complex_output)
130+
131+
126132
def rfft(x, fft_length=None):
127133
complex_output = jnp.fft.rfft(x, n=fft_length, axis=-1, norm="backward")
128134
return jnp.real(complex_output), jnp.imag(complex_output)

keras/src/backend/numpy/math.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,12 @@ def fft2(x):
144144
return np.array(real), np.array(imag)
145145

146146

147+
def ifft2(x):
148+
complex_input = _get_complex_tensor_from_tuple(x)
149+
complex_output = np.fft.ifft2(complex_input)
150+
return np.real(complex_output), np.imag(complex_output)
151+
152+
147153
def rfft(x, fft_length=None):
148154
complex_output = np.fft.rfft(x, n=fft_length, axis=-1, norm="backward")
149155
# numpy always outputs complex128, so we need to recast the dtype

keras/src/backend/tensorflow/math.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,15 @@ def fft2(x):
113113
return tf.math.real(complex_output), tf.math.imag(complex_output)
114114

115115

116+
def ifft2(x):
117+
real, imag = x
118+
h = cast(tf.shape(real)[-2], "float32")
119+
w = cast(tf.shape(real)[-1], "float32")
120+
real_conj, imag_conj = real, -imag
121+
fft_real, fft_imag = fft2((real_conj, imag_conj))
122+
return fft_real / (h * w), -fft_imag / (h * w)
123+
124+
116125
def rfft(x, fft_length=None):
117126
if fft_length is not None:
118127
fft_length = [fft_length]

keras/src/backend/torch/math.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,12 @@ def fft2(x):
203203
return complex_output.real, complex_output.imag
204204

205205

206+
def ifft2(x):
207+
complex_input = _get_complex_tensor_from_tuple(x)
208+
complex_output = torch.fft.ifft2(complex_input)
209+
return complex_output.real, complex_output.imag
210+
211+
206212
def rfft(x, fft_length=None):
207213
x = convert_to_tensor(x)
208214
complex_output = torch.fft.rfft(x, n=fft_length, dim=-1, norm="backward")

keras/src/ops/math.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,81 @@ def fft2(x):
473473
return backend.math.fft2(x)
474474

475475

476+
class IFFT2(Operation):
477+
def __init__(self):
478+
super().__init__()
479+
self.axes = (-2, -1)
480+
481+
def compute_output_spec(self, x):
482+
if not isinstance(x, (tuple, list)) or len(x) != 2:
483+
raise ValueError(
484+
"Input `x` should be a tuple of two tensors - real and "
485+
f"imaginary. Received: x={x}"
486+
)
487+
488+
real, imag = x
489+
# Both real and imaginary parts should have the same shape.
490+
if real.shape != imag.shape:
491+
raise ValueError(
492+
"Input `x` should be a tuple of two tensors - real and "
493+
"imaginary. Both the real and imaginary parts should have the "
494+
f"same shape. Received: x[0].shape = {real.shape}, "
495+
f"x[1].shape = {imag.shape}"
496+
)
497+
# We are calculating 2D IFFT. Hence, rank >= 2.
498+
if len(real.shape) < 2:
499+
raise ValueError(
500+
f"Input should have rank >= 2. "
501+
f"Received: input.shape = {real.shape}"
502+
)
503+
504+
# The axes along which we are calculating IFFT should be fully-defined.
505+
m = real.shape[self.axes[0]]
506+
n = real.shape[self.axes[1]]
507+
if m is None or n is None:
508+
raise ValueError(
509+
f"Input should have its {self.axes} axes fully-defined. "
510+
f"Received: input.shape = {real.shape}"
511+
)
512+
513+
return (
514+
KerasTensor(shape=real.shape, dtype=real.dtype),
515+
KerasTensor(shape=imag.shape, dtype=imag.dtype),
516+
)
517+
518+
def call(self, x):
519+
return backend.math.ifft2(x)
520+
521+
522+
@keras_export("keras.ops.ifft2")
523+
def ifft2(x):
524+
"""Computes the 2D Inverse Fast Fourier Transform along the last two axes of
525+
input.
526+
527+
Args:
528+
x: Tuple of the real and imaginary parts of the input tensor. Both
529+
tensors in the tuple should be of floating type.
530+
531+
Returns:
532+
A tuple containing two tensors - the real and imaginary parts of the
533+
output.
534+
535+
Example:
536+
537+
>>> x = (
538+
... keras.ops.convert_to_tensor([[1., 2.], [2., 1.]]),
539+
... keras.ops.convert_to_tensor([[0., 1.], [1., 0.]]),
540+
... )
541+
>>> ifft2(x)
542+
(array([[ 6., 0.],
543+
[ 0., -2.]], dtype=float32), array([[ 2., 0.],
544+
[ 0., -2.]], dtype=float32))
545+
"""
546+
if any_symbolic_tensors(x):
547+
return IFFT2().symbolic_call(x)
548+
return backend.math.ifft2(x)
549+
550+
476551
class RFFT(Operation):
477552
def __init__(self, fft_length=None):
478553
super().__init__()

keras/src/ops/math_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,15 @@ def test_fft2(self):
218218
self.assertEqual(real_output.shape, ref_shape)
219219
self.assertEqual(imag_output.shape, ref_shape)
220220

221+
def test_ifft2(self):
222+
real = KerasTensor((None, 4, 3), dtype="float32")
223+
imag = KerasTensor((None, 4, 3), dtype="float32")
224+
real_output, imag_output = kmath.ifft2((real, imag))
225+
ref = np.fft.ifft2(np.ones((2, 4, 3)))
226+
ref_shape = (None,) + ref.shape[1:]
227+
self.assertEqual(real_output.shape, ref_shape)
228+
self.assertEqual(imag_output.shape, ref_shape)
229+
221230
@parameterized.parameters([(None,), (1,), (5,)])
222231
def test_rfft(self, fft_length):
223232
x = KerasTensor((None, 4, 3), dtype="float32")
@@ -354,6 +363,14 @@ def test_fft2(self):
354363
self.assertEqual(real_output.shape, ref.shape)
355364
self.assertEqual(imag_output.shape, ref.shape)
356365

366+
def test_ifft2(self):
367+
real = KerasTensor((2, 4, 3), dtype="float32")
368+
imag = KerasTensor((2, 4, 3), dtype="float32")
369+
real_output, imag_output = kmath.ifft2((real, imag))
370+
ref = np.fft.ifft2(np.ones((2, 4, 3)))
371+
self.assertEqual(real_output.shape, ref.shape)
372+
self.assertEqual(imag_output.shape, ref.shape)
373+
357374
def test_rfft(self):
358375
x = KerasTensor((2, 4, 3), dtype="float32")
359376
real_output, imag_output = kmath.rfft(x)
@@ -715,6 +732,18 @@ def test_fft2(self):
715732
self.assertAllClose(real_ref, real_output)
716733
self.assertAllClose(imag_ref, imag_output)
717734

735+
def test_ifft2(self):
736+
real = np.random.random((2, 4, 3)).astype(np.float32)
737+
imag = np.random.random((2, 4, 3)).astype(np.float32)
738+
complex_arr = real + 1j * imag
739+
740+
real_output, imag_output = kmath.ifft2((real, imag))
741+
ref = np.fft.ifft2(complex_arr)
742+
real_ref = np.real(ref)
743+
imag_ref = np.imag(ref)
744+
self.assertAllClose(real_ref, real_output)
745+
self.assertAllClose(imag_ref, imag_output)
746+
718747
@parameterized.parameters([(None,), (3,), (15,)])
719748
def test_rfft(self, n):
720749
# Test 1D.

0 commit comments

Comments
 (0)