Skip to content

[ENH] New WarpPoints interface in algorithms.mesh #889

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Sep 12, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
Next release
============

* ENH: New algorithm: mesh.WarpPoints applies displacements fields to point sets
(https://github.com/nipy/nipype/pull/889).
* ENH: New interfaces for MRTrix3 (https://github.com/nipy/nipype/pull/1126)
* ENH: New option in afni.3dRefit - zdel, ydel, zdel etc. (https://github.com/nipy/nipype/pull/1079)
* FIX: ants.Registration composite transform outputs are no longer returned as lists (https://github.com/nipy/nipype/pull/1183)
Expand Down
134 changes: 129 additions & 5 deletions nipype/algorithms/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,140 @@
iflogger = logging.getLogger('interface')


class WarpPointsInputSpec(BaseInterfaceInputSpec):
points = File(exists=True, mandatory=True,
desc=('file containing the point set'))
warp = File(exists=True, mandatory=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where does this field come from? different nonlinear registration algorithms store this field in different ways. we should specify which types of fields this tool supports.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm getting satisfactory results using fields generated with FSL and convertwarp. Therefore, I guess any field written by FSL (fields, not coefficients) should work out. But true, we should specify this.

I'll be checking what other interfaces generate compatible fields (and/or implement filters for those that are somehow direct). Basically, any 4D image with 3 components should be enough.

Points that fall outside the FoV of the image will experiment zero-displacement.

desc=('dense deformation field to be applied'))
interp = traits.Enum('cubic', 'nearest', 'linear', usedefault=True,
mandatory=True, desc='interpolation')
out_points = File(name_source='points', name_template='%s_warped',
output_name='out_points', keep_extension=True,
desc='the warped point set')


class WarpPointsOutputSpec(TraitedSpec):
out_points = File(desc='the warped point set')


class WarpPoints(BaseInterface):

"""
Applies a displacement field to a point set given in vtk format.
Any discrete deformation field, given in physical coordinates and
which volume covers the extent of the vtk point set, is a valid
``warp`` file. FSL interfaces are compatible, for instance any
field computed with :class:`nipype.interfaces.fsl.utils.ConvertWarp`.

Example
-------

>>> from nipype.algorithms.mesh import WarpPoints
>>> wp = WarpPoints()
>>> wp.inputs.points = 'surf1.vtk'
>>> wp.inputs.warp = 'warpfield.nii'
>>> res = wp.run() # doctest: +SKIP
"""
input_spec = WarpPointsInputSpec
output_spec = WarpPointsOutputSpec
_redirect_x = True

def _gen_fname(self, in_file, suffix='generated', ext=None):
import os.path as op

fname, fext = op.splitext(op.basename(in_file))

if fext == '.gz':
fname, fext2 = op.splitext(fname)
fext = fext2+fext

if ext is None:
ext = fext

if ext[0] == '.':
ext = ext[1:]
return op.abspath('%s_%s.%s' % (fname, suffix, ext))

def _run_interface(self, runtime):
vtk_major = 6
try:
import vtk
vtk_major = vtk.VTK_MAJOR_VERSION
except ImportError:
iflogger.warn(('python-vtk could not be imported'))

try:
from tvtk.api import tvtk
except ImportError:
raise ImportError('Interface requires tvtk')

try:
from enthought.etsconfig.api import ETSConfig
ETSConfig.toolkit = 'null'
except ImportError:
iflogger.warn(('ETS toolkit could not be imported'))
except ValueError:
iflogger.warn(('ETS toolkit could not be set to null'))

import nibabel as nb
import numpy as np
from scipy import ndimage

r = tvtk.PolyDataReader(file_name=self.inputs.points)
r.update()
mesh = r.output
points = np.array(mesh.points)
warp_dims = nb.funcs.four_to_three(nb.load(self.inputs.warp))

affine = warp_dims[0].get_affine()
voxsize = warp_dims[0].get_header().get_zooms()
vox2ras = affine[0:3, 0:3]
ras2vox = np.linalg.inv(vox2ras)
origin = affine[0:3, 3]
voxpoints = np.array([np.dot(ras2vox,
(p-origin)) for p in points])

warps = []
for axis in warp_dims:
wdata = axis.get_data()
if np.any(wdata != 0):

warp = ndimage.map_coordinates(wdata,
voxpoints.transpose())
else:
warp = np.zeros((points.shape[0],))

warps.append(warp)

disps = np.squeeze(np.dstack(warps))
newpoints = [p+d for p, d in zip(points, disps)]
mesh.points = newpoints
w = tvtk.PolyDataWriter()
if vtk_major <= 5:
w.input = mesh
else:
w.set_input_data_object(mesh)

w.file_name = self._gen_fname(self.inputs.points,
suffix='warped',
ext='.vtk')
w.write()
return runtime

def _list_outputs(self):
outputs = self._outputs().get()
outputs['out_points'] = self._gen_fname(self.inputs.points,
suffix='warped',
ext='.vtk')
return outputs


class ComputeMeshWarpInputSpec(BaseInterfaceInputSpec):
surface1 = File(exists=True, mandatory=True,
desc=('Reference surface (vtk format) to which compute '
'distance.'))
surface2 = File(exists=True, mandatory=True,

desc=('Test surface (vtk format) from which compute '
'distance.'))
metric = traits.Enum('euclidean', 'sqeuclidean', usedefault=True,
Expand Down Expand Up @@ -101,10 +230,8 @@ def _run_interface(self, runtime):
ETSConfig.toolkit = 'null'
except ImportError:
iflogger.warn(('ETS toolkit could not be imported'))
pass
except ValueError:
iflogger.warn(('ETS toolkit is already set'))
pass

r1 = tvtk.PolyDataReader(file_name=self.inputs.surface1)
r2 = tvtk.PolyDataReader(file_name=self.inputs.surface2)
Expand All @@ -124,7 +251,6 @@ def _run_interface(self, runtime):
errvector = nla.norm(diff, axis=1)
except TypeError: # numpy < 1.9
errvector = np.apply_along_axis(nla.norm, 1, diff)
pass

if self.inputs.metric == 'sqeuclidean':
errvector = errvector ** 2
Expand Down Expand Up @@ -235,10 +361,8 @@ def _run_interface(self, runtime):
ETSConfig.toolkit = 'null'
except ImportError:
iflogger.warn(('ETS toolkit could not be imported'))
pass
except ValueError:
iflogger.warn(('ETS toolkit is already set'))
pass

r1 = tvtk.PolyDataReader(file_name=self.inputs.in_surf)
vtk1 = r1.output
Expand Down