Skip to content

Commit

Permalink
Detect scipy.fft by version rather than existence
Browse files Browse the repository at this point in the history
  • Loading branch information
peterbell10 committed Jul 19, 2019
1 parent 8443c29 commit 9697c15
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 25 deletions.
16 changes: 7 additions & 9 deletions pyfftw/interfaces/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,17 +259,15 @@
except ImportError:
pass
else:
del scipy
from . import scipy_fftpack

from distutils.version import LooseVersion as _LooseVersion

try:
import scipy.fft
except ImportError:
pass
else:
has_scipy_fft = _LooseVersion(scipy.__version__) >= _LooseVersion('1.4.0')
del _LooseVersion
del scipy
from . import scipy_fft

from . import scipy_fftpack
if has_scipy_fft:
from . import scipy_fft


fft_wrap = None
Expand Down
39 changes: 23 additions & 16 deletions test/test_pyfftw_scipy_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,17 @@
import numpy

try:
import scipy
scipy_version = scipy.__version__
except ImportError:
scipy_version = '0.0.0'

from distutils.version import LooseVersion
has_scipy_fft = LooseVersion(scipy_version) >= LooseVersion('1.4.0')

if has_scipy_fft:
import scipy.fft
from pyfftw.interfaces import scipy_fft
scipy_missing = False
except ImportError:
scipy_missing = True

import unittest
from .test_pyfftw_base import run_test_suites, miss
Expand Down Expand Up @@ -79,7 +85,7 @@ def make_c2r_real_data(shape, dtype):
'r2c': (real_dtypes, make_r2c_real_data),
'c2r': (real_dtypes, make_c2r_real_data)}

@unittest.skipIf(scipy_missing, 'scipy.fft is unavailable')
@unittest.skipIf(not has_scipy_fft, 'scipy.fft is unavailable')
class InterfacesScipyFFTTestSimple(unittest.TestCase):
''' A simple test suite for a simple implementation.
'''
Expand Down Expand Up @@ -112,22 +118,23 @@ def test_acquired_names(self):

# Construct all the test classes automatically.
test_cases = []
if not scipy_missing:
for each_func in funcs:
class_name = 'InterfacesScipyFFTTest' + each_func.upper()
for each_func in funcs:
class_name = 'InterfacesScipyFFTTest' + each_func.upper()

parent_class_name = 'InterfacesNumpyFFTTest' + each_func.upper()
parent_class = getattr(test_pyfftw_numpy_interface, parent_class_name)
parent_class_name = 'InterfacesNumpyFFTTest' + each_func.upper()
parent_class = getattr(test_pyfftw_numpy_interface, parent_class_name)

class_dict = {'validator_module': scipy.fft,
'test_interface': scipy_fft,
'io_dtypes': io_dtypes,
'overwrite_input_flag': 'overwrite_x',
'default_s_from_shape_slicer': slice(None)}
class_dict = {'validator_module': scipy.fft if has_scipy_fft else None,
'test_interface': scipy_fft if has_scipy_fft else None,
'io_dtypes': io_dtypes,
'overwrite_input_flag': 'overwrite_x',
'default_s_from_shape_slicer': slice(None)}

globals()[class_name] = type(class_name, (parent_class,), class_dict)
cls = type(class_name, (parent_class,), class_dict)
cls = unittest.skipIf(not has_scipy_fft, "scipy.fft is not available")(cls)

test_cases.append(globals()[class_name])
globals()[class_name] = cls
test_cases.append(cls)

test_cases.append(InterfacesScipyFFTTestSimple)
test_set = None
Expand Down

0 comments on commit 9697c15

Please sign in to comment.