Skip to content

Commit

Permalink
implement fft2, ifft2, fftn, ifftn
Browse files Browse the repository at this point in the history
  • Loading branch information
vtavana committed Aug 2, 2024
1 parent 0c3dfe5 commit f7ed376
Show file tree
Hide file tree
Showing 9 changed files with 889 additions and 234 deletions.
1 change: 1 addition & 0 deletions .github/workflows/conda-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ env:
test_umath.py
test_usm_type.py
third_party/cupy/core_tests
third_party/cupy/fft_tests
third_party/cupy/creation_tests
third_party/cupy/indexing_tests/test_indexing.py
third_party/cupy/lib_tests
Expand Down
487 changes: 363 additions & 124 deletions dpnp/fft/dpnp_iface_fft.py

Large diffs are not rendered by default.

189 changes: 187 additions & 2 deletions dpnp/fft/dpnp_utils_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@
import dpctl.tensor._tensor_impl as ti
import dpctl.utils as dpu
import numpy
from dpctl.tensor._numpy_helper import normalize_axis_index
from dpctl.tensor._numpy_helper import (
normalize_axis_index,
normalize_axis_tuple,
)
from dpctl.utils import ExecutionPlacementError

import dpnp
Expand All @@ -66,6 +69,63 @@ def _check_norm(norm):
)


def _cook_nd_args(a, s=None, axes=None, c2r=False):
if axes is not None:
# validate axes is a sequence and
# each axis is an integer within the range
normalize_axis_tuple(list(set(axes)), a.ndim, "axes")

if s is None:
shapeless = True
if axes is None:
s = list(a.shape)
else:
s = numpy.take(a.shape, axes)
else:
shapeless = False
try:
iter(s)
except Exception as e:
raise TypeError(
"`s` must be `None` or a sequence of integers."
) from e

for s_i in s:
if s_i is not None and not isinstance(s_i, int):
raise TypeError("`s` must be `None` or a sequence of integers.")

for s_i in s:
if s_i is not None and s_i < 1 and s_i != -1:
raise ValueError(
f"Invalid number of FFT data points ({s_i}) specified."
)

if axes is None:
# TODO: uncomment the checkpoint
# both `s` and `axes` being `None` is currently deprecated
# and will raise an error in future versions of NumPy
# if not shapeless:
# raise ValueError(
# "`axes` should not be `None` if `s` is not `None`."
# )
axes = list(range(-len(s), 0))
if len(s) != len(axes):
raise ValueError("Shape and axes have different lengths.")

# TODO: remove this for loop
# support of `i`` being `None`` is deprecated and will raise
# a TypeError in future versions of NumPy
for i, s_i in enumerate(s):
s = list(s)
s[i] = a.shape[axes[i]] if s_i is None else s_i

if c2r and shapeless:
s[-1] = (a.shape[axes[-1]] - 1) * 2
# use the whole input array along axis `i` if `s[i] == -1`
s = [a.shape[_a] if _s == -1 else _s for _s, _a in zip(s, axes)]
return s, axes


def _commit_descriptor(a, in_place, c2c, a_strides, index, axes):
"""Commit the FFT descriptor for the input array."""

Expand Down Expand Up @@ -205,6 +265,63 @@ def _copy_array(x, complex_input):
return x, copy_flag


def _extract_axes_chunk(a, chunk_size=3):
"""
Classify input into a list of list with each list containing
only unique values and its length is at most `chunk_size`.
Parameters
----------
a : list, tuple
Input.
chunk_size : int
Maximum number of elements in each chunk.
Return
------
out : list of lists
List of lists with each list containing only unique values
and its length is at most `chunk_size`.
The final list is returned in reverse order.
Examples
--------
>>> axes = (0, 1, 2, 3, 4)
>>> _extract_axes_chunk(axes, chunk_size=3)
[[2, 3, 4], [0, 1]]
>>> axes = (0, 1, 2, 3, 4, 4)
>>> _extract_axes_chunk(axes, chunk_size=3)
[[4], [2, 3, 4], [0, 1]]
"""

chunks = []
current_chunk = []
seen_elements = set()

for elem in a:
if elem in seen_elements:
# If element is already seen, start a new chunk
chunks.append(current_chunk)
current_chunk = [elem]
seen_elements = {elem}
else:
current_chunk.append(elem)
seen_elements.add(elem)

if len(current_chunk) == chunk_size:
chunks.append(current_chunk)
current_chunk = []
seen_elements = set()

# Add the last chunk if it's not empty
if current_chunk:
chunks.append(current_chunk)

return chunks[::-1]


def _fft(a, norm, out, forward, in_place, c2c, axes=None):
"""Calculates FFT of the input array along the specified axes."""

Expand Down Expand Up @@ -239,7 +356,11 @@ def _fft(a, norm, out, forward, in_place, c2c, axes=None):

def _scale_result(res, a_shape, norm, forward, index):
"""Scale the result of the FFT according to `norm`."""
scale = numpy.prod(a_shape[index:], dtype=res.real.dtype)
if res.dtype in [dpnp.float32, dpnp.complex64]:
dtype = dpnp.float32
else:
dtype = dpnp.float64
scale = numpy.prod(a_shape[index:], dtype=dtype)
norm_factor = 1
if norm == "ortho":
norm_factor = numpy.sqrt(scale)
Expand Down Expand Up @@ -332,6 +453,7 @@ def _validate_out_keyword(a, out, axis, c2r, r2c):
def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
"""Calculates 1-D FFT of the input array along axis"""

_check_norm(norm)
a_ndim = a.ndim
if a_ndim == 0:
raise ValueError("Input array must be at least 1D")
Expand Down Expand Up @@ -378,3 +500,66 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
c2c=c2c,
axes=axis,
)


def dpnp_fftn(a, forward, s=None, axes=None, norm=None, out=None):
"""Calculates N-D FFT of the input array along axes"""

_check_norm(norm)
if isinstance(axes, (list, tuple)) and len(axes) == 0:
return a

if a.ndim == 0:
if axes is not None:
raise IndexError(
"Input array is 0-dimensional while axis is not `None`."
)

return a

s, axes = _cook_nd_args(a, s, axes)
a = _truncate_or_pad(a, s, axes)
# TODO: None, False, False are place holder for future development of
# rfft2, irfft2, rfftn, irfftn
_validate_out_keyword(a, out, None, False, False)
# TODO: True is place holder for future development of
# rfft2, irfft2, rfftn, irfftn
a, in_place = _copy_array(a, True)

if a.size == 0:
return dpnp.get_result_array(a, out=out, casting="same_kind")

len_axes = len(axes)
# OneMKL supports up to 3-dimensional FFT on GPU
# repeated axis in OneMKL FFT is not allowed
if len_axes > 3 or len(set(axes)) < len_axes:
axes_chunk = _extract_axes_chunk(axes, chunk_size=3)
for chunk in axes_chunk:
a = _fft(
a,
norm=norm,
out=out,
forward=forward,
in_place=in_place,
# TODO: c2c=True is place holder for future development of
# rfft2, irfft2, rfftn, irfftn
c2c=True,
axes=chunk,
)
return a

if a.ndim == len_axes:
# non-batch FFT
axes = None

return _fft(
a,
norm=norm,
out=out,
forward=forward,
in_place=in_place,
# TODO: c2c=True is place holder for future development of
# rfft2, irfft2, rfftn, irfftn
c2c=True,
axes=axes,
)
24 changes: 0 additions & 24 deletions tests/skipped_tests.tbl
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,6 @@ tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: (dpnp
tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray([(i, i) for i in x], [("a", object), ("b", dpnp.int32)])]]
tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray(x).astype(dpnp.int8)]

tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_1_{axes=None, norm=None, s=(1, None), shape=(3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_ifft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_9_{axes=None, norm=None, s=(1, 4, None), shape=(2, 3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_15_{axes=(), norm=None, s=None, shape=(2, 3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_15_{axes=(), norm=None, s=None, shape=(2, 3, 4)}::test_ifft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_16_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_16_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_ifft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_18_{axes=None, norm=None, s=None, shape=(0, 5)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_19_{axes=None, norm=None, s=None, shape=(2, 0, 5)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_20_{axes=None, norm=None, s=None, shape=(0, 0, 5)}::test_fft2

tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_1_{axes=None, norm=None, s=(1, None), shape=(3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_ifftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_17_{axes=(), norm='ortho', s=None, shape=(2, 3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_17_{axes=(), norm='ortho', s=None, shape=(2, 3, 4)}::test_ifftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_18_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_18_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_ifftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_10_{axes=None, norm=None, s=(1, 4, None), shape=(2, 3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_21_{axes=None, norm=None, s=None, shape=(0, 5)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_22_{axes=None, norm=None, s=None, shape=(2, 0, 5)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_23_{axes=None, norm=None, s=None, shape=(0, 0, 5)}::test_fftn

tests/third_party/intel/test_zero_copy_test1.py::test_dpnp_interaction_with_dpctl_memory

tests/test_umath.py::test_umaths[('divmod', 'ii')]
Expand Down
24 changes: 0 additions & 24 deletions tests/skipped_tests_gpu.tbl
Original file line number Diff line number Diff line change
Expand Up @@ -111,30 +111,6 @@ tests/third_party/cupy/core_tests/test_ndarray_reduction.py::TestCubReduction_pa
tests/third_party/cupy/core_tests/test_ndarray_reduction.py::TestCubReduction_param_7_{order='F', shape=(10, 20, 30, 40)}::test_cub_max
tests/third_party/cupy/core_tests/test_ndarray_reduction.py::TestCubReduction_param_7_{order='F', shape=(10, 20, 30, 40)}::test_cub_min

tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_1_{axes=None, norm=None, s=(1, None), shape=(3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_ifft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_9_{axes=None, norm=None, s=(1, 4, None), shape=(2, 3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_15_{axes=(), norm=None, s=None, shape=(2, 3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_15_{axes=(), norm=None, s=None, shape=(2, 3, 4)}::test_ifft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_16_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_16_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_ifft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_18_{axes=None, norm=None, s=None, shape=(0, 5)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_19_{axes=None, norm=None, s=None, shape=(2, 0, 5)}::test_fft2
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_20_{axes=None, norm=None, s=None, shape=(0, 0, 5)}::test_fft2

tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_1_{axes=None, norm=None, s=(1, None), shape=(3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_ifftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_17_{axes=(), norm='ortho', s=None, shape=(2, 3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_17_{axes=(), norm='ortho', s=None, shape=(2, 3, 4)}::test_ifftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_18_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_18_{axes=(0, 1, 2), norm='ortho', s=(2, 3), shape=(2, 3, 4)}::test_ifftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_10_{axes=None, norm=None, s=(1, 4, None), shape=(2, 3, 4)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_21_{axes=None, norm=None, s=None, shape=(0, 5)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_22_{axes=None, norm=None, s=None, shape=(2, 0, 5)}::test_fftn
tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_23_{axes=None, norm=None, s=None, shape=(0, 0, 5)}::test_fftn

tests/third_party/cupy/indexing_tests/test_generate.py::TestAxisConcatenator::test_AxisConcatenator_init1
tests/third_party/cupy/indexing_tests/test_generate.py::TestAxisConcatenator::test_len
tests/third_party/cupy/indexing_tests/test_generate.py::TestC_::test_c_1
Expand Down
Loading

0 comments on commit f7ed376

Please sign in to comment.