Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of dpnp.fft.fft2, dpnp.fft.ifft2, dpnp.fft.fftn, dpnp.fft.ifftn #1961

Merged
merged 17 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/conda-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ env:
test_umath.py
test_usm_type.py
third_party/cupy/core_tests
third_party/cupy/fft_tests
third_party/cupy/creation_tests
third_party/cupy/indexing_tests/test_indexing.py
third_party/cupy/lib_tests
Expand Down
518 changes: 381 additions & 137 deletions dpnp/fft/dpnp_iface_fft.py

Large diffs are not rendered by default.

231 changes: 220 additions & 11 deletions dpnp/fft/dpnp_utils_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,16 @@
# pylint: disable=protected-access
# pylint: disable=no-name-in-module

from collections.abc import Sequence

import dpctl
import dpctl.tensor._tensor_impl as ti
import dpctl.utils as dpu
import numpy
from dpctl.tensor._numpy_helper import normalize_axis_index
from dpctl.tensor._numpy_helper import (
normalize_axis_index,
normalize_axis_tuple,
)
from dpctl.utils import ExecutionPlacementError

vtavana marked this conversation as resolved.
Show resolved Hide resolved
import dpnp
Expand All @@ -54,6 +59,7 @@

__all__ = [
"dpnp_fft",
"dpnp_fftn",
]


Expand Down Expand Up @@ -159,6 +165,37 @@ def _compute_result(dsc, a, out, forward, c2c, a_strides):
return result


# TODO: c2r keyword is place holder for irfftn
def _cook_nd_args(a, s=None, axes=None, c2r=False):
if s is None:
shapeless = True
if axes is None:
s = list(a.shape)
else:
s = numpy.take(a.shape, axes)
else:
shapeless = False

for s_i in s:
if s_i is not None and s_i < 1 and s_i != -1:
raise ValueError(
f"Invalid number of FFT data points ({s_i}) specified."
)

if axes is None:
axes = list(range(-len(s), 0))

if len(s) != len(axes):
raise ValueError("Shape and axes have different lengths.")

s = list(s)
if c2r and shapeless:
s[-1] = (a.shape[axes[-1]] - 1) * 2
# use the whole input array along axis `i` if `s[i] == -1`
s = [a.shape[_a] if _s == -1 else _s for _s, _a in zip(s, axes)]
return s, axes


def _copy_array(x, complex_input):
"""
Creating a C-contiguous copy of input array if input array has a negative
Expand Down Expand Up @@ -204,6 +241,80 @@ def _copy_array(x, complex_input):
return x, copy_flag


def _extract_axes_chunk(a, s, chunk_size=3):
"""
Classify the first input into a list of lists with each list containing
only unique values in reverse order and its length is at most `chunk_size`.
The second input is also classified into a list of lists with each list
containing the corresponding values of the first input.

Parameters
----------
a : list or tuple of ints
The first input.
s : list or tuple of ints
The second input.
chunk_size : int
Maximum number of elements in each chunk.

Return
------
out : a tuple of two lists
The first element of output is a list of lists with each list
containing only unique values in revere order and its length is
at most `chunk_size`.
The second element of output is a list of lists with each list
containing the corresponding values of the first input.

Examples
--------
>>> axes = (0, 1, 2, 3, 4)
>>> shape = (7, 8, 10, 9, 5)
>>> _extract_axes_chunk(axes, shape, chunk_size=3)
([[4, 3], [2, 1, 0]], [[5, 9], [10, 8, 7]])

>>> axes = (1, 0, 3, 2, 4, 4)
>>> shape = (7, 8, 10, 5, 7, 6)
>>> _extract_axes_chunk(axes, shape, chunk_size=3)
([[4], [4, 2], [3, 0, 1]], [[6], [7, 5], [10, 8, 7]])

"""

a_chunks = []
a_current_chunk = []
seen_elements = set()

s_chunks = []
s_current_chunk = []

for a_elem, s_elem in zip(a, s):
if a_elem in seen_elements:
# If element is already seen, start a new chunk
a_chunks.append(a_current_chunk[::-1])
s_chunks.append(s_current_chunk[::-1])
a_current_chunk = [a_elem]
s_current_chunk = [s_elem]
seen_elements = {a_elem}
else:
a_current_chunk.append(a_elem)
s_current_chunk.append(s_elem)
seen_elements.add(a_elem)

if len(a_current_chunk) == chunk_size:
a_chunks.append(a_current_chunk[::-1])
s_chunks.append(s_current_chunk[::-1])
a_current_chunk = []
s_current_chunk = []
seen_elements = set()

# Add the last chunk if it's not empty
if a_current_chunk:
a_chunks.append(a_current_chunk[::-1])
s_chunks.append(s_current_chunk[::-1])

return a_chunks[::-1], s_chunks[::-1]


def _fft(a, norm, out, forward, in_place, c2c, axes=None):
"""Calculates FFT of the input array along the specified axes."""

Expand Down Expand Up @@ -238,7 +349,11 @@ def _fft(a, norm, out, forward, in_place, c2c, axes=None):

def _scale_result(res, a_shape, norm, forward, index):
"""Scale the result of the FFT according to `norm`."""
scale = numpy.prod(a_shape[index:], dtype=res.real.dtype)
if res.dtype in [dpnp.float32, dpnp.complex64]:
dtype = dpnp.float32
else:
dtype = dpnp.float64
scale = numpy.prod(a_shape[index:], dtype=dtype)
norm_factor = 1
if norm == "ortho":
norm_factor = numpy.sqrt(scale)
Expand Down Expand Up @@ -293,7 +408,7 @@ def _truncate_or_pad(a, shape, axes):
return a


def _validate_out_keyword(a, out, axis, c2r, r2c):
def _validate_out_keyword(a, out, s, axes, c2r, r2c):
"""Validate out keyword argument."""
if out is not None:
dpnp.check_supported_arrays_type(out)
Expand All @@ -305,16 +420,18 @@ def _validate_out_keyword(a, out, axis, c2r, r2c):
"Input and output allocation queues are not compatible"
)

# validate out shape
expected_shape = a.shape
# validate out shape against the final shape,
# intermediate shapes may vary
expected_shape = list(a.shape)
for s_i, axis in zip(s[::-1], axes[::-1]):
expected_shape[axis] = s_i
if r2c:
expected_shape = list(a.shape)
expected_shape[axis] = a.shape[axis] // 2 + 1
expected_shape = tuple(expected_shape)
if out.shape != expected_shape:
expected_shape[axes[-1]] = expected_shape[axes[-1]] // 2 + 1

if out.shape != tuple(expected_shape):
raise ValueError(
"output array has incorrect shape, expected "
f"{expected_shape}, got {out.shape}."
f"{tuple(expected_shape)}, got {out.shape}."
)

# validate out data type
Expand All @@ -328,9 +445,33 @@ def _validate_out_keyword(a, out, axis, c2r, r2c):
raise TypeError("output array should have complex data type.")


def _validate_s_axes(a, s, axes):
if axes is not None:
# validate axes is a sequence and
# each axis is an integer within the range
normalize_axis_tuple(list(set(axes)), a.ndim, "axes")

if s is not None:
raise_error = False
if isinstance(s, Sequence):
if any(not isinstance(s_i, int) for s_i in s):
raise_error = True
else:
raise_error = True

if raise_error:
raise TypeError("`s` must be `None` or a sequence of integers.")

if axes is None:
raise ValueError(
"`axes` should not be `None` if `s` is not `None`."
)


def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
"""Calculates 1-D FFT of the input array along axis"""

_check_norm(norm)
a_ndim = a.ndim
if a_ndim == 0:
raise ValueError("Input array must be at least 1D")
Expand All @@ -354,7 +495,7 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):

_check_norm(norm)
a = _truncate_or_pad(a, n, axis)
_validate_out_keyword(a, out, axis, c2r, r2c)
_validate_out_keyword(a, out, (n,), (axis,), c2r, r2c)
# if input array is copied, in-place FFT can be used
a, in_place = _copy_array(a, c2c or c2r)
if not in_place and out is not None:
Expand All @@ -377,3 +518,71 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
c2c=c2c,
axes=axis,
)


def dpnp_fftn(a, forward, s=None, axes=None, norm=None, out=None):
"""Calculates N-D FFT of the input array along axes"""

_check_norm(norm)
if isinstance(axes, (list, tuple)) and len(axes) == 0:
return a

if a.ndim == 0:
if axes is not None:
raise IndexError(
"Input array is 0-dimensional while axis is not `None`."
)

return a

_validate_s_axes(a, s, axes)
s, axes = _cook_nd_args(a, s, axes)
# TODO: False and False are place holder for future development of
# rfft2, irfft2, rfftn, irfftn
_validate_out_keyword(a, out, s, axes, False, False)
# TODO: True is place holder for future development of
# rfft2, irfft2, rfftn, irfftn
a, in_place = _copy_array(a, True)

len_axes = len(axes)
# OneMKL supports up to 3-dimensional FFT on GPU
# repeated axis in OneMKL FFT is not allowed
if len_axes > 3 or len(set(axes)) < len_axes:
axes_chunk, shape_chunk = _extract_axes_chunk(axes, s, chunk_size=3)
for s_chunk, a_chunk in zip(shape_chunk, axes_chunk):
a = _truncate_or_pad(a, shape=s_chunk, axes=a_chunk)
if out is not None and out.shape == a.shape:
tmp_out = out
else:
tmp_out = None
a = _fft(
a,
norm=norm,
out=tmp_out,
forward=forward,
in_place=in_place,
# TODO: c2c=True is place holder for future development of
# rfft2, irfft2, rfftn, irfftn
c2c=True,
axes=a_chunk,
)
return a

a = _truncate_or_pad(a, s, axes)
if a.size == 0:
return dpnp.get_result_array(a, out=out, casting="same_kind")
if a.ndim == len_axes:
# non-batch FFT
axes = None

return _fft(
a,
norm=norm,
out=out,
forward=forward,
in_place=in_place,
# TODO: c2c=True is place holder for future development of
# rfft2, irfft2, rfftn, irfftn
c2c=True,
axes=axes,
)
24 changes: 0 additions & 24 deletions tests/skipped_tests.tbl
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,6 @@ tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: (dpnp
tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray([(i, i) for i in x], [("a", object), ("b", dpnp.int32)])]]
tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray(x).astype(dpnp.int8)]

tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_1_{axes=None, norm=None, s=(1, None), shape=(3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_ifft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_9_{axes=None, norm=None, s=(1, 4, None), shape=(2, 3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_15_{axes=(), norm=None, s=None, shape=(2, 3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_15_{axes=(), norm=None, s=None, shape=(2, 3, 4)}::test_ifft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_16_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_16_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_ifft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_18_{axes=None, norm=None, s=None, shape=(0, 5)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_19_{axes=None, norm=None, s=None, shape=(2, 0, 5)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_20_{axes=None, norm=None, s=None, shape=(0, 0, 5)}::test_fft2

tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_1_{axes=None, norm=None, s=(1, None), shape=(3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_ifftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_17_{axes=(), norm='ortho', s=None, shape=(2, 3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_17_{axes=(), norm='ortho', s=None, shape=(2, 3, 4)}::test_ifftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_18_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_18_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_ifftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_10_{axes=None, norm=None, s=(1, 4, None), shape=(2, 3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_21_{axes=None, norm=None, s=None, shape=(0, 5)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_22_{axes=None, norm=None, s=None, shape=(2, 0, 5)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_23_{axes=None, norm=None, s=None, shape=(0, 0, 5)}::test_fftn

tests/third_party/intel/test_zero_copy_test1.py::test_dpnp_interaction_with_dpctl_memory

tests/test_umath.py::test_umaths[('divmod', 'ii')]
Expand Down
24 changes: 0 additions & 24 deletions tests/skipped_tests_gpu.tbl
Original file line number Diff line number Diff line change
Expand Up @@ -110,30 +110,6 @@ tests/third_party/cupy/core_tests/test_ndarray_reduction.py::TestCubReduction_pa
tests/third_party/cupy/core_tests/test_ndarray_reduction.py::TestCubReduction_param_7_{order='F', shape=(10, 20, 30, 40)}::test_cub_max
tests/third_party/cupy/core_tests/test_ndarray_reduction.py::TestCubReduction_param_7_{order='F', shape=(10, 20, 30, 40)}::test_cub_min

tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_1_{axes=None, norm=None, s=(1, None), shape=(3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_ifft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_9_{axes=None, norm=None, s=(1, 4, None), shape=(2, 3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_15_{axes=(), norm=None, s=None, shape=(2, 3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_15_{axes=(), norm=None, s=None, shape=(2, 3, 4)}::test_ifft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_16_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_16_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_ifft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_18_{axes=None, norm=None, s=None, shape=(0, 5)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_19_{axes=None, norm=None, s=None, shape=(2, 0, 5)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_20_{axes=None, norm=None, s=None, shape=(0, 0, 5)}::test_fft2

tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_1_{axes=None, norm=None, s=(1, None), shape=(3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_ifftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_17_{axes=(), norm='ortho', s=None, shape=(2, 3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_17_{axes=(), norm='ortho', s=None, shape=(2, 3, 4)}::test_ifftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_18_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_18_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_ifftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_10_{axes=None, norm=None, s=(1, 4, None), shape=(2, 3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_21_{axes=None, norm=None, s=None, shape=(0, 5)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_22_{axes=None, norm=None, s=None, shape=(2, 0, 5)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_23_{axes=None, norm=None, s=None, shape=(0, 0, 5)}::test_fftn

tests/third_party/cupy/indexing_tests/test_generate.py::TestAxisConcatenator::test_AxisConcatenator_init1
tests/third_party/cupy/indexing_tests/test_generate.py::TestAxisConcatenator::test_len
tests/third_party/cupy/indexing_tests/test_generate.py::TestC_::test_c_1
Expand Down
Loading
Loading