From 52348dab797c5df0b01e2b75cfd2244cd30397bf Mon Sep 17 00:00:00 2001 From: lijiaqi0612 <33169170+lijiaqi0612@users.noreply.github.com> Date: Fri, 10 Sep 2021 11:48:28 +0800 Subject: [PATCH] Add C2R Python layer normal and abnormal use cases (#29) * documents and single case * test c2r case * New C2R Python layer normal and exception use cases --- .../fluid/tests/unittests/fft/test_fft.py | 434 +++++++++++++++++- .../tests/unittests/fft/test_spectral_op.py | 49 ++ python/paddle/tensor/fft.py | 274 ++++++++++- 3 files changed, 735 insertions(+), 22 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/fft/test_fft.py b/python/paddle/fluid/tests/unittests/fft/test_fft.py index e2a1e3306e288..af06653f45968 100644 --- a/python/paddle/fluid/tests/unittests/fft/test_fft.py +++ b/python/paddle/fluid/tests/unittests/fft/test_fft.py @@ -15,11 +15,13 @@ import re import sys import unittest -from math import prod +from scipy.fft import hfftn, hfft2 import numpy as np import paddle +TEST_CASE_NAME = "test_case" + def setUpModule(): global rtol @@ -34,7 +36,7 @@ def tearDownModule(): def rand_x(dims=1, dtype='float32', min_dim_len=1, max_dim_len=10): """generate random input""" - shape = {np.random.randint(min_dim_len, max_dim_len) for i in range(dims)} + shape = [np.random.randint(min_dim_len, max_dim_len) for i in range(dims)] return np.random.randn(*shape).astype(dtype) @@ -69,7 +71,7 @@ def decorator(base_class): def class_name(cls, num, params_dict): suffix = to_safe_name( next((v for v in params_dict.values() if isinstance(v, str)), "")) - if "test_case" in params_dict: + if TEST_CASE_NAME in params_dict: suffix = to_safe_name(params_dict["test_case"]) return "{}_{}{}".format(cls.__name__, num, suffix and "_" + suffix) @@ -78,20 +80,420 @@ def to_safe_name(s): return str(re.sub("[^a-zA-Z0-9_]+", "_", s)) -@parameterize(('x', 'n', 'axis', 'norm'), [ - (rand_x(1), None, -1, 'backward'), - (rand_x(3, np.float32), None, -1, 'backward'), - (rand_x(3, np.float64), None, -1, 'backward'), +@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ + ('test_x_complex128', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.complex128), None, -1, "backward"), + ('test_n_grater_than_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), 4, -1, + "backward"), + ('test_n_smaller_than_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), 2, -1, + "backward"), + ('test_axis_not_last', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, 1, + "backward"), + ('test_norm_forward', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, 1, + "forward"), + ('test_norm_ortho', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, -1, + "ortho"), +]) +class TestHfft(unittest.TestCase): + """Test hfft with norm condition + """ + + def test_hfft(self): + np.testing.assert_allclose( + np.fft.hfft(self.x, self.n, self.axis, self.norm), + paddle.tensor.fft.hfft( + paddle.to_tensor(self.x), self.n, self.axis, self.norm), + rtol=1e-5, + atol=0) + + +@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ + ('test_x_complex128', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.complex128), None, -1, "backward"), + ('test_n_grater_than_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), 4, -1, + "backward"), + ('test_n_smaller_than_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), 2, -1, + "backward"), + ('test_axis_not_last', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, -1, + "backward"), + ('test_norm_forward', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, -1, + "forward"), + ('test_norm_ortho', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, -1, + "ortho"), +]) +class TestIrfft(unittest.TestCase): + """Test irfft with norm condition + """ + + def test_irfft(self): + np.testing.assert_allclose( + np.fft.irfft(self.x, self.n, self.axis, self.norm), + paddle.tensor.fft.irfft( + paddle.to_tensor(self.x), self.n, self.axis, self.norm), + rtol=1e-5, + atol=0) + + +@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ + ('test_x_complex128', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.complex128), None, None, "backward"), + ('test_n_grater_than_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [4], None, + "backward"), + ('test_n_smaller_than_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [2], None, + "backward"), + ('test_axis_not_last', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None, + "backward"), + ('test_norm_forward', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None, + "forward"), + ('test_norm_ortho', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None, + "ortho"), +]) +class Testirfftn(unittest.TestCase): + """Test irfftn with norm condition + """ + + def test_irfftn(self): + np.testing.assert_allclose( + np.fft.irfftn(self.x, self.n, self.axis, self.norm), + paddle.tensor.fft.irfftn( + paddle.to_tensor(self.x), self.n, self.axis, self.norm), + rtol=1e-5, + atol=0) + + +@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ + ('test_x_complex128', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.complex128), None, None, "backward"), + ('test_n_grater_than_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [4], None, + "backward"), + ('test_n_smaller_than_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [2], None, + "backward"), + ('test_axis_not_last', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None, + "backward"), + ('test_norm_forward', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None, + "forward"), + ('test_norm_ortho', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None, + "ortho"), +]) +class Testhfftn(unittest.TestCase): + """Test hfftn with norm condition + """ + + def test_hfftn(self): + np.testing.assert_allclose( + hfftn(self.x, self.n, self.axis, self.norm), + paddle.tensor.fft.hfftn( + paddle.to_tensor(self.x), self.n, self.axis, self.norm), + rtol=1e-5, + atol=0) + + +@parameterize((TEST_CASE_NAME, 'x', 's', 'axis', 'norm'), [ + ('test_x_complex128', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.complex128), None, (-2, -1), "backward"), + ('test_n_grater_than_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [1, 2], (-2, -1), + "backward"), + ('test_n_smaller_than_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [2, 1], (-2, -1), + "backward"), + ('test_axis_not_last', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1), + "backward"), + ('test_norm_forward', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1), + "forward"), + ('test_norm_ortho', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1), + "ortho"), +]) +class Testhfft2(unittest.TestCase): + """Test hfft2 with norm condition + """ + + def test_hfft2(self): + np.testing.assert_allclose( + hfft2(self.x, self.s, self.axis, self.norm), + paddle.tensor.fft.hfft2( + paddle.to_tensor(self.x), self.s, self.axis, self.norm), + rtol=1e-5, + atol=0) + + +@parameterize((TEST_CASE_NAME, 'x', 's', 'axis', 'norm'), [ + ('test_x_complex128', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.complex128), None, (-2, -1), "backward"), + ('test_n_equal_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (2, 1), (-2, -1), + "backward"), + ('test_axis_not_last', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1), + "backward"), + ('test_norm_forward', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1), + "forward"), + ('test_norm_ortho', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1), + "ortho"), ]) -class TestRfft(unittest.TestCase): - def test_rfft(self): - self.assertTrue( - np.allclose( - np.fft.rfft(self.x, self.n, self.axis, self.norm), - paddle.tensor.fft.rfft( - paddle.to_tensor(self.x), self.n, self.axis, self.norm), - rtol=rtol.get(str(self.x.dtype)), - atol=atol.get(str(self.x.dtype)))) +class TestIrfft2(unittest.TestCase): + """Test irfft2 with norm condition + """ + + def test_irfft2(self): + np.testing.assert_allclose( + np.fft.irfft2(self.x, self.s, self.axis, self.norm), + paddle.tensor.fft.irfft2( + paddle.to_tensor(self.x), self.s, self.axis, self.norm), + rtol=1e-5, + atol=0) + + +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), + [('test_input_dtype', np.random.randn(4, 4, 4), None, -1, 'backward', + ValueError), ('test_bool_input', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.bool8), None, -1, 'backward', ValueError), + ('test_n_nagative', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), -1, -1, + 'backward', ValueError), + ('test_n_zero', np.random.randn(4, 4) + 1j * np.random.randn(4, 4), 0, -1, + 'backward', ValueError), + ('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + (1, 2, 3), -1, 'backward', ValueError), + ('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4), + None, 10, 'backward', ValueError), ( + 'test_axis_with_array', np.random.randn(4) + 1j * np.random.randn(4), + None, (0, 1), 'backward', + ValueError), ('test_norm_not_in_enum_value', + np.random.randn(4, 4) + 1j * np.random.randn(4, 4), + None, -1, 'random', ValueError)]) +class TestHfftException(unittest.TestCase): + '''Test hfft with buoudary condition + Test case include: + - non complex input + - n out of range + - axis out of range + - norm out of range + ''' + + def test_hfft(self): + with self.assertRaises(self.expect_exception): + paddle.tensor.fft.rfft(self.x, self.n, self.axis, self.norm) + + +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), + [('test_input_dtype', np.random.randn(4, 4, 4), None, -1, 'backward', + ValueError), ('test_bool_input', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.bool8), None, -1, 'backward', ValueError), + ('test_n_nagative', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), -1, -1, + 'backward', ValueError), + ('test_n_zero', np.random.randn(4, 4) + 1j * np.random.randn(4, 4), 0, -1, + 'backward', ValueError), + ('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + (1, 2), -1, 'backward', ValueError), + ('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4), + None, 10, 'backward', ValueError), ( + 'test_axis_with_array', np.random.randn(4) + 1j * np.random.randn(4), + None, (0, 1), 'backward', + ValueError), ('test_norm_not_in_enum_value', + np.random.randn(4, 4) + 1j * np.random.randn(4, 4), + None, None, 'random', ValueError)]) +class TestIrfftException(unittest.TestCase): + '''Test Irfft with buoudary condition + Test case include: + - non complex input + - n out of range + - axis out of range + - norm out of range + - the dimensions of n and axis are different + ''' + + def test_irfft(self): + with self.assertRaises(self.expect_exception): + paddle.tensor.fft.irfft(self.x, self.n, self.axis, self.norm) + + +@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ + ('test_input_dtype', np.random.randn(4, 4, 4), None, None, 'backward', + ValueError), ('test_bool_input', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.bool8), None, (-2, -1), 'backward', ValueError), + ('test_n_nagative', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2), + (-2, -1), 'backward', ValueError), + ('test_n_zero', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + (0, 0), (-2, -1), 'backward', ValueError), + ('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), 3, + None, 'backward', + ValueError), ('test_n_axis_dim', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + (1, 2), (-1), 'backward', ValueError), + ('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4), + None, (1, 2), 'backward', ValueError), ( + 'test_axis_type', np.random.randn(4) + 1j * np.random.randn(4), None, + -1, 'backward', + ValueError), ('test_norm_not_in_enum_value', + np.random.randn(4, 4) + 1j * np.random.randn(4, 4), None, + None, 'random', ValueError) +]) +class TestHfft2Exception(unittest.TestCase): + '''Test hfft2 with buoudary condition + Test case include: + - non complex input + - n out of range + - axis out of range + - the dimensions of n and axis are different + - norm out of range + ''' + + def test_hfft2(self): + with self.assertRaises(self.expect_exception): + paddle.tensor.fft.hfft2(self.x, self.n, self.axis, self.norm) + + +@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ + ('test_input_dtype', np.random.randn(4, 4, 4), None, None, 'backward', + ValueError), ('test_bool_input', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.bool8), None, (-2, -1), 'backward', ValueError), + ('test_n_nagative', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2), + (-2, -1), 'backward', ValueError), + ('test_n_zero', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + (0, 0), (-2, -1), 'backward', ValueError), + ('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), 3, + -1, 'backward', + ValueError), ('test_n_axis_dim', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + (1, 2), (-3, -2, -1), 'backward', ValueError), + ('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4), + None, (1, 2), 'backward', ValueError), ( + 'test_axis_type', np.random.randn(4) + 1j * np.random.randn(4), None, + 1, 'backward', + ValueError), ('test_norm_not_in_enum_value', + np.random.randn(4, 4) + 1j * np.random.randn(4, 4), None, + None, 'random', ValueError) +]) +class TestIrfft2Exception(unittest.TestCase): + '''Test irfft2 with buoudary condition + Test case include: + - non complex input + - n out of range + - axis out of range + - norm out of range + - the dimensions of n and axis are different + ''' + + def test_irfft2(self): + with self.assertRaises(self.expect_exception): + paddle.tensor.fft.irfft2(self.x, self.n, self.axis, self.norm) + + +@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ + ('test_input_dtype', np.random.randn(4, 4, 4), None, None, 'backward', + ValueError), ('test_bool_input', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.bool8), None, (-2, -1), 'backward', ValueError), + ('test_n_nagative', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2), + (-2, -1), 'backward', ValueError), + ('test_n_zero', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + (0, 0), (-2, -1), 'backward', ValueError), + ('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), 3, + -1, 'backward', + ValueError), ('test_n_axis_dim', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + (1, 2), (-3, -2, -1), 'backward', ValueError), + ('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4), + None, (10, 20), 'backward', ValueError), ( + 'test_axis_type', np.random.randn(4) + 1j * np.random.randn(4), None, + 1, 'backward', + ValueError), ('test_norm_not_in_enum_value', + np.random.randn(4, 4) + 1j * np.random.randn(4, 4), None, + None, 'random', ValueError) +]) +class TestHfftnException(unittest.TestCase): + '''Test hfftn with buoudary condition + Test case include: + - non complex input + - n out of range + - axis out of range + - norm out of range + - the dimensions of n and axis are different + ''' + + def test_hfftn(self): + with self.assertRaises(self.expect_exception): + paddle.tensor.fft.hfftn(self.x, self.n, self.axis, self.norm) + + +@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ + ('test_input_dtype', np.random.randn(4, 4, 4), None, None, 'backward', + ValueError), ('test_bool_input', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.bool8), None, (-2, -1), 'backward', ValueError), + ('test_n_nagative', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2), + (-2, -1), 'backward', ValueError), + ('test_n_zero', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + (0, 0), (-2, -1), 'backward', ValueError), + ('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), 3, + -1, 'backward', + ValueError), ('test_n_axis_dim', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + (1, 2), (-3, -2, -1), 'backward', ValueError), + ('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4), + None, (10, 20), 'backward', ValueError), ( + 'test_axis_type', np.random.randn(4) + 1j * np.random.randn(4), None, + 1, 'backward', + ValueError), ('test_norm_not_in_enum_value', + np.random.randn(4, 4) + 1j * np.random.randn(4, 4), None, + None, 'random', ValueError) +]) +class TestIrfftnException(unittest.TestCase): + '''Test irfftn with buoudary condition + Test case include: + - non complex input + - n out of range + - axis out of range + - norm out of range + - the dimensions of n and axis are different + ''' + + def test_irfftn(self): + with self.assertRaises(self.expect_exception): + paddle.tensor.fft.irfftn(self.x, self.n, self.axis, self.norm) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/fft/test_spectral_op.py b/python/paddle/fluid/tests/unittests/fft/test_spectral_op.py index b9a7651e44909..9c3d3ea596cd8 100644 --- a/python/paddle/fluid/tests/unittests/fft/test_spectral_op.py +++ b/python/paddle/fluid/tests/unittests/fft/test_spectral_op.py @@ -11,3 +11,52 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core +import sys.path +from ..op_test import OpTest +from paddle.fluid import Program, program_guard +import paddle.fluid.dygraph as dg +import paddle.static as static +from numpy.random import random as rand + +paddle.enable_static() + + +class TestFFTC2ROp(OpTest): + def setUp(self): + self.op_type = "fft_c2r" + self.init_dtype_type() + self.init_input_output() + self.init_grad_input_output() + + def init_dtype_type(self): + self.dtype = np.complex64 + + def init_input_output(self): + x = (np.random.random((12, 14)) + 1j * np.random.random( + (12, 14))).astype(self.dtype) + out = np.conj(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + + def init_grad_input_output(self): + self.grad_out = (np.ones((12, 14)) + 1j * np.ones( + (12, 14))).astype(self.dtype) + self.grad_in = np.conj(self.grad_out) + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad( + ['X'], + 'Out', + user_defined_grads=[self.grad_in], + user_defined_grad_outputs=[self.grad_out]) diff --git a/python/paddle/tensor/fft.py b/python/paddle/tensor/fft.py index 1b3691d7fd647..44d5c39f2dfd5 100644 --- a/python/paddle/tensor/fft.py +++ b/python/paddle/tensor/fft.py @@ -147,10 +147,99 @@ def rfft(x, n=None, axis=-1, norm="backward", name=None): def irfft(x, n=None, axis=-1, norm="backward", name=None): + """ + Computes the inverse of `rfft`. + + This function calculates the inverse of the one-dimensional *n* point discrete + Fourier transform of the actual input calculated by "rfft". In other words, + ``irfft(rfft(a),len(a)) == a`` is within the numerical accuracy range. + + The input shall be in the form of "rfft", i.e. the actual zero frequency term, + followed by the complex positive frequency term, in the order of increasing frequency. + Because the discrete Fourier transform of the actual input is Hermite symmetric, + the negative frequency term is regarded as the complex conjugate term of the corresponding + positive frequency term. + + Args: + x (Tensor): The input data. It's a Tensor type. Data type: float32, float64. + n (int, optional): The length of the output transform axis. For `n` output + points, ``n//2 + 1``input points are necessary. If the length of the input tensor is greater + than `n`, it will be cropped, if it is shorter than this, fill in zero. If `n` is not given, + it is considered to be ``2 * (k-1)``, where ``k`` is the length of the input axis specified + along the ` axis'. + axis (int, optional): Axis used to calculate FFT. If not specified, the last axis + is used by default. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward". + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name` . + + Returns: + Real tensor. Truncated or zero fill input for the transformation along the axis indicated by + `axis`, or the last input if `axis` is not specified. The length of the conversion axis + is `n`, or ``2 * k-2``, if `k` is None, where `k` is the length of the input conversion axis. + If the output is an odd number, you need to specify the value of 'n', such as ``2 * k-1`` + in some cases. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.array([1, 2, 3, 4, 3, 2]) + xp = paddle.to_tensor(x) + irfft_xp = paddle.tensor.fft.irfft(xp).numpy() + print(irfft_xp) + # [15.+0.j, -4.+0.j, 0.+0.j, -1.-0.j, 0.+0.j, -4.+0.j] + + """ return fft_c2r(x, n, axis, norm, forward=False, name=name) def hfft(x, n=None, axis=-1, norm="backward", name=None): + """ + Compute the FFT of a signal that has Hermitian symmetry, a real + spectrum. + + Args: + x (Tensor): The input data. It's a Tensor type. + n (int, optional): The length of the output transform axis. For `n` output + points, ``n//2 + 1`` input points are necessary. If the length of the input tensor is greater + than `n`, it will be cropped, if it is shorter than this, fill in zero. If `n` is not given, + it is considered to be ``2 * (k-1)``, where ``k`` is the length of the input axis specified + along the ` axis'. + axis (int,optional): Axis used to calculate FFT. If not specified, the last axis + is used by default. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward". + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name` . + + Returns: + Real tensor. Truncated or zero fill input for the transformation along the axis indicated by + `axis`, or the last input if `axis` is not specified. The length of the conversion axis + is `n`, or ``2 * k-2``, if `k` is None, where `k` is the length of the input conversion axis. + If the output is an odd number, you need to specify the value of 'n', such as ``2 * k-1`` in + some cases. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.array([1, 2, 3, 4, 3, 2]) + xp = paddle.to_tensor(x) + hfft_xp = paddle.tensor.fft.hfft(xp).numpy() + print(hfft_xp) + # [15.+0.j, -4.+0.j, 0.+0.j, -1.-0.j, 0.+0.j, -4.+0.j] + """ + return fft_c2r(x, n, axis, norm, forward=True, name=name) @@ -180,10 +269,120 @@ def rfftn(x, s=None, axes=None, norm="backward", name=None): def irfftn(x, s=None, axes=None, norm="backward", name=None): + """ + Computes the inverse of `rfftn`. + + This function computes the inverse of the N-D discrete + Fourier Transform for real input over any number of axes in an + M-D array by means of the Fast Fourier Transform (FFT). In + other words, ``irfftn(rfftn(x), x.shape) == x`` to within numerical + accuracy. (The ``a.shape`` is necessary like ``len(a)`` is for `irfft`, + and for the same reason.) + + The input should be ordered in the same way as is returned by `rfftn`, + i.e., as for `irfft` for the final transformation axis, and as for `ifftn` + along all the other axes. + + Args: + x (Tensor): The input data. It's a Tensor type. + s (sequence of ints, optional): The length of the output transform axis. + (``s[0]`` refers to axis 0, ``s[1]`` to axis 1, etc.). `s` is also the + number of input points used along this axis, except for the last axis, + where ``s[-1]//2+1`` points of the input are used. Along any axis, if + the shape indicated by `s` is smaller than that of the input, the input + is cropped. If it is larger, the input is padded with zeros. + If `s` is not given, the shape of the input along the axes specified by axes + is used. Except for the last axis which is taken to be ``2*(k-1)`` where + ``k`` is the length of the input along that axis. + axis (sequence of ints, optional): Axes over which to compute the inverse FFT. If not given, the last + `len(s)` axes are used, or all axes if `s` is also not specified. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward". + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name` . + + Returns: + Real tensor. The truncated or zero-padded input, transformed along the axes indicated by `axes`, + or by a combination of `s` or `x`, as explained in the parameters section above. The length of + each transformed axis is as given by the corresponding element of `s`, or the length of the input + in every axis except for the last one if `s` is not given. In the final transformed axis the length + of the output when `s` is not given is ``2*(m-1)``, where ``m`` is the length of the final + transformed axis of the input. To get an odd number of output points in the final axis, + `s` must be specified. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.zeros((3, 2, 2)) + x[0, 0, 0] = 3 * 2 * 2 + xp = paddle.to_tensor(x) + irfftn_xp = paddle.tensor.fft.irfftn(xp).numpy() + print(irfftn_xp) + # [[[1., 1.], + # [1., 1.]], + # [[1., 1.], + # [1., 1.]], + # [[1., 1.], + # [1., 1.]]] + + """ return fftn_c2r(x, s, axes, norm, forward=False, name=name) def hfftn(x, s=None, axes=None, norm="backward", name=None): + """ + Compute the N-D FFT of Hermitian symmetric complex input, i.e., a + signal with a real spectrum. + + This function calculates the n-D discrete Fourier transform of Hermite symmetric + complex input on any axis in M-D array by fast Fourier transform (FFT). + In other words, ``ihfftn(hfftn(x, s)) == x is within the numerical accuracy range. + (``s`` here are ``x.shape`` and ``s[-1] = x.shape[- 1] * 2 - 1``. This is necessary + for the same reason that ``irfft` requires ``x.shape``.) + + Args: + x (Tensor): The input data. It's a Tensor type. + s (sequence of ints, optional): The length of the output transform axis. + (``s[0]`` refers to axis 0, ``s[1]`` to axis 1, etc.). `s` is also the + number of input points used along this axis, except for the last axis, + where ``s[-1]//2+1`` points of the input are used. Along any axis, if + the shape indicated by `s` is smaller than that of the input, the input + is cropped. If it is larger, the input is padded with zeros. + If `s` is not given, the shape of the input along the axes specified by axes + is used. Except for the last axis which is taken to be ``2*(k-1)`` where + ``k`` is the length of the input along that axis. + axis (sequence of ints, optional): Axes over which to compute the inverse FFT. If not given, the last + `len(s)` axes are used, or all axes if `s` is also not specified. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward". + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name` . + + Returns: + Real tensor. Truncate or zero fill input, transforming along the axis indicated by axis or + a combination of `s` or `X`. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.array([1, 2, 3, 4, 3, 2]) + xp = paddle.to_tensor(x) + hfftn_xp = paddle.tensor.fft.hfftn(xp).numpy() + print(hfftn_xp) + # [15.+0.j, -4.+0.j, 0.+0.j, -1.-0.j, 0.+0.j, -4.+0.j] + + + """ return fftn_c2r(x, s, axes, norm, forward=True, name=name) @@ -200,7 +399,7 @@ def fft2(x, s=None, axes=(-2, -1), norm="backward", name=None): "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". format(s)) if axes is not None: - if not isinstance(axes, Sequence) or len(s) != 2: + if not isinstance(axes, Sequence) or len(axes) != 2: raise ValueError( "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". format(axes)) @@ -215,7 +414,7 @@ def ifft2(x, s=None, axes=(-2, -1), norm="backward", name=None): "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". format(s)) if axes is not None: - if not isinstance(axes, Sequence) or len(s) != 2: + if not isinstance(axes, Sequence) or len(axes) != 2: raise ValueError( "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". format(axes)) @@ -230,7 +429,7 @@ def rfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". format(s)) if axes is not None: - if not isinstance(axes, Sequence) or len(s) != 2: + if not isinstance(axes, Sequence) or len(axes) != 2: raise ValueError( "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". format(axes)) @@ -238,6 +437,37 @@ def rfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): def irfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): + """ + Computes the inverse of `rfft2`. + + Args: + x (Tensor): The input data. It's a Tensor type. + s (sequence of ints, optional): Shape of the real output to the inverse FFT. + axis (sequence of ints, optional): The axes over which to compute the inverse FFT. If not specified, + the last two axes are used by default. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward". + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name` . + + Returns: + Real tensor. The result of the inverse real 2-D FFT. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.array([1, 2, 3, 4, 3, 2]) + xp = paddle.to_tensor(x) + irfft2_xp = paddle.tensor.fft.irfft2(xp).numpy() + print(irfft2_xp) + # [15.+0.j, -4.+0.j, 0.+0.j, -1.-0.j, 0.+0.j, -4.+0.j] + + """ _check_at_least_ndim(x, 2) if s is not None: if not isinstance(s, Sequence) or len(s) != 2: @@ -245,7 +475,7 @@ def irfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". format(s)) if axes is not None: - if not isinstance(axes, Sequence) or len(s) != 2: + if not isinstance(axes, Sequence) or len(axes) != 2: raise ValueError( "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". format(axes)) @@ -253,6 +483,38 @@ def irfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): def hfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): + """ + Compute the 2-D FFT of a Hermitian complex array. + + Args: + x (Tensor): The input data. It's a Tensor type. + s (sequence of ints, optional): Shape of the real output. + axis (sequence of ints, optional): Axes over which to compute the FFT. If not specified, + the last two axes are used by default. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward". + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name` . + + Returns: + Real tensor. The real result of the 2-D Hermitian complex real FFT. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.array([1, 2, 3, 4, 3, 2]) + xp = paddle.to_tensor(x) + hfft2_xp = paddle.tensor.fft.hfft2(xp).numpy() + print(hfft2_xp) + # [15.+0.j, -4.+0.j, 0.+0.j, -1.-0.j, 0.+0.j, -4.+0.j] + + + """ _check_at_least_ndim(x, 2) if s is not None: if not isinstance(s, Sequence) or len(s) != 2: @@ -260,7 +522,7 @@ def hfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". format(s)) if axes is not None: - if not isinstance(axes, Sequence) or len(s) != 2: + if not isinstance(axes, Sequence) or len(axes) != 2: raise ValueError( "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". format(axes)) @@ -275,7 +537,7 @@ def ihfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". format(s)) if axes is not None: - if not isinstance(axes, Sequence) or len(s) != 2: + if not isinstance(axes, Sequence) or len(axes) != 2: raise ValueError( "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". format(axes))