Skip to content

Commit

Permalink
fixes resampling niftisaver (#3308)
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli authored Nov 11, 2021
1 parent 8076372 commit fb66ba0
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
15 changes: 8 additions & 7 deletions monai/data/nifti_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def __init__(
output_dir: output image directory.
output_postfix: a string appended to all output file names.
output_ext: output file extension name.
resample: whether to resample before saving the data array.
resample: whether to convert the data array to it's original coordinate system
based on `original_affine` in the `meta_data`.
mode: {``"bilinear"``, ``"nearest"``}
This option is used when ``resample = True``.
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
Expand Down Expand Up @@ -107,7 +108,7 @@ def __init__(

def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None:
"""
Save data into a Nifti file.
Save data into a NIfTI file.
The meta_data could optionally have the following keys:
- ``'filename_or_obj'`` -- for output file name creation, corresponding to filename or object.
Expand All @@ -116,7 +117,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
- ``'spatial_shape'`` -- for data output shape.
- ``'patch_index'`` -- if the data is a patch of big image, append the patch index to filename.
When meta_data is specified, the saver will try to resample batch data from the space
When meta_data is specified and `resample=True`, the saver will try to resample batch data from the space
defined by "affine" to the space defined by "original_affine".
If meta_data is None, use the default index (starting from 0) as the filename.
Expand All @@ -131,7 +132,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
"""
filename = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index)
self._data_index += 1
original_affine = meta_data.get("original_affine", None) if meta_data else None
original_affine = meta_data.get("original_affine", None) if meta_data and self.resample else None
affine = meta_data.get("affine", None) if meta_data else None
spatial_shape = meta_data.get("spatial_shape", None) if meta_data else None
patch_index = meta_data.get(Key.PATCH_INDEX, None) if meta_data else None
Expand All @@ -151,7 +152,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
# change data shape to be (channel, h, w, d)
while len(data.shape) < 4:
data = np.expand_dims(data, -1)
# change data to "channel last" format and write to nifti format file
# change data to "channel last" format and write to NIfTI format file
data = np.moveaxis(np.asarray(data), 0, -1)

# if desired, remove trailing singleton dimensions
Expand All @@ -164,7 +165,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
file_name=path,
affine=affine,
target_affine=original_affine,
resample=self.resample,
resample=True,
output_spatial_shape=spatial_shape,
mode=self.mode,
padding_mode=self.padding_mode,
Expand All @@ -178,7 +179,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]

def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None:
"""
Save a batch of data into Nifti format files.
Save a batch of data into NIfTI format files.
Spatially it supports up to three dimensions, that is, H, HW, HWD for
1D, 2D, 3D respectively (with resampling supports for 2D and 3D only).
Expand Down
19 changes: 19 additions & 0 deletions tests/test_nifti_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,25 @@ def test_saved_3d_resize_content(self):
filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz")
self.assertTrue(os.path.exists(os.path.join(tempdir, filepath)))

def test_saved_3d_no_resize_content(self):
with tempfile.TemporaryDirectory() as tempdir:

saver = NiftiSaver(
output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32, resample=False
)

meta_data = {
"filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)],
"spatial_shape": [(10, 10, 2)] * 8,
"affine": [np.diag(np.ones(4)) * 5] * 8,
"original_affine": [np.diag(np.ones(4)) * 1.0] * 8,
}
saver.save_batch(torch.randint(0, 255, (8, 8, 1, 2, 2)), meta_data)
for i in range(8):
filepath = os.path.join(tempdir, "testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz")
img, _ = LoadImage("nibabelreader")(filepath)
self.assertEqual(img.shape, (1, 2, 2, 8))

def test_squeeze_end_dims(self):
with tempfile.TemporaryDirectory() as tempdir:

Expand Down

0 comments on commit fb66ba0

Please sign in to comment.