Skip to content

Commit

Permalink
TST: update FFTW interface tests to test the norm argument as well
Browse files Browse the repository at this point in the history
  • Loading branch information
grlee77 committed May 19, 2016
1 parent 321e8c5 commit 264839b
Showing 1 changed file with 41 additions and 6 deletions.
47 changes: 41 additions & 6 deletions test/test_pyfftw_numpy_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@ def make_complex_data(shape, dtype):
def make_real_data(shape, dtype):
return dtype(numpy.random.randn(*shape))

def _numpy_fft_has_norm_kwarg():
"""returns True if numpy's fft supports the norm keyword argument
This should be true for numpy >= 1.10
"""
# return LooseVersion(numpy.version.version) >= LooseVersion('1.10')
try:
np_fft.fft(numpy.ones(4), norm=None)
return True
except TypeError:
return False

functions = {
'fft': 'complex',
Expand Down Expand Up @@ -114,6 +125,7 @@ class InterfacesNumpyFFTTestFFT(unittest.TestCase):
((59, 99), {'axis': -1}),
((59, 99), {'axis': 0}),
((32, 32, 4), {'axis': 1}),
((32, 32, 2), {'axis': 1, 'norm': 'ortho'}),
((64, 128, 16), {}),
)

Expand All @@ -126,13 +138,17 @@ class InterfacesNumpyFFTTestFFT(unittest.TestCase):
((100,), (100, -20), IndexError, ''))

realinv = False
has_norm_kwarg = _numpy_fft_has_norm_kwarg()

@property
def test_data(self):
for test_shape, kwargs in self.test_shapes:
axes = self.axes_from_kwargs(kwargs)
s = self.s_from_kwargs(test_shape, kwargs)

if not self.has_norm_kwarg and 'norm' in kwargs:
kwargs.pop('norm')

if self.realinv:
test_shape = list(test_shape)
test_shape[axes[-1]] = test_shape[axes[-1]]//2 + 1
Expand Down Expand Up @@ -183,15 +199,18 @@ def _validate(self, array_type, test_shape, dtype,
# a complex array is turned into a real array

if 'axes' in kwargs:
axes = {'axes': kwargs['axes']}
validator_kwargs = {'axes': kwargs['axes']}
elif 'axis' in kwargs:
axes = {'axis': kwargs['axis']}
validator_kwargs = {'axis': kwargs['axis']}
else:
axes = {}
validator_kwargs = {}

if self.has_norm_kwarg and 'norm' in kwargs:
validator_kwargs['norm'] = kwargs['norm']

try:
test_out_array = getattr(self.validator_module, self.func)(
copy_func(np_input_array), s, **axes)
copy_func(np_input_array), s, **validator_kwargs)

except Exception as e:
interface_exception = None
Expand Down Expand Up @@ -331,14 +350,18 @@ def test_on_non_numpy_array(self):
test_shape, dtype, s, kwargs)


def test_fail_on_invalid_s_or_axes(self):
def test_fail_on_invalid_s_or_axes_or_norm(self):
dtype_tuple = self.io_dtypes[functions[self.func]]

for dtype in dtype_tuple[0]:

for test_shape, args, exception, e_str in self.invalid_args:
input_array = dtype_tuple[1](test_shape, dtype)

if len(args) > 2 and not self.has_norm_kwarg:
# skip tests invovling norm argument if it isn't available
continue

self.assertRaisesRegex(exception, e_str,
getattr(self.test_interface, self.func),
*((input_array,) + args))
Expand Down Expand Up @@ -601,17 +624,21 @@ class InterfacesNumpyFFTTestIFFT(InterfacesNumpyFFTTestFFT):

class InterfacesNumpyFFTTestRFFT(InterfacesNumpyFFTTestFFT):
func = 'rfft'
has_norm_kwarg = False

class InterfacesNumpyFFTTestIRFFT(InterfacesNumpyFFTTestFFT):
func = 'irfft'
realinv = True
has_norm_kwarg = False

class InterfacesNumpyFFTTestHFFT(InterfacesNumpyFFTTestFFT):
func = 'hfft'
realinv = True
has_norm_kwarg = False

class InterfacesNumpyFFTTestIHFFT(InterfacesNumpyFFTTestFFT):
func = 'ihfft'
has_norm_kwarg = False

class InterfacesNumpyFFTTestFFT2(InterfacesNumpyFFTTestFFT):
axes_kw = 'axes'
Expand All @@ -621,6 +648,7 @@ class InterfacesNumpyFFTTestFFT2(InterfacesNumpyFFTTestFFT):
((128, 32), {'axes': None}),
((128, 32, 4), {'axes': (0, 2)}),
((59, 100), {'axes': (-2, -1)}),
((32, 32), {'axes': (-2, -1), 'norm': 'ortho'}),
((64, 128, 16), {'axes': (0, 2)}),
((4, 6, 8, 4), {'axes': (0, 3)}),
)
Expand All @@ -631,7 +659,9 @@ class InterfacesNumpyFFTTestFFT2(InterfacesNumpyFFTTestFFT):
((100,), ((100, 200), (-3, -2, -1)), ValueError, 'Shape error'),
((100, 200), (100, -1), TypeError, ''),
((100, 200), ((100, 200), (-3, -2)), IndexError, 'Invalid axes'),
((100, 200), ((100,), (-3,)), IndexError, 'Invalid axes'))
((100, 200), ((100,), (-3,)), IndexError, 'Invalid axes'),
# pass invalid normalisation string
((100, 200), ((100,), (-3,), 'invalid_norm'), ValueError, ''))

def test_shape_and_s_different_lengths(self):
dtype_tuple = self.io_dtypes[functions[self.func]]
Expand All @@ -653,17 +683,20 @@ class InterfacesNumpyFFTTestIFFT2(InterfacesNumpyFFTTestFFT2):

class InterfacesNumpyFFTTestRFFT2(InterfacesNumpyFFTTestFFT2):
func = 'rfft2'
has_norm_kwarg = False

class InterfacesNumpyFFTTestIRFFT2(InterfacesNumpyFFTTestFFT2):
func = 'irfft2'
realinv = True
has_norm_kwarg = False

class InterfacesNumpyFFTTestFFTN(InterfacesNumpyFFTTestFFT2):
func = 'ifftn'
test_shapes = (
((128, 32, 4), {'axes': None}),
((64, 128, 16), {'axes': (0, 1, 2)}),
((4, 6, 8, 4), {'axes': (0, 3, 1)}),
((4, 6, 4, 4), {'axes': (0, 3, 1), 'norm': 'ortho'}),
((4, 6, 8, 4), {'axes': (0, 3, 1, 2)}),
)

Expand All @@ -672,10 +705,12 @@ class InterfacesNumpyFFTTestIFFTN(InterfacesNumpyFFTTestFFTN):

class InterfacesNumpyFFTTestRFFTN(InterfacesNumpyFFTTestFFTN):
func = 'rfftn'
has_norm_kwarg = False

class InterfacesNumpyFFTTestIRFFTN(InterfacesNumpyFFTTestFFTN):
func = 'irfftn'
realinv = True
has_norm_kwarg = False

test_cases = (
InterfacesNumpyFFTTestModule,
Expand Down

0 comments on commit 264839b

Please sign in to comment.