Skip to content

Commit

Permalink
6007 reverse_indexing for PILReader (#6008)
Browse files Browse the repository at this point in the history
Fixes #6007

### Description
- reverse_indexing = False:
to support consistency with PIL/torchvision
```py
img = LoadImage(image_only=True, ensure_channel_first=True, reverse_indexing=False)("MONAI-logo_color.png")  # PILReader
torchvision.utils.save_image(img, "MONAI-logo_color_torchvision.png", normalize=True)
```
- reverse_indexing = True:
to support consistency with other backends in monai
```py
img = LoadImage(image_only=True, ensure_channel_first=True, reader="PILReader", reverse_indexing=True)(filename)  # PIL backend
img_1 = LoadImage(image_only=True, ensure_channel_first=True, reader="ITKReader")(filename)  # itk backend
np.testing.assert_allclose(img, img_1)  # true
```

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli authored Feb 16, 2023
1 parent f4902b2 commit 94e9e17
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 14 deletions.
12 changes: 8 additions & 4 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,13 +1141,17 @@ class PILReader(ImageReader):
Args:
converter: additional function to convert the image data after `read()`.
for example, use `converter=lambda image: image.convert("LA")` to convert image format.
reverse_indexing: whether to swap axis 0 and 1 after loading the array, this is enabled by default,
so that output of the reader is consistent with the other readers. Set this option to ``False`` to use
the PIL backend's original spatial axes convention.
kwargs: additional args for `Image.open` API in `read()`, mode details about available args:
https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.open
"""

def __init__(self, converter: Callable | None = None, **kwargs):
def __init__(self, converter: Callable | None = None, reverse_indexing: bool = True, **kwargs):
super().__init__()
self.converter = converter
self.reverse_indexing = reverse_indexing
self.kwargs = kwargs

def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:
Expand Down Expand Up @@ -1194,8 +1198,8 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
It computes `spatial_shape` and stores it in meta dict.
When loading a list of files, they are stacked together at a new dimension as the first dimension,
and the metadata of the first image is used to represent the output metadata.
Note that it will swap axis 0 and 1 after loading the array because the `HW` definition in PIL
is different from other common medical packages.
Note that by default `self.reverse_indexing` is set to ``True``, which swaps axis 0 and 1 after loading
the array because the spatial axes definition in PIL is different from other common medical packages.
Args:
img: a PIL Image object loaded from a file or a list of PIL Image objects.
Expand All @@ -1207,7 +1211,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
for i in ensure_tuple(img):
header = self._get_meta_dict(i)
header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i)
data = np.moveaxis(np.asarray(i), 0, 1)
data = np.moveaxis(np.asarray(i), 0, 1) if self.reverse_indexing else np.asarray(i)
img_array.append(data)
header[MetaKeys.ORIGINAL_CHANNEL_DIM] = (
"no_channel" if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1
Expand Down
6 changes: 3 additions & 3 deletions monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ class LoadImage(Transform):
- Current default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader),
(npz, npy -> NumpyReader), (nrrd -> NrrdReader), (DICOM file -> ITKReader).
Please note that for png, jpg, bmp, and other 2D formats, readers often swap axis 0 and 1 after
loading the array because the `HW` definition for non-medical specific file formats is different
from other common medical packages.
Please note that for png, jpg, bmp, and other 2D formats, readers by default swap axis 0 and 1 after
loading the array with ``reverse_indexing`` set to ``True`` because the spatial axes definition
for non-medical specific file formats is different from other common medical packages.
See also:
Expand Down
6 changes: 3 additions & 3 deletions monai/transforms/io/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ class LoadImaged(MapTransform):
- Current default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader),
(npz, npy -> NumpyReader), (dcm, DICOM series and others -> ITKReader).
Please note that for png, jpg, bmp, and other 2D formats, readers often swap axis 0 and 1 after
loading the array because the `HW` definition for non-medical specific file formats is different
from other common medical packages.
Please note that for png, jpg, bmp, and other 2D formats, readers by default swap axis 0 and 1 after
loading the array with ``reverse_indexing`` set to ``True`` because the spatial axes definition
for non-medical specific file formats is different from other common medical packages.
Note:
Expand Down
9 changes: 5 additions & 4 deletions tests/test_pil_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

TEST_CASE_2 = [(128, 128, 3), ["test_image.png"], (128, 128, 3), (128, 128)]

TEST_CASE_3 = [(128, 128, 4), ["test_image.png"], (128, 128, 4), (128, 128)]
TEST_CASE_3 = [(128, 128, 4), ["test_image.png"], (128, 128, 4), (128, 128), False]

TEST_CASE_4 = [(128, 128), ["test_image1.png", "test_image2.png", "test_image3.png"], (3, 128, 128), (128, 128)]

Expand All @@ -38,20 +38,21 @@

class TestPNGReader(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6])
def test_shape_value(self, data_shape, filenames, expected_shape, meta_shape):
def test_shape_value(self, data_shape, filenames, expected_shape, meta_shape, reverse=True):
test_image = np.random.randint(0, 256, size=data_shape)
with tempfile.TemporaryDirectory() as tempdir:
for i, name in enumerate(filenames):
filenames[i] = os.path.join(tempdir, name)
Image.fromarray(test_image.astype("uint8")).save(filenames[i])
reader = PILReader(mode="r")
reader = PILReader(mode="r", reverse_indexing=reverse)
result = reader.get_data(reader.read(filenames))
# load image by PIL and compare the result
test_image = np.asarray(Image.open(filenames[0]))

self.assertTupleEqual(tuple(result[1]["spatial_shape"]), meta_shape)
self.assertTupleEqual(result[0].shape, expected_shape)
test_image = np.moveaxis(test_image, 0, 1)
if reverse:
test_image = np.moveaxis(test_image, 0, 1)
if result[0].shape == test_image.shape:
np.testing.assert_allclose(result[0], test_image)
else:
Expand Down

0 comments on commit 94e9e17

Please sign in to comment.