Skip to content

Commit acd1dda

Browse files
committed
adding ifft2 method to ops
1 parent c052cea commit acd1dda

File tree

6 files changed

+140
-0
lines changed

6 files changed

+140
-0
lines changed

keras/src/backend/jax/math.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,15 @@ def fft2(x):
123123
return jnp.real(complex_output), jnp.imag(complex_output)
124124

125125

126+
def ifft2(x):
127+
real, imag = x
128+
H = cast(jnp.shape(real)[-2], jnp.float32)
129+
W = cast(jnp.shape(real)[-1], jnp.float32)
130+
real_conj, imag_conj = real, -imag
131+
fft_real, fft_imag = fft2((real_conj, imag_conj))
132+
return fft_real / (H * W), fft_imag / (H * W)
133+
134+
126135
def rfft(x, fft_length=None):
127136
complex_output = jnp.fft.rfft(x, n=fft_length, axis=-1, norm="backward")
128137
return jnp.real(complex_output), jnp.imag(complex_output)

keras/src/backend/numpy/math.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,15 @@ def fft2(x):
144144
return np.array(real), np.array(imag)
145145

146146

147+
def ifft2(x):
148+
real, imag = x
149+
H = np.float32(real.shape[-2])
150+
W = np.float32(real.shape[-1])
151+
real_conj, imag_conj = real, -imag
152+
fft_real, fft_imag = fft2((real_conj, imag_conj))
153+
return fft_real / (H * W), -fft_imag / (H * W)
154+
155+
147156
def rfft(x, fft_length=None):
148157
complex_output = np.fft.rfft(x, n=fft_length, axis=-1, norm="backward")
149158
# numpy always outputs complex128, so we need to recast the dtype

keras/src/backend/tensorflow/math.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,15 @@ def fft2(x):
113113
return tf.math.real(complex_output), tf.math.imag(complex_output)
114114

115115

116+
def ifft2(x):
117+
real, imag = x
118+
H = cast(tf.shape(real)[-2], "float32")
119+
W = cast(tf.shape(real)[-1], "float32")
120+
real_conj, imag_conj = real, -imag
121+
fft_real, fft_imag = fft2((real_conj, imag_conj))
122+
return fft_real / (H * W), -fft_imag / (H * W)
123+
124+
116125
def rfft(x, fft_length=None):
117126
if fft_length is not None:
118127
fft_length = [fft_length]

keras/src/backend/torch/math.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,15 @@ def fft2(x):
203203
return complex_output.real, complex_output.imag
204204

205205

206+
def ifft2(x):
207+
real, imag = x
208+
H = cast(torch.tensor(real.shape[-2]), "float32")
209+
W = cast(torch.tensor(real.shape[-1]), "float32")
210+
real_conj, imag_conj = real, -imag
211+
fft_real, fft_imag = fft2((real_conj, imag_conj))
212+
return fft_real / (H * W), -fft_imag / (H * W)
213+
214+
206215
def rfft(x, fft_length=None):
207216
x = convert_to_tensor(x)
208217
complex_output = torch.fft.rfft(x, n=fft_length, dim=-1, norm="backward")

keras/src/ops/math.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,81 @@ def fft2(x):
475475
return backend.math.fft2(x)
476476

477477

478+
class IFFT2(Operation):
479+
def __init__(self):
480+
super().__init__()
481+
self.axes = (-2, -1)
482+
483+
def compute_output_spec(self, x):
484+
if not isinstance(x, (tuple, list)) or len(x) != 2:
485+
raise ValueError(
486+
"Input `x` should be a tuple of two tensors - real and "
487+
f"imaginary. Received: x={x}"
488+
)
489+
490+
real, imag = x
491+
# Both real and imaginary parts should have the same shape.
492+
if real.shape != imag.shape:
493+
raise ValueError(
494+
"Input `x` should be a tuple of two tensors - real and "
495+
"imaginary. Both the real and imaginary parts should have the "
496+
f"same shape. Received: x[0].shape = {real.shape}, "
497+
f"x[1].shape = {imag.shape}"
498+
)
499+
# We are calculating 2D IFFT. Hence, rank >= 2.
500+
if len(real.shape) < 2:
501+
raise ValueError(
502+
f"Input should have rank >= 2. "
503+
f"Received: input.shape = {real.shape}"
504+
)
505+
506+
# The axes along which we are calculating IFFT should be fully-defined.
507+
m = real.shape[self.axes[0]]
508+
n = real.shape[self.axes[1]]
509+
if m is None or n is None:
510+
raise ValueError(
511+
f"Input should have its {self.axes} axes fully-defined. "
512+
f"Received: input.shape = {real.shape}"
513+
)
514+
515+
return (
516+
KerasTensor(shape=real.shape, dtype=real.dtype),
517+
KerasTensor(shape=imag.shape, dtype=imag.dtype),
518+
)
519+
520+
def call(self, x):
521+
return backend.math.ifft2(x)
522+
523+
524+
@keras_export("keras.ops.ifft2")
525+
def ifft2(x):
526+
"""Computes the 2D Inverse Fast Fourier Transform along the last two axes of
527+
input.
528+
529+
Args:
530+
x: Tuple of the real and imaginary parts of the input tensor. Both
531+
tensors in the tuple should be of floating type.
532+
533+
Returns:
534+
A tuple containing two tensors - the real and imaginary parts of the
535+
output.
536+
537+
Example:
538+
539+
>>> x = (
540+
... keras.ops.convert_to_tensor([[1., 2.], [2., 1.]]),
541+
... keras.ops.convert_to_tensor([[0., 1.], [1., 0.]]),
542+
... )
543+
>>> ifft2(x)
544+
(array([[ 6., 0.],
545+
[ 0., -2.]], dtype=float32), array([[ 2., 0.],
546+
[ 0., -2.]], dtype=float32))
547+
"""
548+
if any_symbolic_tensors(x):
549+
return IFFT2().symbolic_call(x)
550+
return backend.math.ifft2(x)
551+
552+
478553
class RFFT(Operation):
479554
def __init__(self, fft_length=None):
480555
super().__init__()

keras/src/ops/math_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,15 @@ def test_fft2(self):
219219
self.assertEqual(real_output.shape, ref_shape)
220220
self.assertEqual(imag_output.shape, ref_shape)
221221

222+
def test_ifft2(self):
223+
real = KerasTensor((None, 4, 3), dtype="float32")
224+
imag = KerasTensor((None, 4, 3), dtype="float32")
225+
real_output, imag_output = kmath.ifft2((real, imag))
226+
ref = np.fft.ifft2(np.ones((2, 4, 3)))
227+
ref_shape = (None,) + ref.shape[1:]
228+
self.assertEqual(real_output.shape, ref_shape)
229+
self.assertEqual(imag_output.shape, ref_shape)
230+
222231
@parameterized.parameters([(None,), (1,), (5,)])
223232
def test_rfft(self, fft_length):
224233
x = KerasTensor((None, 4, 3), dtype="float32")
@@ -355,6 +364,14 @@ def test_fft2(self):
355364
self.assertEqual(real_output.shape, ref.shape)
356365
self.assertEqual(imag_output.shape, ref.shape)
357366

367+
def test_ifft2(self):
368+
real = KerasTensor((2, 4, 3), dtype="float32")
369+
imag = KerasTensor((2, 4, 3), dtype="float32")
370+
real_output, imag_output = kmath.ifft2((real, imag))
371+
ref = np.fft.ifft2(np.ones((2, 4, 3)))
372+
self.assertEqual(real_output.shape, ref.shape)
373+
self.assertEqual(imag_output.shape, ref.shape)
374+
358375
def test_rfft(self):
359376
x = KerasTensor((2, 4, 3), dtype="float32")
360377
real_output, imag_output = kmath.rfft(x)
@@ -717,6 +734,18 @@ def test_fft2(self):
717734
self.assertAllClose(real_ref, real_output)
718735
self.assertAllClose(imag_ref, imag_output)
719736

737+
def test_ifft2(self):
738+
real = np.random.random((2, 4, 3))
739+
imag = np.random.random((2, 4, 3))
740+
complex_arr = real + 1j * imag
741+
742+
real_output, imag_output = kmath.ifft2((real, imag))
743+
ref = np.fft.ifft2(complex_arr)
744+
real_ref = np.real(ref)
745+
imag_ref = np.imag(ref)
746+
self.assertAllClose(real_ref, real_output)
747+
self.assertAllClose(imag_ref, imag_output)
748+
720749
@parameterized.parameters([(None,), (3,), (15,)])
721750
def test_rfft(self, n):
722751
# Test 1D.

0 commit comments

Comments
 (0)