Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

6007 reverse_indexing for PILReader #6008

Merged
merged 3 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,13 +1141,17 @@ class PILReader(ImageReader):
Args:
converter: additional function to convert the image data after `read()`.
for example, use `converter=lambda image: image.convert("LA")` to convert image format.
reverse_indexing: whether to swap axis 0 and 1 after loading the array, this is enabled by default,
so that output of the reader is consistent with the other readers. Set this option to ``False`` to use
the PIL backend's original spatial axes convention.
kwargs: additional args for `Image.open` API in `read()`, mode details about available args:
https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.open
"""

def __init__(self, converter: Callable | None = None, **kwargs):
def __init__(self, converter: Callable | None = None, reverse_indexing: bool = True, **kwargs):
super().__init__()
self.converter = converter
self.reverse_indexing = reverse_indexing
self.kwargs = kwargs

def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:
Expand Down Expand Up @@ -1194,8 +1198,8 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
It computes `spatial_shape` and stores it in meta dict.
When loading a list of files, they are stacked together at a new dimension as the first dimension,
and the metadata of the first image is used to represent the output metadata.
Note that it will swap axis 0 and 1 after loading the array because the `HW` definition in PIL
is different from other common medical packages.
Note that by default `self.reverse_indexing` is set to ``True``, which swaps axis 0 and 1 after loading
the array because the spatial axes definition in PIL is different from other common medical packages.

Args:
img: a PIL Image object loaded from a file or a list of PIL Image objects.
Expand All @@ -1207,7 +1211,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
for i in ensure_tuple(img):
header = self._get_meta_dict(i)
header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i)
data = np.moveaxis(np.asarray(i), 0, 1)
data = np.moveaxis(np.asarray(i), 0, 1) if self.reverse_indexing else np.asarray(i)
img_array.append(data)
header[MetaKeys.ORIGINAL_CHANNEL_DIM] = (
"no_channel" if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1
Expand Down
6 changes: 3 additions & 3 deletions monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ class LoadImage(Transform):
- Current default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader),
(npz, npy -> NumpyReader), (nrrd -> NrrdReader), (DICOM file -> ITKReader).

Please note that for png, jpg, bmp, and other 2D formats, readers often swap axis 0 and 1 after
loading the array because the `HW` definition for non-medical specific file formats is different
from other common medical packages.
Please note that for png, jpg, bmp, and other 2D formats, readers by default swap axis 0 and 1 after
loading the array with ``reverse_indexing`` set to ``True`` because the spatial axes definition
for non-medical specific file formats is different from other common medical packages.

See also:

Expand Down
6 changes: 3 additions & 3 deletions monai/transforms/io/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ class LoadImaged(MapTransform):
- Current default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader),
(npz, npy -> NumpyReader), (dcm, DICOM series and others -> ITKReader).

Please note that for png, jpg, bmp, and other 2D formats, readers often swap axis 0 and 1 after
loading the array because the `HW` definition for non-medical specific file formats is different
from other common medical packages.
Please note that for png, jpg, bmp, and other 2D formats, readers by default swap axis 0 and 1 after
loading the array with ``reverse_indexing`` set to ``True`` because the spatial axes definition
for non-medical specific file formats is different from other common medical packages.

Note:

Expand Down
9 changes: 5 additions & 4 deletions tests/test_pil_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

TEST_CASE_2 = [(128, 128, 3), ["test_image.png"], (128, 128, 3), (128, 128)]

TEST_CASE_3 = [(128, 128, 4), ["test_image.png"], (128, 128, 4), (128, 128)]
TEST_CASE_3 = [(128, 128, 4), ["test_image.png"], (128, 128, 4), (128, 128), False]

TEST_CASE_4 = [(128, 128), ["test_image1.png", "test_image2.png", "test_image3.png"], (3, 128, 128), (128, 128)]

Expand All @@ -38,20 +38,21 @@

class TestPNGReader(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6])
def test_shape_value(self, data_shape, filenames, expected_shape, meta_shape):
def test_shape_value(self, data_shape, filenames, expected_shape, meta_shape, reverse=True):
test_image = np.random.randint(0, 256, size=data_shape)
with tempfile.TemporaryDirectory() as tempdir:
for i, name in enumerate(filenames):
filenames[i] = os.path.join(tempdir, name)
Image.fromarray(test_image.astype("uint8")).save(filenames[i])
reader = PILReader(mode="r")
reader = PILReader(mode="r", reverse_indexing=reverse)
result = reader.get_data(reader.read(filenames))
# load image by PIL and compare the result
test_image = np.asarray(Image.open(filenames[0]))

self.assertTupleEqual(tuple(result[1]["spatial_shape"]), meta_shape)
self.assertTupleEqual(result[0].shape, expected_shape)
test_image = np.moveaxis(test_image, 0, 1)
if reverse:
test_image = np.moveaxis(test_image, 0, 1)
if result[0].shape == test_image.shape:
np.testing.assert_allclose(result[0], test_image)
else:
Expand Down