diff --git a/clinica/utils/testing_utils.py b/clinica/utils/testing_utils.py index 080f6ecd76..a3f935e677 100644 --- a/clinica/utils/testing_utils.py +++ b/clinica/utils/testing_utils.py @@ -1,6 +1,11 @@ import json import os +from functools import partial from pathlib import Path +from typing import Callable, Optional + +import nibabel as nib +from numpy.testing import assert_array_almost_equal, assert_array_equal def build_bids_directory(directory: os.PathLike, subjects_sessions: dict) -> None: @@ -199,3 +204,103 @@ def rmtree(f: Path): for child in f.iterdir(): rmtree(child) f.rmdir() + + +def _assert_nifti_relation( + img1: Path, + img2: Path, + assertion_func_data: Callable, + assertion_func_affine: Optional[Callable] = None, +) -> None: + """Assert that two nifti images satisfy some relationship. + + Parameters + ---------- + img1 : Path + Path to the first image. + + img2 : Path + Path to the second image. + + assertion_func_data : Callable + Assertion function for dataobj comparison. + The function should take two numpy arrays as input. + + assertion_func_affine : Callable, optional + Assertion function for affine comparison. + The function should take two numpy arrays as input. + If not specified, `assertion_func_data` will be used to + check affine matrices. + """ + assertion_func_affine = assertion_func_affine or assertion_func_data + img1 = nib.load(img1) + img2 = nib.load(img2) + + assert img1.shape == img2.shape # Fail fast + assertion_func_affine(img1, img2) + assertion_func_data(img1, img2) + + +def _assert_affine_equal(img1: nib.Nifti1Image, img2: nib.Nifti1Image) -> None: + assert_array_equal(img1.affine, img2.affine) + + +def _assert_dataobj_equal(img1: nib.Nifti1Image, img2: nib.Nifti1Image) -> None: + assert_array_equal(img1.get_fdata(), img2.get_fdata()) + + +def _assert_affine_almost_equal( + img1: nib.Nifti1Image, img2: nib.Nifti1Image, decimal: int = 6 +) -> None: + assert_array_almost_equal(img1.affine, img2.affine, decimal=decimal) + + +def _assert_dataobj_almost_equal( + img1: nib.Nifti1Image, img2: nib.Nifti1Image, decimal: int = 6 +) -> None: + assert_array_almost_equal(img1.get_fdata(), img2.get_fdata(), decimal=decimal) + + +def _assert_large_image_dataobj_almost_equal( + img1: nib.Nifti1Image, + img2: nib.Nifti1Image, + decimal: int = 6, + n_samples: Optional[int] = None, + verbose: bool = False, +) -> None: + import random + + import numpy as np + + volumes = range(0, img1.shape[-1]) + if n_samples: + volumes = random.sample(volumes, n_samples) + for volume in volumes: + if verbose: + print(f"--> Processing volume {volume}...") + assert_array_almost_equal( + np.asarray(img1.dataobj[..., volume]), + np.asarray(img2.dataobj[..., volume]), + decimal=decimal, + ) + + +assert_nifti_equal = partial( + _assert_nifti_relation, + assertion_func_data=_assert_dataobj_equal, + assertion_func_affine=_assert_affine_equal, +) + + +assert_nifti_almost_equal = partial( + _assert_nifti_relation, + assertion_func_data=_assert_dataobj_almost_equal, + assertion_func_affine=_assert_affine_almost_equal, +) + + +assert_large_nifti_almost_equal = partial( + _assert_nifti_relation, + assertion_func_data=_assert_large_image_dataobj_almost_equal, + assertion_func_affine=_assert_affine_almost_equal, +) diff --git a/test/nonregression/pipelines/test_run_pipelines_dwi.py b/test/nonregression/pipelines/test_run_pipelines_dwi.py index 49ca423fd0..897400e793 100644 --- a/test/nonregression/pipelines/test_run_pipelines_dwi.py +++ b/test/nonregression/pipelines/test_run_pipelines_dwi.py @@ -15,6 +15,11 @@ import pytest from numpy.testing import assert_array_almost_equal +from clinica.utils.testing_utils import ( + assert_large_nifti_almost_equal, + assert_nifti_almost_equal, +) + # Determine location for working_directory warnings.filterwarnings("ignore") @@ -77,31 +82,19 @@ def test_dwi_perform_ants_registration(cmdopt, tmp_path): ) ref_file = fspath(ref_dir / "transform1Warp.nii.gz") - out_img = nib.load(out_file) - ref_img = nib.load(ref_file) - - # assert similarity_measure(out_file, ref_file, 0.97) - assert_array_almost_equal(out_img.get_fdata(), ref_img.get_fdata()) + assert_nifti_almost_equal(out_file, ref_file) out_file = fspath( tmp_path / "tmp" / "epi_correction_image_warped" / "transformWarped.nii.gz" ) ref_file = fspath(ref_dir / "transformWarped.nii.gz") - out_img = nib.load(out_file) - ref_img = nib.load(ref_file) - - # assert similarity_measure(out_file, ref_file, 0.97) - assert_array_almost_equal(out_img.get_fdata(), ref_img.get_fdata()) + assert_nifti_almost_equal(out_file, ref_file) out_file = fspath(tmp_path / "tmp" / "merged_transforms" / "transform1Warp.nii.gz") ref_file = fspath(ref_dir / "merged_transform.nii.gz") - out_img = nib.load(out_file) - ref_img = nib.load(ref_file) - - # assert similarity_measure(out_file, ref_file, 0.97) - assert_array_almost_equal(out_img.get_fdata(), ref_img.get_fdata()) + assert_nifti_almost_equal(out_file, ref_file) out_file = fspath( tmp_path / "tmp" / "rotated_b_vectors" / "sub-01_ses-M000_dwi_rotated.bvec" @@ -149,13 +142,7 @@ def test_dwi_perform_dwi_epi_correction(cmdopt, tmp_path): ) ref_file = fspath(ref_dir / "Jacobian_image_maths_thresh_merged.nii.gz") - out_img = nib.load(out_file) - ref_img = nib.load(ref_file) - - assert out_img.shape == ref_img.shape - assert_array_almost_equal(out_img.affine, ref_img.affine) - - assert similarity_measure(out_file, ref_file, 0.97) + assert_large_nifti_almost_equal(out_file, ref_file) @pytest.mark.slow