Skip to content

ENH: Add LibraryBaseInterface #2538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 15, 2018
Merged
15 changes: 2 additions & 13 deletions nipype/algorithms/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
from scipy.ndimage.measurements import center_of_mass, label

from .. import config, logging
from ..utils.misc import package_check

from ..interfaces.base import (BaseInterface, traits, TraitedSpec, File,
InputMultiPath, BaseInterfaceInputSpec,
isdefined)
from ..interfaces.nipy.base import NipyBaseInterface
from ..utils import NUMPY_MMAP

iflogger = logging.getLogger('interface')
Expand Down Expand Up @@ -651,7 +651,7 @@ class SimilarityOutputSpec(TraitedSpec):
traits.Float(desc="Similarity between volume 1 and 2, frame by frame"))


class Similarity(BaseInterface):
class Similarity(NipyBaseInterface):
"""Calculates similarity between two 3D or 4D volumes. Both volumes have to be in
the same coordinate system, same space within that coordinate system and
with the same voxel dimensions.
Expand All @@ -674,19 +674,8 @@ class Similarity(BaseInterface):

input_spec = SimilarityInputSpec
output_spec = SimilarityOutputSpec
_have_nipy = True

def __init__(self, **inputs):
try:
package_check('nipy')
except Exception:
self._have_nipy = False
super(Similarity, self).__init__(**inputs)

def _run_interface(self, runtime):
if not self._have_nipy:
raise RuntimeError('nipy is not installed')

from nipy.algorithms.registration.histogram_registration import HistogramRegistration
from nipy.algorithms.registration.affine import Affine

Expand Down
2 changes: 1 addition & 1 deletion nipype/interfaces/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""
from .core import (Interface, BaseInterface, SimpleInterface, CommandLine,
StdOutCommandLine, MpiCommandLine, SEMLikeCommandLine,
PackageInfo)
LibraryBaseInterface, PackageInfo)

from .specs import (BaseTraitedSpec, TraitedSpec, DynamicTraitedSpec,
BaseInterfaceInputSpec, CommandLineInputSpec,
Expand Down
29 changes: 29 additions & 0 deletions nipype/interfaces/base/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,6 +1267,35 @@ def _format_arg(self, name, spec, value):
return super(SEMLikeCommandLine, self)._format_arg(name, spec, value)


class LibraryBaseInterface(BaseInterface):
_pkg = None
imports = ()

def __init__(self, check_import=True, *args, **kwargs):
super(LibraryBaseInterface, self).__init__(*args, **kwargs)
if check_import:
import importlib
failed_imports = []
for pkg in (self._pkg,) + tuple(self.imports):
try:
importlib.import_module(pkg)
except ImportError:
failed_imports.append(pkg)
if failed_imports:
iflogger.warn('Unable to import %s; %s interface may fail to '
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this going to generate a lot of warnings when we create docs? or make specs? i.e. things that simply create an instance of an interface:

InterfaceX()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, potentially. I guess we could set a class-level flag to only display once per class? Or do you have another preferred approach?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one option would be to remove this if. let's merge it and see how annoying it gets.

'run', failed_imports, self.__class__.__name__)

@property
def version(self):
if self._version is None:
import importlib
try:
self._version = importlib.import_module(self._pkg).__version__
except (ImportError, AttributeError):
pass
return super(LibraryBaseInterface, self).version


class PackageInfo(object):
_version = None
version_cmd = None
Expand Down
17 changes: 17 additions & 0 deletions nipype/interfaces/base/tests/test_auto_LibraryBaseInterface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# AUTO-GENERATED by tools/checkspecs.py - DO NOT EDIT
from __future__ import unicode_literals
from ..core import LibraryBaseInterface


def test_LibraryBaseInterface_inputs():
input_map = dict(
ignore_exception=dict(
deprecated='1.0.0',
nohash=True,
usedefault=True,
), )
inputs = LibraryBaseInterface.input_spec()

for key, metadata in list(input_map.items()):
for metakey, value in list(metadata.items()):
assert getattr(inputs.traits()[key], metakey) == value
33 changes: 33 additions & 0 deletions nipype/interfaces/cmtk/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
""" Base interface for cmtk """

from ..base import LibraryBaseInterface
from ...utils.misc import package_check


class CFFBaseInterface(LibraryBaseInterface):
_pkg = 'cfflib'


# Originally set in convert, nbs, nx, parcellation
# Set here to be imported, in case anybody depends on its presence
# Remove in 2.0
have_cmp = True
try:
package_check('cmp')
except ImportError:
have_cmp = False

have_cfflib = True
try:
package_check('cfflib')
except ImportError:
have_cfflib = False

have_cv = True
try:
package_check('cviewer')
except ImportError:
have_cv = False
18 changes: 6 additions & 12 deletions nipype/interfaces/cmtk/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,10 @@
import string
import networkx as nx

from ...utils.misc import package_check
from ...utils.filemanip import split_filename
from ..base import (BaseInterface, BaseInterfaceInputSpec, traits, File,
from ..base import (BaseInterfaceInputSpec, traits, File,
TraitedSpec, InputMultiPath, isdefined)

have_cfflib = True
try:
package_check('cfflib')
except Exception as e:
have_cfflib = False
else:
import cfflib as cf
from .base import CFFBaseInterface, have_cfflib


class CFFConverterInputSpec(BaseInterfaceInputSpec):
Expand Down Expand Up @@ -67,7 +59,7 @@ class CFFConverterOutputSpec(TraitedSpec):
connectome_file = File(exists=True, desc='Output connectome file')


class CFFConverter(BaseInterface):
class CFFConverter(CFFBaseInterface):
"""
Creates a Connectome File Format (CFF) file from input networks, surfaces, volumes, tracts, etcetera....

Expand All @@ -87,6 +79,7 @@ class CFFConverter(BaseInterface):
output_spec = CFFConverterOutputSpec

def _run_interface(self, runtime):
import cfflib as cf
a = cf.connectome()

if isdefined(self.inputs.title):
Expand Down Expand Up @@ -232,7 +225,7 @@ class MergeCNetworksOutputSpec(TraitedSpec):
exists=True, desc='Output CFF file with all the networks added')


class MergeCNetworks(BaseInterface):
class MergeCNetworks(CFFBaseInterface):
""" Merges networks from multiple CFF files into one new CFF file.

Example
Expand All @@ -248,6 +241,7 @@ class MergeCNetworks(BaseInterface):
output_spec = MergeCNetworksOutputSpec

def _run_interface(self, runtime):
import cfflib as cf
extracted_networks = []

for i, con in enumerate(self.inputs.in_files):
Expand Down
19 changes: 5 additions & 14 deletions nipype/interfaces/cmtk/nbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,11 @@
import networkx as nx

from ... import logging
from ...utils.misc import package_check
from ..base import (BaseInterface, BaseInterfaceInputSpec, traits, File,
from ..base import (LibraryBaseInterface, BaseInterfaceInputSpec, traits, File,
TraitedSpec, InputMultiPath, OutputMultiPath, isdefined)
from .base import have_cv
iflogger = logging.getLogger('interface')

have_cv = True
try:
package_check('cviewer')
except Exception as e:
have_cv = False
else:
import cviewer.libs.pyconto.groupstatistics.nbs as nbs


def ntwks_to_matrices(in_files, edge_key):
first = nx.read_gpickle(in_files[0])
Expand Down Expand Up @@ -92,7 +84,7 @@ class NetworkBasedStatisticOutputSpec(TraitedSpec):
desc='Output network with edges identified by the NBS')


class NetworkBasedStatistic(BaseInterface):
class NetworkBasedStatistic(LibraryBaseInterface):
"""
Calculates and outputs the average network given a set of input NetworkX gpickle files

Expand All @@ -111,11 +103,10 @@ class NetworkBasedStatistic(BaseInterface):
"""
input_spec = NetworkBasedStatisticInputSpec
output_spec = NetworkBasedStatisticOutputSpec
_pkg = 'cviewer'

def _run_interface(self, runtime):

if not have_cv:
raise ImportError("cviewer library is not available")
from cviewer.libs.pyconto.groupstatistics import nbs

THRESH = self.inputs.threshold
K = self.inputs.number_of_permutations
Expand Down
10 changes: 1 addition & 9 deletions nipype/interfaces/cmtk/nx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,12 @@

from ... import logging
from ...utils.filemanip import split_filename
from ...utils.misc import package_check
from ..base import (BaseInterface, BaseInterfaceInputSpec, traits, File,
TraitedSpec, InputMultiPath, OutputMultiPath, isdefined)
from .base import have_cmp

iflogger = logging.getLogger('interface')

have_cmp = True
try:
package_check('cmp')
except Exception as e:
have_cmp = False
else:
import cmp


def read_unknown_ntwk(ntwk):
if not isinstance(ntwk, nx.classes.graph.Graph):
Expand Down
31 changes: 13 additions & 18 deletions nipype/interfaces/cmtk/parcellation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,22 @@
import os
import os.path as op
import shutil
import warnings

import numpy as np
import nibabel as nb
import networkx as nx

from ... import logging
from ...utils.misc import package_check
from ..base import (BaseInterface, BaseInterfaceInputSpec, traits, File,
from ..base import (BaseInterface, LibraryBaseInterface,
BaseInterfaceInputSpec, traits, File,
TraitedSpec, Directory, isdefined)
from .base import have_cmp
iflogger = logging.getLogger('interface')

have_cmp = True
try:
package_check('cmp')
except Exception as e:
have_cmp = False
else:
import cmp
from cmp.util import runCmd


def create_annot_label(subject_id, subjects_dir, fs_dir, parcellation_name):
import cmp
from cmp.util import runCmd
iflogger.info("Create the cortical labels necessary for our ROIs")
iflogger.info("=================================================")
fs_label_dir = op.join(op.join(subjects_dir, subject_id), 'label')
Expand Down Expand Up @@ -174,6 +167,8 @@ def create_annot_label(subject_id, subjects_dir, fs_dir, parcellation_name):
def create_roi(subject_id, subjects_dir, fs_dir, parcellation_name, dilation):
""" Creates the ROI_%s.nii.gz files using the given parcellation information
from networks. Iteratively create volume. """
import cmp
from cmp.util import runCmd
iflogger.info("Create the ROIs:")
output_dir = op.abspath(op.curdir)
fs_dir = op.join(subjects_dir, subject_id)
Expand Down Expand Up @@ -306,6 +301,8 @@ def create_roi(subject_id, subjects_dir, fs_dir, parcellation_name, dilation):


def create_wm_mask(subject_id, subjects_dir, fs_dir, parcellation_name):
import cmp
import scipy.ndimage.morphology as nd
iflogger.info("Create white matter mask")
fs_dir = op.join(subjects_dir, subject_id)
cmp_config = cmp.configuration.PipelineConfiguration()
Expand All @@ -328,11 +325,6 @@ def create_wm_mask(subject_id, subjects_dir, fs_dir, parcellation_name):
aseg = nb.load(op.join(fs_dir, 'mri', 'aseg.nii.gz'))
asegd = aseg.get_data()

try:
import scipy.ndimage.morphology as nd
except ImportError:
raise Exception('Need scipy for binary erosion of white matter mask')

# need binary erosion function
imerode = nd.binary_erosion

Expand Down Expand Up @@ -438,6 +430,7 @@ def create_wm_mask(subject_id, subjects_dir, fs_dir, parcellation_name):

def crop_and_move_datasets(subject_id, subjects_dir, fs_dir, parcellation_name,
out_roi_file, dilation):
from cmp.util import runCmd
fs_dir = op.join(subjects_dir, subject_id)
cmp_config = cmp.configuration.PipelineConfiguration()
cmp_config.parcellation_scheme = "Lausanne2008"
Expand Down Expand Up @@ -549,7 +542,7 @@ class ParcellateOutputSpec(TraitedSpec):
)


class Parcellate(BaseInterface):
class Parcellate(LibraryBaseInterface):
"""Subdivides segmented ROI file into smaller subregions

This interface implements the same procedure as in the ConnectomeMapper's
Expand All @@ -571,6 +564,8 @@ class Parcellate(BaseInterface):

input_spec = ParcellateInputSpec
output_spec = ParcellateOutputSpec
_pkg = 'cmp'
imports = ('scipy', )

def _run_interface(self, runtime):
if self.inputs.subjects_dir:
Expand Down
17 changes: 17 additions & 0 deletions nipype/interfaces/cmtk/tests/test_auto_CFFBaseInterface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# AUTO-GENERATED by tools/checkspecs.py - DO NOT EDIT
from __future__ import unicode_literals
from ..base import CFFBaseInterface


def test_CFFBaseInterface_inputs():
input_map = dict(
ignore_exception=dict(
deprecated='1.0.0',
nohash=True,
usedefault=True,
), )
inputs = CFFBaseInterface.input_spec()

for key, metadata in list(input_map.items()):
for metakey, value in list(metadata.items()):
assert getattr(inputs.traits()[key], metakey) == value
1 change: 0 additions & 1 deletion nipype/interfaces/cmtk/tests/test_nbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def test_importerror(creating_graphs, tmpdir):

with pytest.raises(ImportError) as e:
nbs.run()
assert "cviewer library is not available" == str(e.value)


@pytest.mark.skipif(not have_cv, reason="cviewer has to be available")
Expand Down
Loading