Skip to content

Commit

Permalink
enh(displacementsfields): early implementation of itk importing and t…
Browse files Browse the repository at this point in the history
…ests

Closes #32
  • Loading branch information
oesteban committed Oct 30, 2019
1 parent 04ce632 commit 68ec96d
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 27 deletions.
36 changes: 36 additions & 0 deletions nitransforms/io/itk.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Read/write ITK transforms."""
import warnings
import numpy as np
from scipy.io import savemat as _save_mat
from nibabel.loadsave import load as loadimg
from nibabel.affines import from_matvec
from .base import BaseLinearTransformList, LinearParameters, _read_mat, TransformFileError

Expand Down Expand Up @@ -245,3 +247,37 @@ def from_string(cls, string):
_self.xforms.append(cls._inner_type.from_string(
'#%s' % xfm))
return _self


class ITKDisplacementsField:
"""A data structure representing displacements fields."""

@classmethod
def from_filename(cls, filename):
"""Import a displacements field from a NIfTI file."""
imgobj = loadimg(str(filename))
return cls.from_image(imgobj)

@classmethod
def from_image(cls, imgobj):
"""Import a displacements field from a NIfTI file."""
_hdr = imgobj.header.copy()
_shape = _hdr.get_data_shape()

if (
len(_shape) != 5 or
_shape[-2] != 1 or
not _shape[-1] in (2, 3)
):
raise TransformFileError(
'Displacements field "%s" does not come from ITK.' %
imgobj.file_map['image'].filename)

if _hdr.get_intent()[0] != 'vector':
warnings.warn('Incorrect intent identified.')
_hdr.set_intent('vector')

_field = np.squeeze(np.asanyarray(imgobj.dataobj))
_field[..., (0, 1)] *= -1.0

return imgobj.__class__(_field, imgobj.affine, _hdr)
12 changes: 5 additions & 7 deletions nitransforms/nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,17 @@ class DisplacementsFieldTransform(TransformBase):

def __init__(self, field, reference=None):
"""Create a dense deformation field transform."""
super(DisplacementsFieldTransform, self).__init__()
super().__init__()
self._field = np.asanyarray(field.dataobj)

ndim = self._field.ndim - 1
if len(self._field.shape[:-1]) != ndim:
if self._field.shape[:-1] != ndim:
raise ValueError(
'Number of components of the deformation field does '
'The number of components of the displacements field does '
'not match the number of dimensions')

if reference is None:
reference = field.__class__(np.zeros(self._field.shape[:-1]),
field.affine, field.header)
self.reference = reference
self.reference = field.__class__(np.zeros(self._field.shape[:-1]),
field.affine, field.header)

def map(self, x, inverse=False, index=0):
r"""
Expand Down
96 changes: 76 additions & 20 deletions nitransforms/tests/test_nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

import numpy as np
import nibabel as nb
from ..io.base import TransformFileError
from ..nonlinear import DisplacementsFieldTransform
from ..io.itk import ITKDisplacementsField

TESTS_BORDER_TOLERANCE = 0.05
APPLY_NONLINEAR_CMD = {
Expand All @@ -19,34 +21,88 @@
}


@pytest.mark.parametrize('size', [(20, 20, 20), (20, 20, 20, 3)])
def test_itk_disp_load(size):
"""Checks field sizes."""
with pytest.raises(TransformFileError):
ITKDisplacementsField.from_image(
nb.Nifti1Image(np.zeros(size), None, None))


@pytest.mark.parametrize('size', [(20, 20, 20), (20, 20, 20, 1, 3)])
def test_displacements_bad_sizes(size):
"""Checks field sizes."""
with pytest.raises(ValueError):
DisplacementsFieldTransform(
nb.Nifti1Image(np.zeros(size), None, None))


def test_itk_disp_load_intent():
"""Checks whether the NIfTI intent is fixed."""
with pytest.warns(UserWarning):
field = ITKDisplacementsField.from_image(
nb.Nifti1Image(np.zeros((20, 20, 20, 1, 3)), None, None))

assert field.header.get_intent()[0] == 'vector'


@pytest.mark.parametrize('image_orientation', ['RAS', 'LAS', 'LPS', 'oblique'])
@pytest.mark.parametrize('sw_tool', ['itk'])
def test_displacements_field(tmp_path, data_path, sw_tool):
@pytest.mark.parametrize('axis', [0, 1, 2, (0, 1), (1, 2), (0, 1, 2)])
def test_displacements_field1(tmp_path, get_testdata, image_orientation, sw_tool, axis):
"""Check a translation-only field on one or more axes, different image orientations."""
os.chdir(str(tmp_path))
img_fname = str(data_path / 'tpl-OASIS30ANTs_T1w.nii.gz')
xfm_fname = str(
data_path / 'ds-005_sub-01_from-OASIS_to-T1_warp.nii.gz')
ants_warp = nb.load(xfm_fname)
hdr = ants_warp.header.copy()
nii = get_testdata[image_orientation]
nii.to_filename('reference.nii.gz')
fieldmap = np.zeros((*nii.shape[:3], 1, 3), dtype='float32')
fieldmap[..., axis] = -10.0

_hdr = nii.header.copy()
_hdr.set_intent('vector')
_hdr.set_data_dtype('float32')

# fieldmap = np.squeeze(np.asanyarray(ants_warp.dataobj))
xfm_fname = 'warp.nii.gz'
nii = nb.load(img_fname)
fieldmap = np.zeros((*nii.shape[:3], 1, 3))
fieldmap[..., 2] = -10.0
# fieldmap = np.flip(np.flip(fieldmap, 1), 0)
ants_warp = nb.Nifti1Image(fieldmap, nii.affine, hdr)
ants_warp.to_filename(xfm_fname)
fieldmap = np.squeeze(np.asanyarray(ants_warp.dataobj))
field = nb.Nifti1Image(
fieldmap,
ants_warp.affine, ants_warp.header
)

xfm = DisplacementsFieldTransform(field)
field = nb.Nifti1Image(fieldmap, nii.affine, _hdr)
field.to_filename(xfm_fname)

xfm = DisplacementsFieldTransform(
ITKDisplacementsField.from_image(field))

# Then apply the transform and cross-check with software
cmd = APPLY_NONLINEAR_CMD[sw_tool](
transform=os.path.abspath(xfm_fname),
reference=tmp_path / 'reference.nii.gz',
moving=tmp_path / 'reference.nii.gz')

# skip test if command is not available on host
exe = cmd.split(" ", 1)[0]
if not shutil.which(exe):
pytest.skip("Command {} not found on host".format(exe))

exit_code = check_call([cmd], shell=True)
assert exit_code == 0
sw_moved = nb.load('resampled.nii.gz')

nt_moved = xfm.apply(nii, order=0)
nt_moved.to_filename('nt_resampled.nii.gz')
diff = sw_moved.get_fdata() - nt_moved.get_fdata()
# A certain tolerance is necessary because of resampling at borders
assert (np.abs(diff) > 1e-3).sum() / diff.size < TESTS_BORDER_TOLERANCE


@pytest.mark.parametrize('sw_tool', ['itk'])
def test_displacements_field2(tmp_path, data_path, sw_tool):
"""Check a translation-only field on one or more axes, different image orientations."""
os.chdir(str(tmp_path))
img_fname = data_path / 'tpl-OASIS30ANTs_T1w.nii.gz'
xfm_fname = data_path / 'ds-005_sub-01_from-OASIS_to-T1_warp.nii.gz'

xfm = DisplacementsFieldTransform(
ITKDisplacementsField.from_filename(xfm_fname))

# Then apply the transform and cross-check with software
cmd = APPLY_NONLINEAR_CMD[sw_tool](
transform=xfm_fname,
reference=img_fname,
moving=img_fname)

Expand Down

0 comments on commit 68ec96d

Please sign in to comment.