Skip to content

Commit

Permalink
enh: port from process pool into asyncio concurrent
Browse files Browse the repository at this point in the history
Co-authored-by: Chris Markiewicz <effigies@gmail.com>
  • Loading branch information
oesteban and effigies committed Aug 6, 2024
1 parent 7c7608f commit 063e1f0
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 75 deletions.
172 changes: 111 additions & 61 deletions nitransforms/resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
"""Resampling utilities."""

import asyncio
from os import cpu_count
from concurrent.futures import ProcessPoolExecutor, as_completed
from functools import partial
from pathlib import Path
from typing import Tuple
from typing import Callable, TypeVar

import numpy as np
from nibabel.loadsave import load as _nbload
Expand All @@ -27,30 +28,58 @@
_as_homogeneous,
)

R = TypeVar("R")

SERIALIZE_VOLUME_WINDOW_WIDTH: int = 8
"""Minimum number of volumes to automatically serialize 4D transforms."""


def _apply_volume(
index: int,
async def worker(job: Callable[[], R], semaphore) -> R:
async with semaphore:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, job)


async def _apply_serial(
data: np.ndarray,
spatialimage: SpatialImage,
targets: np.ndarray,
transform: TransformBase,
ref_ndim: int,
ref_ndcoords: np.ndarray,
n_resamplings: int,
output: np.ndarray,
input_dtype: np.dtype,
order: int = 3,
mode: str = "constant",
cval: float = 0.0,
prefilter: bool = True,
) -> Tuple[int, np.ndarray]:
max_concurrent: int = min(cpu_count(), 12),
):
"""
Decorate :obj:`~scipy.ndimage.map_coordinates` to return an order index for parallelization.
Resample through a given transform serially, in a 3D+t setting.
Parameters
----------
index : :obj:`int`
The index of the volume to apply the interpolation to.
data : :obj:`~numpy.ndarray`
The input data array.
spatialimage : :obj:`~nibabel.spatialimages.SpatialImage` or `os.pathlike`
The image object containing the data to be resampled in reference
space
targets : :obj:`~numpy.ndarray`
The target coordinates for mapping.
transform : :obj:`~nitransforms.base.TransformBase`
The 3D, 3D+t, or 4D transform through which data will be resampled.
ref_ndim : :obj:`int`
Dimensionality of the resampling target (reference image).
ref_ndcoords : :obj:`~numpy.ndarray`
Physical coordinates (RAS+) where data will be interpolated, if the resampling
target is a grid, the scanner coordinates of all voxels.
n_resamplings : :obj:`int`
Total number of 3D resamplings (can be defined by the input image, the transform,
or be matched, that is, same number of volumes in the input and number of transforms).
output : :obj:`~numpy.ndarray`
The output data array where resampled values will be stored volume-by-volume.
order : :obj:`int`, optional
The order of the spline interpolation, default is 3.
The order has to be in the range 0-5.
Expand All @@ -71,18 +100,46 @@ def _apply_volume(
Returns
-------
(:obj:`int`, :obj:`~numpy.ndarray`)
The index and the array resulting from the interpolation.
np.ndarray
Data resampled on the 3D+t array of input coordinates.
"""
return index, ndi.map_coordinates(
data,
targets,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
)
tasks = []
semaphore = asyncio.Semaphore(max_concurrent)

for t in range(n_resamplings):
xfm_t = transform if n_resamplings == 1 else transform[t]

if targets is None:
targets = ImageGrid(spatialimage).index( # data should be an image
_as_homogeneous(xfm_t.map(ref_ndcoords), dim=ref_ndim)
)

data_t = (
data
if data is not None
else spatialimage.dataobj[..., t].astype(input_dtype, copy=False)
)

tasks.append(
asyncio.create_task(
worker(
partial(
ndi.map_coordinates,
data_t,
targets,
output=output[..., t],
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
),
semaphore,
)
)
)
await asyncio.gather(*tasks)
return output


def apply(
Expand All @@ -94,15 +151,17 @@ def apply(
cval: float = 0.0,
prefilter: bool = True,
output_dtype: np.dtype = None,
serialize_nvols: int = SERIALIZE_VOLUME_WINDOW_WIDTH,
njobs: int = None,
dtype_width: int = 8,
serialize_nvols: int = SERIALIZE_VOLUME_WINDOW_WIDTH,
max_concurrent: int = min(cpu_count(), 12),
) -> SpatialImage | np.ndarray:
"""
Apply a transformation to an image, resampling on the reference spatial object.
Parameters
----------
transform: :obj:`~nitransforms.base.TransformBase`
The 3D, 3D+t, or 4D transform through which data will be resampled.
spatialimage : :obj:`~nibabel.spatialimages.SpatialImage` or `os.pathlike`
The image object containing the data to be resampled in reference
space
Expand All @@ -118,15 +177,15 @@ def apply(
or ``'wrap'``. Default is ``'constant'``.
cval : :obj:`float`, optional
Constant value for ``mode='constant'``. Default is 0.0.
prefilter: :obj:`bool`, optional
prefilter : :obj:`bool`, optional
Determines if the image's data array is prefiltered with
a spline filter before interpolation. The default is ``True``,
which will create a temporary *float64* array of filtered values
if *order > 1*. If setting this to ``False``, the output will be
slightly blurred if *order > 1*, unless the input is prefiltered,
i.e. it is the result of calling the spline filter on the original
input.
output_dtype: :obj:`~numpy.dtype`, optional
output_dtype : :obj:`~numpy.dtype`, optional
The dtype of the returned array or image, if specified.
If ``None``, the default behavior is to use the effective dtype of
the input image. If slope and/or intercept are defined, the effective
Expand All @@ -135,10 +194,17 @@ def apply(
If ``reference`` is defined, then the return value is an image, with
a data array of the effective dtype but with the on-disk dtype set to
the input image's on-disk dtype.
dtype_width: :obj:`int`
dtype_width : :obj:`int`
Cap the width of the input data type to the given number of bytes.
This argument is intended to work as a way to implement lower memory
requirements in resampling.
serialize_nvols : :obj:`int`
Minimum number of volumes in a 3D+t (that is, a series of 3D transformations
independent in time) to resample on a one-by-one basis.
Serialized resampling can be executed concurrently (parallelized) with
the argument ``max_concurrent``.
max_concurrent : :obj:`int`
Maximum number of 3D resamplings to be executed concurrently.
Returns
-------
Expand Down Expand Up @@ -201,46 +267,30 @@ def apply(
else None
)

njobs = cpu_count() if njobs is None or njobs < 1 else njobs

with ProcessPoolExecutor(max_workers=min(njobs, n_resamplings)) as executor:
results = []
for t in range(n_resamplings):
xfm_t = transform if n_resamplings == 1 else transform[t]

if targets is None:
targets = ImageGrid(spatialimage).index( # data should be an image
_as_homogeneous(xfm_t.map(ref_ndcoords), dim=_ref.ndim)
)

data_t = (
data
if data is not None
else spatialimage.dataobj[..., t].astype(input_dtype, copy=False)
)

results.append(
executor.submit(
_apply_volume,
t,
data_t,
targets,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
)
)
# Order F ensures individual volumes are contiguous in memory
# Also matches NIfTI, making final save more efficient
resampled = np.zeros(
(len(ref_ndcoords), len(transform)), dtype=input_dtype, order="F"
)

# Order F ensures individual volumes are contiguous in memory
# Also matches NIfTI, making final save more efficient
resampled = np.zeros(
(len(ref_ndcoords), len(transform)), dtype=input_dtype, order="F"
resampled = asyncio.run(
_apply_serial(
data,
spatialimage,
targets,
transform,
_ref.ndim,
ref_ndcoords,
n_resamplings,
resampled,
input_dtype,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
max_concurrent=max_concurrent,
)

for future in as_completed(results):
t, resampled_t = future.result()
resampled[..., t] = resampled_t
)
else:
data = np.asanyarray(spatialimage.dataobj, dtype=input_dtype)

Expand Down
15 changes: 1 addition & 14 deletions nitransforms/tests/test_resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from nitransforms import nonlinear as nitnl
from nitransforms import manip as nitm
from nitransforms import io
from nitransforms.resampling import apply, _apply_volume
from nitransforms.resampling import apply

RMSE_TOL_LINEAR = 0.09
RMSE_TOL_NONLINEAR = 0.05
Expand Down Expand Up @@ -363,16 +363,3 @@ def test_LinearTransformsMapping_apply(
reference=testdata_path / "sbref.nii.gz",
serialize_nvols=2 if serialize_4d else np.inf,
)


@pytest.mark.parametrize("t", list(range(4)))
def test_apply_helper(monkeypatch, t):
"""Ensure the apply helper function correctly just decorates with index."""
from nitransforms.resampling import ndi

def _retval(*args, **kwargs):
return 1

monkeypatch.setattr(ndi, "map_coordinates", _retval)

assert _apply_volume(t, None, None) == (t, 1)

0 comments on commit 063e1f0

Please sign in to comment.