From fb66ba0625c3a64ba7cdba9811a9997b336e3702 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 11 Nov 2021 14:03:01 +0000 Subject: [PATCH] fixes resampling niftisaver (#3308) Signed-off-by: Wenqi Li --- monai/data/nifti_saver.py | 15 ++++++++------- tests/test_nifti_saver.py | 19 +++++++++++++++++++ 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/monai/data/nifti_saver.py b/monai/data/nifti_saver.py index b7067def73..427b2d29d5 100644 --- a/monai/data/nifti_saver.py +++ b/monai/data/nifti_saver.py @@ -56,7 +56,8 @@ def __init__( output_dir: output image directory. output_postfix: a string appended to all output file names. output_ext: output file extension name. - resample: whether to resample before saving the data array. + resample: whether to convert the data array to it's original coordinate system + based on `original_affine` in the `meta_data`. mode: {``"bilinear"``, ``"nearest"``} This option is used when ``resample = True``. Interpolation mode to calculate output values. Defaults to ``"bilinear"``. @@ -107,7 +108,7 @@ def __init__( def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: """ - Save data into a Nifti file. + Save data into a NIfTI file. The meta_data could optionally have the following keys: - ``'filename_or_obj'`` -- for output file name creation, corresponding to filename or object. @@ -116,7 +117,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] - ``'spatial_shape'`` -- for data output shape. - ``'patch_index'`` -- if the data is a patch of big image, append the patch index to filename. - When meta_data is specified, the saver will try to resample batch data from the space + When meta_data is specified and `resample=True`, the saver will try to resample batch data from the space defined by "affine" to the space defined by "original_affine". If meta_data is None, use the default index (starting from 0) as the filename. @@ -131,7 +132,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] """ filename = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index) self._data_index += 1 - original_affine = meta_data.get("original_affine", None) if meta_data else None + original_affine = meta_data.get("original_affine", None) if meta_data and self.resample else None affine = meta_data.get("affine", None) if meta_data else None spatial_shape = meta_data.get("spatial_shape", None) if meta_data else None patch_index = meta_data.get(Key.PATCH_INDEX, None) if meta_data else None @@ -151,7 +152,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] # change data shape to be (channel, h, w, d) while len(data.shape) < 4: data = np.expand_dims(data, -1) - # change data to "channel last" format and write to nifti format file + # change data to "channel last" format and write to NIfTI format file data = np.moveaxis(np.asarray(data), 0, -1) # if desired, remove trailing singleton dimensions @@ -164,7 +165,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] file_name=path, affine=affine, target_affine=original_affine, - resample=self.resample, + resample=True, output_spatial_shape=spatial_shape, mode=self.mode, padding_mode=self.padding_mode, @@ -178,7 +179,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: """ - Save a batch of data into Nifti format files. + Save a batch of data into NIfTI format files. Spatially it supports up to three dimensions, that is, H, HW, HWD for 1D, 2D, 3D respectively (with resampling supports for 2D and 3D only). diff --git a/tests/test_nifti_saver.py b/tests/test_nifti_saver.py index e22a6e6620..3cbb24c69e 100644 --- a/tests/test_nifti_saver.py +++ b/tests/test_nifti_saver.py @@ -64,6 +64,25 @@ def test_saved_3d_resize_content(self): filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + def test_saved_3d_no_resize_content(self): + with tempfile.TemporaryDirectory() as tempdir: + + saver = NiftiSaver( + output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32, resample=False + ) + + meta_data = { + "filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)], + "spatial_shape": [(10, 10, 2)] * 8, + "affine": [np.diag(np.ones(4)) * 5] * 8, + "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, + } + saver.save_batch(torch.randint(0, 255, (8, 8, 1, 2, 2)), meta_data) + for i in range(8): + filepath = os.path.join(tempdir, "testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") + img, _ = LoadImage("nibabelreader")(filepath) + self.assertEqual(img.shape, (1, 2, 2, 8)) + def test_squeeze_end_dims(self): with tempfile.TemporaryDirectory() as tempdir: