Skip to content

Commit

Permalink
Implementation of dpnp.fft.fft2, dpnp.fft.ifft2, dpnp.fft.fftn,…
Browse files Browse the repository at this point in the history
… `dpnp.fft.ifftn` (#1961)

* implement fft2, ifft2, fftn, ifftn

* unmute a few tests

* Revert "unmute a few tests"

This reverts commit de6e0e3.

* update a few tests

* improve coverage + update tests

* address comments

* raise Error for deprecated behaviors in NumPy 2.0

* only support sequence for s and axes

* update a test + keep functionin alphabetic order

* update when both s and axes are given

* revert incorrect change for cupy test
  • Loading branch information
vtavana authored Aug 12, 2024
1 parent 05196fe commit 4b3b324
Show file tree
Hide file tree
Showing 9 changed files with 1,006 additions and 261 deletions.
1 change: 1 addition & 0 deletions .github/workflows/conda-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ env:
test_umath.py
test_usm_type.py
third_party/cupy/core_tests
third_party/cupy/fft_tests
third_party/cupy/creation_tests
third_party/cupy/indexing_tests/test_indexing.py
third_party/cupy/lib_tests
Expand Down
518 changes: 381 additions & 137 deletions dpnp/fft/dpnp_iface_fft.py

Large diffs are not rendered by default.

231 changes: 220 additions & 11 deletions dpnp/fft/dpnp_utils_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,16 @@
# pylint: disable=protected-access
# pylint: disable=no-name-in-module

from collections.abc import Sequence

import dpctl
import dpctl.tensor._tensor_impl as ti
import dpctl.utils as dpu
import numpy
from dpctl.tensor._numpy_helper import normalize_axis_index
from dpctl.tensor._numpy_helper import (
normalize_axis_index,
normalize_axis_tuple,
)
from dpctl.utils import ExecutionPlacementError

import dpnp
Expand All @@ -54,6 +59,7 @@

__all__ = [
"dpnp_fft",
"dpnp_fftn",
]


Expand Down Expand Up @@ -159,6 +165,37 @@ def _compute_result(dsc, a, out, forward, c2c, a_strides):
return result


# TODO: c2r keyword is place holder for irfftn
def _cook_nd_args(a, s=None, axes=None, c2r=False):
if s is None:
shapeless = True
if axes is None:
s = list(a.shape)
else:
s = numpy.take(a.shape, axes)
else:
shapeless = False

for s_i in s:
if s_i is not None and s_i < 1 and s_i != -1:
raise ValueError(
f"Invalid number of FFT data points ({s_i}) specified."
)

if axes is None:
axes = list(range(-len(s), 0))

if len(s) != len(axes):
raise ValueError("Shape and axes have different lengths.")

s = list(s)
if c2r and shapeless:
s[-1] = (a.shape[axes[-1]] - 1) * 2
# use the whole input array along axis `i` if `s[i] == -1`
s = [a.shape[_a] if _s == -1 else _s for _s, _a in zip(s, axes)]
return s, axes


def _copy_array(x, complex_input):
"""
Creating a C-contiguous copy of input array if input array has a negative
Expand Down Expand Up @@ -204,6 +241,80 @@ def _copy_array(x, complex_input):
return x, copy_flag


def _extract_axes_chunk(a, s, chunk_size=3):
"""
Classify the first input into a list of lists with each list containing
only unique values in reverse order and its length is at most `chunk_size`.
The second input is also classified into a list of lists with each list
containing the corresponding values of the first input.
Parameters
----------
a : list or tuple of ints
The first input.
s : list or tuple of ints
The second input.
chunk_size : int
Maximum number of elements in each chunk.
Return
------
out : a tuple of two lists
The first element of output is a list of lists with each list
containing only unique values in revere order and its length is
at most `chunk_size`.
The second element of output is a list of lists with each list
containing the corresponding values of the first input.
Examples
--------
>>> axes = (0, 1, 2, 3, 4)
>>> shape = (7, 8, 10, 9, 5)
>>> _extract_axes_chunk(axes, shape, chunk_size=3)
([[4, 3], [2, 1, 0]], [[5, 9], [10, 8, 7]])
>>> axes = (1, 0, 3, 2, 4, 4)
>>> shape = (7, 8, 10, 5, 7, 6)
>>> _extract_axes_chunk(axes, shape, chunk_size=3)
([[4], [4, 2], [3, 0, 1]], [[6], [7, 5], [10, 8, 7]])
"""

a_chunks = []
a_current_chunk = []
seen_elements = set()

s_chunks = []
s_current_chunk = []

for a_elem, s_elem in zip(a, s):
if a_elem in seen_elements:
# If element is already seen, start a new chunk
a_chunks.append(a_current_chunk[::-1])
s_chunks.append(s_current_chunk[::-1])
a_current_chunk = [a_elem]
s_current_chunk = [s_elem]
seen_elements = {a_elem}
else:
a_current_chunk.append(a_elem)
s_current_chunk.append(s_elem)
seen_elements.add(a_elem)

if len(a_current_chunk) == chunk_size:
a_chunks.append(a_current_chunk[::-1])
s_chunks.append(s_current_chunk[::-1])
a_current_chunk = []
s_current_chunk = []
seen_elements = set()

# Add the last chunk if it's not empty
if a_current_chunk:
a_chunks.append(a_current_chunk[::-1])
s_chunks.append(s_current_chunk[::-1])

return a_chunks[::-1], s_chunks[::-1]


def _fft(a, norm, out, forward, in_place, c2c, axes=None):
"""Calculates FFT of the input array along the specified axes."""

Expand Down Expand Up @@ -238,7 +349,11 @@ def _fft(a, norm, out, forward, in_place, c2c, axes=None):

def _scale_result(res, a_shape, norm, forward, index):
"""Scale the result of the FFT according to `norm`."""
scale = numpy.prod(a_shape[index:], dtype=res.real.dtype)
if res.dtype in [dpnp.float32, dpnp.complex64]:
dtype = dpnp.float32
else:
dtype = dpnp.float64
scale = numpy.prod(a_shape[index:], dtype=dtype)
norm_factor = 1
if norm == "ortho":
norm_factor = numpy.sqrt(scale)
Expand Down Expand Up @@ -293,7 +408,7 @@ def _truncate_or_pad(a, shape, axes):
return a


def _validate_out_keyword(a, out, axis, c2r, r2c):
def _validate_out_keyword(a, out, s, axes, c2r, r2c):
"""Validate out keyword argument."""
if out is not None:
dpnp.check_supported_arrays_type(out)
Expand All @@ -305,16 +420,18 @@ def _validate_out_keyword(a, out, axis, c2r, r2c):
"Input and output allocation queues are not compatible"
)

# validate out shape
expected_shape = a.shape
# validate out shape against the final shape,
# intermediate shapes may vary
expected_shape = list(a.shape)
for s_i, axis in zip(s[::-1], axes[::-1]):
expected_shape[axis] = s_i
if r2c:
expected_shape = list(a.shape)
expected_shape[axis] = a.shape[axis] // 2 + 1
expected_shape = tuple(expected_shape)
if out.shape != expected_shape:
expected_shape[axes[-1]] = expected_shape[axes[-1]] // 2 + 1

if out.shape != tuple(expected_shape):
raise ValueError(
"output array has incorrect shape, expected "
f"{expected_shape}, got {out.shape}."
f"{tuple(expected_shape)}, got {out.shape}."
)

# validate out data type
Expand All @@ -328,9 +445,33 @@ def _validate_out_keyword(a, out, axis, c2r, r2c):
raise TypeError("output array should have complex data type.")


def _validate_s_axes(a, s, axes):
if axes is not None:
# validate axes is a sequence and
# each axis is an integer within the range
normalize_axis_tuple(list(set(axes)), a.ndim, "axes")

if s is not None:
raise_error = False
if isinstance(s, Sequence):
if any(not isinstance(s_i, int) for s_i in s):
raise_error = True
else:
raise_error = True

if raise_error:
raise TypeError("`s` must be `None` or a sequence of integers.")

if axes is None:
raise ValueError(
"`axes` should not be `None` if `s` is not `None`."
)


def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
"""Calculates 1-D FFT of the input array along axis"""

_check_norm(norm)
a_ndim = a.ndim
if a_ndim == 0:
raise ValueError("Input array must be at least 1D")
Expand All @@ -354,7 +495,7 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):

_check_norm(norm)
a = _truncate_or_pad(a, n, axis)
_validate_out_keyword(a, out, axis, c2r, r2c)
_validate_out_keyword(a, out, (n,), (axis,), c2r, r2c)
# if input array is copied, in-place FFT can be used
a, in_place = _copy_array(a, c2c or c2r)
if not in_place and out is not None:
Expand All @@ -377,3 +518,71 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
c2c=c2c,
axes=axis,
)


def dpnp_fftn(a, forward, s=None, axes=None, norm=None, out=None):
"""Calculates N-D FFT of the input array along axes"""

_check_norm(norm)
if isinstance(axes, (list, tuple)) and len(axes) == 0:
return a

if a.ndim == 0:
if axes is not None:
raise IndexError(
"Input array is 0-dimensional while axis is not `None`."
)

return a

_validate_s_axes(a, s, axes)
s, axes = _cook_nd_args(a, s, axes)
# TODO: False and False are place holder for future development of
# rfft2, irfft2, rfftn, irfftn
_validate_out_keyword(a, out, s, axes, False, False)
# TODO: True is place holder for future development of
# rfft2, irfft2, rfftn, irfftn
a, in_place = _copy_array(a, True)

len_axes = len(axes)
# OneMKL supports up to 3-dimensional FFT on GPU
# repeated axis in OneMKL FFT is not allowed
if len_axes > 3 or len(set(axes)) < len_axes:
axes_chunk, shape_chunk = _extract_axes_chunk(axes, s, chunk_size=3)
for s_chunk, a_chunk in zip(shape_chunk, axes_chunk):
a = _truncate_or_pad(a, shape=s_chunk, axes=a_chunk)
if out is not None and out.shape == a.shape:
tmp_out = out
else:
tmp_out = None
a = _fft(
a,
norm=norm,
out=tmp_out,
forward=forward,
in_place=in_place,
# TODO: c2c=True is place holder for future development of
# rfft2, irfft2, rfftn, irfftn
c2c=True,
axes=a_chunk,
)
return a

a = _truncate_or_pad(a, s, axes)
if a.size == 0:
return dpnp.get_result_array(a, out=out, casting="same_kind")
if a.ndim == len_axes:
# non-batch FFT
axes = None

return _fft(
a,
norm=norm,
out=out,
forward=forward,
in_place=in_place,
# TODO: c2c=True is place holder for future development of
# rfft2, irfft2, rfftn, irfftn
c2c=True,
axes=axes,
)
24 changes: 0 additions & 24 deletions tests/skipped_tests.tbl
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,6 @@ tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: (dpnp
tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray([(i, i) for i in x], [("a", object), ("b", dpnp.int32)])]]
tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray(x).astype(dpnp.int8)]

tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_1_{axes=None, norm=None, s=(1, None), shape=(3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_ifft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_9_{axes=None, norm=None, s=(1, 4, None), shape=(2, 3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_15_{axes=(), norm=None, s=None, shape=(2, 3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_15_{axes=(), norm=None, s=None, shape=(2, 3, 4)}::test_ifft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_16_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_16_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_ifft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_18_{axes=None, norm=None, s=None, shape=(0, 5)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_19_{axes=None, norm=None, s=None, shape=(2, 0, 5)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_20_{axes=None, norm=None, s=None, shape=(0, 0, 5)}::test_fft2

tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_1_{axes=None, norm=None, s=(1, None), shape=(3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_ifftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_17_{axes=(), norm='ortho', s=None, shape=(2, 3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_17_{axes=(), norm='ortho', s=None, shape=(2, 3, 4)}::test_ifftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_18_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_18_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_ifftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_10_{axes=None, norm=None, s=(1, 4, None), shape=(2, 3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_21_{axes=None, norm=None, s=None, shape=(0, 5)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_22_{axes=None, norm=None, s=None, shape=(2, 0, 5)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_23_{axes=None, norm=None, s=None, shape=(0, 0, 5)}::test_fftn

tests/third_party/intel/test_zero_copy_test1.py::test_dpnp_interaction_with_dpctl_memory

tests/test_umath.py::test_umaths[('divmod', 'ii')]
Expand Down
24 changes: 0 additions & 24 deletions tests/skipped_tests_gpu.tbl
Original file line number Diff line number Diff line change
Expand Up @@ -110,30 +110,6 @@ tests/third_party/cupy/core_tests/test_ndarray_reduction.py::TestCubReduction_pa
tests/third_party/cupy/core_tests/test_ndarray_reduction.py::TestCubReduction_param_7_{order='F', shape=(10, 20, 30, 40)}::test_cub_max
tests/third_party/cupy/core_tests/test_ndarray_reduction.py::TestCubReduction_param_7_{order='F', shape=(10, 20, 30, 40)}::test_cub_min

tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_1_{axes=None, norm=None, s=(1, None), shape=(3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_ifft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_9_{axes=None, norm=None, s=(1, 4, None), shape=(2, 3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_15_{axes=(), norm=None, s=None, shape=(2, 3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_15_{axes=(), norm=None, s=None, shape=(2, 3, 4)}::test_ifft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_16_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_16_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_ifft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_18_{axes=None, norm=None, s=None, shape=(0, 5)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_19_{axes=None, norm=None, s=None, shape=(2, 0, 5)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_20_{axes=None, norm=None, s=None, shape=(0, 0, 5)}::test_fft2

tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_1_{axes=None, norm=None, s=(1, None), shape=(3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_ifftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_17_{axes=(), norm='ortho', s=None, shape=(2, 3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_17_{axes=(), norm='ortho', s=None, shape=(2, 3, 4)}::test_ifftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_18_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_18_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_ifftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_10_{axes=None, norm=None, s=(1, 4, None), shape=(2, 3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_21_{axes=None, norm=None, s=None, shape=(0, 5)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_22_{axes=None, norm=None, s=None, shape=(2, 0, 5)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_23_{axes=None, norm=None, s=None, shape=(0, 0, 5)}::test_fftn

tests/third_party/cupy/indexing_tests/test_generate.py::TestAxisConcatenator::test_AxisConcatenator_init1
tests/third_party/cupy/indexing_tests/test_generate.py::TestAxisConcatenator::test_len
tests/third_party/cupy/indexing_tests/test_generate.py::TestC_::test_c_1
Expand Down
Loading

0 comments on commit 4b3b324

Please sign in to comment.