Skip to content

Commit

Permalink
Updating the scipy.fftpack interface to work properly in the rfft and…
Browse files Browse the repository at this point in the history
… irfft cases along with a complete test suite for scipy_fftpack.py to be confident there aren't any other lingering issues.
  • Loading branch information
hgomersall committed Oct 30, 2013
1 parent 9ba7211 commit 2de5e2a
Show file tree
Hide file tree
Showing 4 changed files with 284 additions and 54 deletions.
16 changes: 11 additions & 5 deletions pyfftw/interfaces/numpy_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,23 @@
of :mod:`numpy.fft`, but those functions that are not included here are imported
directly from :mod:`numpy.fft`.
The precision of the transform that is used is selected from the array that
is passed in, defaulting to double precision if any type conversion is
required.
It is notable that unlike :mod:`numpy.fftpack`, these functions will
generally return an output array with the same precision as the input
array, and the transform that is chosen is chosen based on the precision
of the input array. That is, if the input array is 32-bit floating point,
then the transform will be 32-bit floating point and so will the returned
array. If any type conversion is required, the default will be double
precision.
One known caveat is that repeated axes are handled differently to
:mod:`numpy.fft`; axes that are repeated in the axes argument are considered
only once, as compared to :mod:`numpy.fft` in which repeated axes results in
the DFT being taken along that axes as many times as the axis occurs.
The exceptions raised by each of these functions are as per their
equivalents in :mod:`numpy.fft`.
The exceptions raised by each of these functions are mostly as per their
equivalents in :mod:`numpy.fft`, though there are some corner cases in
which this may not be true.
'''

from ._utils import _Xfftn
Expand Down
137 changes: 132 additions & 5 deletions pyfftw/interfaces/scipy_fftpack.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python
#
# Copyright 2012 Knowledge Economy Developments Ltd
# Copyright 2013 Knowledge Economy Developments Ltd
#
# Henry Gomersall
# heng@kedevelopments.co.uk
Expand All @@ -23,9 +23,27 @@
:mod:`scipy.fftpack` module. This module *provides* the entire documented
namespace of :mod:`scipy.fftpack`, but those functions that are not included
here are imported directly from :mod:`scipy.fftpack`.
The exceptions raised by each of these functions are mostly as per their
equivalents in :mod:`scipy.fftpack`, though there are some corner cases in
which this may not be true.
It is notable that unlike :mod:`scipy.fftpack`, these functions will
generally return an output array with the same precision as the input
array, and the transform that is chosen is chosen based on the precision
of the input array. That is, if the input array is 32-bit floating point,
then the transform will be 32-bit floating point and so will the returned
array. If any type conversion is required, the default will be double
precision.
Some corner (mis)usages of :mod:`scipy.fftpack` may not transfer neatly.
For example, using :func:`scipy.fftpack.fft2` with a non 1D array and
a 2D `shape` argument will return without exception whereas
:func:`pyfftw.interfaces.scipy_fftpack.fft2` will raise a `ValueError`.
'''

from . import numpy_fft
import numpy

# Complete the namespace (these are not actually used in this module)
from scipy.fftpack import (dct, idct, diff, tilbert, itilbert,
Expand All @@ -46,7 +64,6 @@ def fft(x, n=None, axis=-1, overwrite_x=False,
the rest of the arguments are documented
in the :ref:`additional argument docs<interfaces_additional_args>`.
'''

return numpy_fft.fft(x, n, axis, overwrite_x, planner_effort,
threads, auto_align_input, auto_contiguous)

Expand Down Expand Up @@ -102,6 +119,16 @@ def fftn(x, shape=None, axes=None, overwrite_x=False,
in the :ref:`additional argument docs<interfaces_additional_args>`.
'''

if shape is not None:
if ((axes is not None and len(shape) != len(axes)) or
(axes is None and len(shape) != x.ndim)):
raise ValueError('Shape error: In order to maintain better '
'compatibility with scipy.fftpack.fftn, a ValueError '
'is raised when the length of the shape argument is '
'not the same as x.ndim if axes is None or the length '
'of axes if it is not. If this is problematic, consider '
'using the numpy interface.')

return numpy_fft.fftn(x, shape, axes, overwrite_x, planner_effort,
threads, auto_align_input, auto_contiguous)

Expand All @@ -116,9 +143,88 @@ def ifftn(x, shape=None, axes=None, overwrite_x=False,
in the :ref:`additional argument docs<interfaces_additional_args>`.
'''

if shape is not None:
if ((axes is not None and len(shape) != len(axes)) or
(axes is None and len(shape) != x.ndim)):
raise ValueError('Shape error: In order to maintain better '
'compatibility with scipy.fftpack.ifftn, a ValueError '
'is raised when the length of the shape argument is '
'not the same as x.ndim if axes is None or the length '
'of axes if it is not. If this is problematic, consider '
'using the numpy interface.')

return numpy_fft.ifftn(x, shape, axes, overwrite_x, planner_effort,
threads, auto_align_input, auto_contiguous)

def _complex_to_rfft_output(complex_output, output_shape, axis):
'''Convert the complex output from pyfftw to the real output expected
from :func:`scipy.fftpack.rfft`.
'''

rfft_output = numpy.empty(output_shape, dtype=complex_output.real.dtype)
source_slicer = [slice(None)] * complex_output.ndim
target_slicer = [slice(None)] * complex_output.ndim

# First element
source_slicer[axis] = slice(0, 1)
target_slicer[axis] = slice(0, 1)
rfft_output[target_slicer] = complex_output[source_slicer].real

# Real part
source_slicer[axis] = slice(1, None)
target_slicer[axis] = slice(1, None, 2)
rfft_output[target_slicer] = complex_output[source_slicer].real

# Imaginary part
if output_shape[axis] % 2 == 0:
end_val = -1
else:
end_val = None

source_slicer[axis] = slice(1, end_val, None)
target_slicer[axis] = slice(2, None, 2)
rfft_output[target_slicer] = complex_output[source_slicer].imag

return rfft_output


def _irfft_input_to_complex(irfft_input, axis):
'''Convert the expected real input to :func:`scipy.fftpack.irfft` to
the complex input needed by pyfftw.
'''
complex_dtype = numpy.result_type(irfft_input, 1j)

input_shape = list(irfft_input.shape)
input_shape[axis] = input_shape[axis]//2 + 1

complex_input = numpy.empty(input_shape, dtype=complex_dtype)
source_slicer = [slice(None)] * len(input_shape)
target_slicer = [slice(None)] * len(input_shape)

# First element
source_slicer[axis] = slice(0, 1)
target_slicer[axis] = slice(0, 1)
complex_input[target_slicer] = irfft_input[source_slicer]

# Real part
source_slicer[axis] = slice(1, None, 2)
target_slicer[axis] = slice(1, None)
complex_input[target_slicer].real = irfft_input[source_slicer]

# Imaginary part
if irfft_input.shape[axis] % 2 == 0:
end_val = -1
target_slicer[axis] = slice(-1, None)
complex_input[target_slicer].imag = 0.0
else:
end_val = None

source_slicer[axis] = slice(2, None, 2)
target_slicer[axis] = slice(1, end_val)
complex_input[target_slicer].imag = irfft_input[source_slicer]

return complex_input


def rfft(x, n=None, axis=-1, overwrite_x=False,
planner_effort='FFTW_MEASURE', threads=1,
Expand All @@ -129,10 +235,21 @@ def rfft(x, n=None, axis=-1, overwrite_x=False,
the rest of the arguments are documented
in the :ref:`additional argument docs<interfaces_additional_args>`.
'''
if not numpy.isrealobj(x):
raise TypeError('Input array must be real to maintain '
'compatibility with scipy.fftpack.rfft.')

return numpy_fft.rfft(x, n, axis, overwrite_x, planner_effort,
x = numpy.asanyarray(x)

complex_output = numpy_fft.rfft(x, n, axis, overwrite_x, planner_effort,
threads, auto_align_input, auto_contiguous)

output_shape = list(x.shape)
if n is not None:
output_shape[axis] = n

return _complex_to_rfft_output(complex_output, output_shape, axis)

def irfft(x, n=None, axis=-1, overwrite_x=False,
planner_effort='FFTW_MEASURE', threads=1,
auto_align_input=True, auto_contiguous=True):
Expand All @@ -142,7 +259,17 @@ def irfft(x, n=None, axis=-1, overwrite_x=False,
the rest of the arguments are documented
in the :ref:`additional argument docs<interfaces_additional_args>`.
'''
if not numpy.isrealobj(x):
raise TypeError('Input array must be real to maintain '
'compatibility with scipy.fftpack.irfft.')

return numpy_fft.irfft(x, n, axis, overwrite_x, planner_effort,
threads, auto_align_input, auto_contiguous)
x = numpy.asanyarray(x)

if n is None:
n = x.shape[axis]

complex_input = _irfft_input_to_complex(x, axis)

return numpy_fft.irfft(complex_input, n, axis, overwrite_x,
planner_effort, threads, auto_align_input, auto_contiguous)

Loading

0 comments on commit 2de5e2a

Please sign in to comment.