Skip to content

Commit

Permalink
TST: Test 2-D and N-D Dask wrapped FFT functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jakirkham committed Mar 13, 2018
1 parent 7a77145 commit a2772cf
Showing 1 changed file with 88 additions and 0 deletions.
88 changes: 88 additions & 0 deletions test/test_pyfftw_dask_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,17 @@ def _dask_array_fft_has_norm_kwarg():

functions = {
'fft': 'complex',
'fft2': 'complex',
'fftn': 'complex',
'ifft': 'complex',
'ifft2': 'complex',
'ifftn': 'complex',
'rfft': 'r2c',
'rfft2': 'r2c',
'rfftn': 'r2c',
'irfft': 'c2r',
'irfft2': 'c2r',
'irfftn': 'c2r',
'hfft': 'c2r',
'ihfft': 'r2c'}

Expand Down Expand Up @@ -472,19 +480,91 @@ def test_input_maintained(self):
numpy.alltrue(input_array == orig_input_array))


class InterfacesDaskFFTTestFFT2(InterfacesDaskFFTTestFFT):
axes_kw = 'axes'
func = 'ifft2'
has_norm_kwarg = False
test_shapes = (
((128, 64), {'axes': None}),
((128, 32), {'axes': None}),
((128, 32, 4), {'axes': (0, 2)}),
((59, 100), {'axes': (-2, -1)}),
((32, 32), {'axes': (-2, -1), 'norm': 'ortho'}),
((64, 128, 16), {'axes': (0, 2)}),
((4, 6, 8, 4), {'axes': (0, 3)}),
)

invalid_args = ()

def test_shape_and_s_different_lengths(self):
dtype_tuple = self.io_dtypes[functions[self.func]]
for dtype in dtype_tuple[0]:
for test_shape, s, _kwargs in self.test_data:
kwargs = copy.copy(_kwargs)
try:
s = s[1:]
except TypeError:
self.skipTest('Not meaningful test on 1d arrays.')

# Convert empty tuples to None
s = s if s else None

del kwargs['axes']
self.validate(dtype_tuple[1],
test_shape, dtype, s, kwargs)


class InterfacesDaskFFTTestFFTN(InterfacesDaskFFTTestFFT2):
func = 'ifftn'
has_norm_kwarg = False
test_shapes = (
((128, 32, 4), {'axes': None}),
((64, 128, 16), {'axes': (0, 1, 2)}),
((4, 6, 8, 4), {'axes': (0, 3, 1)}),
((4, 6, 4, 4), {'axes': (0, 3, 1), 'norm': 'ortho'}),
((4, 6, 8, 4), {'axes': (0, 3, 1, 2)}),
)


class InterfacesDaskFFTTestIFFT(InterfacesDaskFFTTestFFT):
func = 'ifft'
has_norm_kwarg = False

class InterfacesDaskFFTTestIFFT2(InterfacesDaskFFTTestFFT2):
func = 'ifft2'
has_norm_kwarg = False

class InterfacesDaskFFTTestIFFTN(InterfacesDaskFFTTestFFTN):
func = 'ifftn'
has_norm_kwarg = False

class InterfacesDaskFFTTestRFFT(InterfacesDaskFFTTestFFT):
func = 'rfft'
has_norm_kwarg = False

class InterfacesDaskFFTTestRFFT2(InterfacesDaskFFTTestFFT2):
func = 'rfft2'
has_norm_kwarg = False

class InterfacesDaskFFTTestRFFTN(InterfacesDaskFFTTestFFTN):
func = 'rfftn'
has_norm_kwarg = False

class InterfacesDaskFFTTestIRFFT(InterfacesDaskFFTTestFFT):
func = 'irfft'
realinv = True
has_norm_kwarg = False

class InterfacesDaskFFTTestIRFFT2(InterfacesDaskFFTTestFFT2):
func = 'irfft2'
has_norm_kwarg = False
realinv = True

class InterfacesDaskFFTTestIRFFTN(InterfacesDaskFFTTestFFTN):
func = 'irfftn'
has_norm_kwarg = False
realinv = True

class InterfacesDaskFFTTestHFFT(InterfacesDaskFFTTestFFT):
func = 'hfft'
realinv = True
Expand All @@ -497,9 +577,17 @@ class InterfacesDaskFFTTestIHFFT(InterfacesDaskFFTTestFFT):
test_cases = (
InterfacesDaskFFTTestModule,
InterfacesDaskFFTTestFFT,
InterfacesDaskFFTTestFFT2,
InterfacesDaskFFTTestFFTN,
InterfacesDaskFFTTestIFFT,
InterfacesDaskFFTTestIFFT2,
InterfacesDaskFFTTestIFFTN,
InterfacesDaskFFTTestRFFT,
InterfacesDaskFFTTestRFFT2,
InterfacesDaskFFTTestRFFTN,
InterfacesDaskFFTTestIRFFT,
InterfacesDaskFFTTestIRFFT2,
InterfacesDaskFFTTestIRFFTN,
InterfacesDaskFFTTestHFFT,
InterfacesDaskFFTTestIHFFT)

Expand Down

0 comments on commit a2772cf

Please sign in to comment.