Skip to content

ENH: Extend the nonlinear transforms API #166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions nitransforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""
from . import linear, manip, nonlinear
from .linear import Affine, LinearTransformsMapping
from .nonlinear import DisplacementsFieldTransform
from .nonlinear import DenseFieldTransform
from .manip import TransformChain

try:
Expand All @@ -42,7 +42,7 @@
"nonlinear",
"Affine",
"LinearTransformsMapping",
"DisplacementsFieldTransform",
"DenseFieldTransform",
"TransformChain",
"__copyright__",
"__packagename__",
Expand Down
4 changes: 2 additions & 2 deletions nitransforms/manip.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
TransformError,
)
from .linear import Affine
from .nonlinear import DisplacementsFieldTransform
from .nonlinear import DenseFieldTransform


class TransformChain(TransformBase):
Expand Down Expand Up @@ -197,7 +197,7 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None):
if isinstance(xfmobj, itk.ITKLinearTransform):
retval.insert(0, Affine(xfmobj.to_ras(), reference=reference))
else:
retval.insert(0, DisplacementsFieldTransform(xfmobj))
retval.insert(0, DenseFieldTransform(xfmobj))

return TransformChain(retval)

Expand Down
137 changes: 110 additions & 27 deletions nitransforms/nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,51 @@
)


class DisplacementsFieldTransform(TransformBase):
"""Represents a dense field of displacements (one vector per voxel)."""
class DenseFieldTransform(TransformBase):
"""Represents dense field (voxel-wise) transforms."""

__slots__ = ["_field"]
__slots__ = ("_field", "_deltas")

def __init__(self, field, reference=None):
def __init__(self, field=None, is_deltas=True, reference=None):
"""
Create a dense deformation field transform.
Create a dense field transform.

Converting to a field of deformations is straightforward by just adding the corresponding
displacement to the :math:`(x, y, z)` coordinates of each voxel.
Numerically, deformation fields are less susceptible to rounding errors
than displacements fields.
SPM generally prefers deformations for that reason.

Parameters
----------
field : :obj:`numpy.array_like` or :obj:`nibabel.SpatialImage`
The field of deformations or displacements (*deltas*). If given as a data array,
then the reference **must** be given.
is_deltas : :obj:`bool`
Whether this is a displacements (deltas) field (default), or deformations.
reference : :obj:`ImageGrid`
Defines the domain of the transform. If not provided, the domain is defined from
the ``field`` input.

Example
-------
>>> DisplacementsFieldTransform(test_dir / "someones_displacement_field.nii.gz")
<DisplacementFieldTransform[3D] (57, 67, 56)>
>>> DenseFieldTransform(test_dir / "someones_displacement_field.nii.gz")
<DenseFieldTransform[3D] (57, 67, 56)>

"""
if field is None and reference is None:
raise TransformError("DenseFieldTransforms require a spatial reference")

super().__init__()

field = _ensure_image(field)
self._field = np.squeeze(
np.asanyarray(field.dataobj) if hasattr(field, "dataobj") else field
)
if field is not None:
field = _ensure_image(field)
self._field = np.squeeze(
np.asanyarray(field.dataobj) if hasattr(field, "dataobj") else field
)
else:
self._field = np.zeros((*reference.shape, reference.ndim), dtype="float32")
is_deltas = True

try:
self.reference = ImageGrid(
Expand All @@ -59,45 +83,61 @@ def __init__(self, field, reference=None):
ndim = self._field.ndim - 1
if self._field.shape[-1] != ndim:
raise TransformError(
"The number of components of the displacements (%d) does not "
"The number of components of the field (%d) does not match "
"the number of dimensions (%d)" % (self._field.shape[-1], ndim)
)

if is_deltas:
self._deltas = self._field
# Convert from displacements (deltas) to deformations fields
# (just add its origin to each delta vector)
self._field += self.reference.ndcoords.T.reshape(self._field.shape)

def __repr__(self):
"""Beautify the python representation."""
return f"<DisplacementFieldTransform[{self._field.shape[-1]}D] {self._field.shape[:3]}>"
return f"<{self.__class__.__name__}[{self._field.shape[-1]}D] {self._field.shape[:3]}>"

def map(self, x, inverse=False):
r"""
Apply the transformation to a list of physical coordinate points.

.. math::
\mathbf{y} = \mathbf{x} + D(\mathbf{x}),
\mathbf{y} = \mathbf{x} + \Delta(\mathbf{x}),
\label{eq:2}\tag{2}

where :math:`D(\mathbf{x})` is the value of the discrete field of displacements
:math:`D` interpolated at the location :math:`\mathbf{x}`.
where :math:`\Delta(\mathbf{x})` is the value of the discrete field of displacements
:math:`\Delta` interpolated at the location :math:`\mathbf{x}`.

Parameters
----------
x : N x D numpy.ndarray
x : N x D :obj:`numpy.array_like`
Input RAS+ coordinates (i.e., physical coordinates).
inverse : bool
inverse : :obj:`bool`
If ``True``, apply the inverse transform :math:`x = f^{-1}(y)`.

Returns
-------
y : N x D numpy.ndarray
y : N x D :obj:`numpy.array_like`
Transformed (mapped) RAS+ coordinates (i.e., physical coordinates).

Examples
--------
>>> xfm = DisplacementsFieldTransform(test_dir / "someones_displacement_field.nii.gz")
>>> xfm = DenseFieldTransform(
... test_dir / "someones_displacement_field.nii.gz",
... is_deltas=False,
... )
>>> xfm.map([-6.5, -36., -19.5]).tolist()
[[-6.5, -36.475167989730835, -19.5]]
[[0.0, -0.47516798973083496, 0.0]]

>>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
[[-6.5, -36.475167989730835, -19.5], [-1.0, -42.038356602191925, -11.25]]
[[0.0, -0.47516798973083496, 0.0], [0.0, -0.538356602191925, 0.0]]

>>> xfm = DenseFieldTransform(
... test_dir / "someones_displacement_field.nii.gz",
... is_deltas=True,
... )
>>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
[[-6.5, -36.47516632080078, -19.5], [-1.0, -42.03835678100586, -11.25]]

"""

Expand All @@ -106,9 +146,51 @@ def map(self, x, inverse=False):
ijk = self.reference.index(x)
indexes = np.round(ijk).astype("int")
if np.any(np.abs(ijk - indexes) > 0.05):
warnings.warn("Some coordinates are off-grid of the displacements field.")
warnings.warn("Some coordinates are off-grid of the field.")
indexes = tuple(tuple(i) for i in indexes.T)
return x + self._field[indexes]
return self._field[indexes]

def __matmul__(self, b):
"""
Compose with a transform on the right.

Examples
--------
>>> deff = DenseFieldTransform(
... test_dir / "someones_displacement_field.nii.gz",
... is_deltas=False,
... )
>>> deff2 = deff @ TransformBase()
>>> deff == deff2
True

>>> disp = DenseFieldTransform(test_dir / "someones_displacement_field.nii.gz")
>>> disp2 = disp @ TransformBase()
>>> disp == disp2
True

"""
retval = b.map(
self._field.reshape((-1, self._field.shape[-1]))
).reshape(self._field.shape)
return DenseFieldTransform(retval, is_deltas=False, reference=self.reference)

def __eq__(self, other):
"""
Overload equals operator.

Examples
--------
>>> xfm1 = DenseFieldTransform(test_dir / "someones_displacement_field.nii.gz")
>>> xfm2 = DenseFieldTransform(test_dir / "someones_displacement_field.nii.gz")
>>> xfm1 == xfm2
True

"""
_eq = np.array_equal(self._field, other._field)
if _eq and self._reference != other._reference:
warnings.warn("Fields are equal, but references do not match.")
return _eq

@classmethod
def from_filename(cls, filename, fmt="X5"):
Expand All @@ -123,7 +205,7 @@ def from_filename(cls, filename, fmt="X5"):
return cls(_factory[fmt].from_filename(filename))


load = DisplacementsFieldTransform.from_filename
load = DenseFieldTransform.from_filename


class BSplineFieldTransform(TransformBase):
Expand Down Expand Up @@ -169,8 +251,9 @@ def to_field(self, reference=None, dtype="float32"):
# 1 x Nvox : (1 x K) @ (K x Nvox)
field[:, d] = self._coeffs[..., d].reshape(-1) @ self._weights

return DisplacementsFieldTransform(
field.astype(dtype).reshape(*_ref.shape, -1), reference=_ref)
return DenseFieldTransform(
field.astype(dtype).reshape(*_ref.shape, -1), reference=_ref
)

def apply(
self,
Expand Down
20 changes: 14 additions & 6 deletions nitransforms/tests/test_nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from nitransforms.io.base import TransformFileError
from nitransforms.nonlinear import (
BSplineFieldTransform,
DisplacementsFieldTransform,
DenseFieldTransform,
load as nlload,
)
from ..io.itk import ITKDisplacementsField
Expand Down Expand Up @@ -45,7 +45,7 @@ def test_itk_disp_load(size):
def test_displacements_bad_sizes(size):
"""Checks field sizes."""
with pytest.raises(TransformError):
DisplacementsFieldTransform(nb.Nifti1Image(np.zeros(size), np.eye(4), None))
DenseFieldTransform(nb.Nifti1Image(np.zeros(size), np.eye(4), None))


def test_itk_disp_load_intent():
Expand All @@ -59,15 +59,23 @@ def test_itk_disp_load_intent():


def test_displacements_init():
DisplacementsFieldTransform(
identity1 = DenseFieldTransform(
np.zeros((10, 10, 10, 3)),
reference=nb.Nifti1Image(np.zeros((10, 10, 10, 3)), np.eye(4), None),
)
identity2 = DenseFieldTransform(
reference=nb.Nifti1Image(np.zeros((10, 10, 10)), np.eye(4), None),
)

assert np.array_equal(identity1._field, identity2._field)
assert np.array_equal(identity1._deltas, identity2._deltas)

with pytest.raises(TransformError):
DisplacementsFieldTransform(np.zeros((10, 10, 10, 3)))
DenseFieldTransform()
with pytest.raises(TransformError):
DenseFieldTransform(np.zeros((10, 10, 10, 3)))
with pytest.raises(TransformError):
DisplacementsFieldTransform(
DenseFieldTransform(
np.zeros((10, 10, 10, 3)),
reference=np.zeros((10, 10, 10, 3)),
)
Expand Down Expand Up @@ -237,7 +245,7 @@ def test_bspline(tmp_path, testdata_path):
bs_name = testdata_path / "someones_bspline_coefficients.nii.gz"

bsplxfm = BSplineFieldTransform(bs_name, reference=img_name)
dispxfm = DisplacementsFieldTransform(disp_name)
dispxfm = DenseFieldTransform(disp_name)

out_disp = dispxfm.apply(img_name)
out_bspl = bsplxfm.apply(img_name)
Expand Down