Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

6086 6087 nan to indicate no_channel, split dim singleton #6090

Merged
merged 3 commits into from
Mar 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 19 additions & 18 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from monai.data.utils import (
affine_to_spacing,
correct_nifti_header_if_necessary,
is_no_channel,
is_supported_format,
orientation_ras_lps,
)
Expand Down Expand Up @@ -162,7 +163,7 @@ def _copy_compatible_dict(from_dict: dict, to_dict: dict):
def _stack_images(image_list: list, meta_dict: dict):
if len(image_list) <= 1:
return image_list[0]
if meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None) not in ("no_channel", None):
if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)):
channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM])
return np.concatenate(image_list, axis=channel_dim)
# stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified
Expand Down Expand Up @@ -213,7 +214,7 @@ def __init__(
):
super().__init__()
self.kwargs = kwargs
self.channel_dim = channel_dim
self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim
self.series_name = series_name
self.reverse_indexing = reverse_indexing
self.series_meta = series_meta
Expand Down Expand Up @@ -305,7 +306,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i)
if self.channel_dim is None: # default to "no_channel" or -1
header[MetaKeys.ORIGINAL_CHANNEL_DIM] = (
"no_channel" if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1
float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1
)
else:
header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim
Expand Down Expand Up @@ -435,7 +436,7 @@ class PydicomReader(ImageReader):

def __init__(
self,
channel_dim: int | None = None,
channel_dim: str | int | None = None,
affine_lps_to_ras: bool = True,
swap_ij: bool = True,
prune_metadata: bool = True,
Expand All @@ -444,7 +445,7 @@ def __init__(
):
super().__init__()
self.kwargs = kwargs
self.channel_dim = channel_dim
self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim
self.affine_lps_to_ras = affine_lps_to_ras
self.swap_ij = swap_ij
self.prune_metadata = prune_metadata
Expand Down Expand Up @@ -629,7 +630,7 @@ def get_data(self, data) -> tuple[np.ndarray, dict]:
metadata[MetaKeys.AFFINE] = affine.copy()
if self.channel_dim is None: # default to "no_channel" or -1
metadata[MetaKeys.ORIGINAL_CHANNEL_DIM] = (
"no_channel" if len(data_array.shape) == len(metadata[MetaKeys.SPATIAL_SHAPE]) else -1
float("nan") if len(data_array.shape) == len(metadata[MetaKeys.SPATIAL_SHAPE]) else -1
)
else:
metadata[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim
Expand Down Expand Up @@ -883,14 +884,14 @@ class NibabelReader(ImageReader):
@deprecated_arg("dtype", since="1.0", msg_suffix="please modify dtype of the returned by ``get_data`` instead.")
def __init__(
self,
channel_dim: int | None = None,
channel_dim: str | int | None = None,
as_closest_canonical: bool = False,
squeeze_non_spatial_dims: bool = False,
dtype: DtypeLike = np.float32,
**kwargs,
):
super().__init__()
self.channel_dim = channel_dim
self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim
self.as_closest_canonical = as_closest_canonical
self.squeeze_non_spatial_dims = squeeze_non_spatial_dims
self.dtype = dtype # deprecated
Expand Down Expand Up @@ -965,7 +966,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
img_array.append(data)
if self.channel_dim is None: # default to "no_channel" or -1
header[MetaKeys.ORIGINAL_CHANNEL_DIM] = (
"no_channel" if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1
float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1
)
else:
header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim
Expand Down Expand Up @@ -1018,8 +1019,8 @@ def _get_spatial_shape(self, img):
dim = np.insert(dim, 0, 3)
ndim = dim[0]
size = list(dim[1:])
if self.channel_dim is not None:
size.pop(self.channel_dim)
if not is_no_channel(self.channel_dim):
size.pop(int(self.channel_dim)) # type: ignore
spatial_rank = max(min(ndim, 3), 1)
return np.asarray(size[:spatial_rank])

Expand Down Expand Up @@ -1049,12 +1050,12 @@ class NumpyReader(ImageReader):

"""

def __init__(self, npz_keys: KeysCollection | None = None, channel_dim: int | None = None, **kwargs):
def __init__(self, npz_keys: KeysCollection | None = None, channel_dim: str | int | None = None, **kwargs):
super().__init__()
if npz_keys is not None:
npz_keys = ensure_tuple(npz_keys)
self.npz_keys = npz_keys
self.channel_dim = channel_dim
self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim
self.kwargs = kwargs

def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:
Expand Down Expand Up @@ -1126,7 +1127,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
header[MetaKeys.SPACE] = SpaceKeys.RAS
img_array.append(i)
header[MetaKeys.ORIGINAL_CHANNEL_DIM] = (
self.channel_dim if isinstance(self.channel_dim, int) else "no_channel"
self.channel_dim if isinstance(self.channel_dim, int) else float("nan")
)
_copy_compatible_dict(header, compatible_meta)

Expand Down Expand Up @@ -1214,7 +1215,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
data = np.moveaxis(np.asarray(i), 0, 1) if self.reverse_indexing else np.asarray(i)
img_array.append(data)
header[MetaKeys.ORIGINAL_CHANNEL_DIM] = (
"no_channel" if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1
float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1
)
_copy_compatible_dict(header, compatible_meta)

Expand Down Expand Up @@ -1532,13 +1533,13 @@ class NrrdReader(ImageReader):

def __init__(
self,
channel_dim: int | None = None,
channel_dim: str | int | None = None,
dtype: np.dtype | type | str | None = np.float32,
index_order: str = "F",
affine_lps_to_ras: bool = True,
**kwargs,
):
self.channel_dim = channel_dim
self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim
self.dtype = dtype
self.index_order = index_order
self.affine_lps_to_ras = affine_lps_to_ras
Expand Down Expand Up @@ -1605,7 +1606,7 @@ def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]:

if self.channel_dim is None: # default to "no_channel" or -1
header[MetaKeys.ORIGINAL_CHANNEL_DIM] = (
"no_channel" if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else 0
float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else 0
)
else:
header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim
Expand Down
12 changes: 12 additions & 0 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
"remove_extra_metadata",
"get_extra_metadata_keys",
"PICKLE_KEY_SUFFIX",
"is_no_channel",
]

# module to be used by `torch.save`
Expand Down Expand Up @@ -1529,3 +1530,14 @@ def get_extra_metadata_keys() -> list[str]:
# ]

return keys


def is_no_channel(val) -> bool:
"""Returns whether `val` indicates "no_channel", for MetaKeys.ORIGINAL_CHANNEL_DIM."""
if isinstance(val, torch.Tensor):
return bool(torch.isnan(val))
if isinstance(val, str):
return val == "no_channel"
if np.isscalar(val):
return bool(np.isnan(val))
return val is None
15 changes: 7 additions & 8 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import no_collation
from monai.data.utils import is_no_channel, no_collation
from monai.networks.layers.simplelayers import (
ApplyFilter,
EllipticalFilter,
Expand All @@ -54,6 +54,7 @@
)
from monai.transforms.utils_pytorch_numpy_unification import concatenate, in1d, moveaxis, unravel_indices
from monai.utils import (
MetaKeys,
TraceKeys,
convert_data_type,
convert_to_cupy,
Expand Down Expand Up @@ -267,9 +268,9 @@ def __call__(self, img: torch.Tensor, meta_dict: Mapping | None = None) -> torch
if isinstance(img, MetaTensor):
meta_dict = img.meta

channel_dim = meta_dict.get("original_channel_dim", None) if isinstance(meta_dict, Mapping) else None
channel_dim = meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None) if isinstance(meta_dict, Mapping) else None
if self.input_channel_dim is not None:
channel_dim = self.input_channel_dim
channel_dim = float("nan") if self.input_channel_dim == "no_channel" else self.input_channel_dim

if channel_dim is None:
msg = "Unknown original_channel_dim in the MetaTensor meta dict or `meta_dict` or `channel_dim`."
Expand All @@ -280,12 +281,12 @@ def __call__(self, img: torch.Tensor, meta_dict: Mapping | None = None) -> torch

# track the original channel dim
if isinstance(meta_dict, dict):
meta_dict["original_channel_dim"] = channel_dim
meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = channel_dim

if channel_dim == "no_channel":
if is_no_channel(channel_dim):
result = img[None]
else:
result = moveaxis(img, channel_dim, 0) # type: ignore
result = moveaxis(img, int(channel_dim), 0) # type: ignore

return convert_to_tensor(result, track_meta=get_track_meta()) # type: ignore

Expand Down Expand Up @@ -371,8 +372,6 @@ def __call__(self, img: torch.Tensor) -> list[torch.Tensor]:
Apply the transform to `img`.
"""
n_out = img.shape[self.dim]
if n_out <= 1:
raise RuntimeError(f"Input image is singleton along dimension to be split, got shape {img.shape}.")
if isinstance(img, torch.Tensor):
outputs = list(torch.split(img, 1, self.dim))
else:
Expand Down
2 changes: 1 addition & 1 deletion monai/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ class MetaKeys(StrEnum):
ORIGINAL_AFFINE = "original_affine" # the affine after image loading before any data processing
SPATIAL_SHAPE = "spatial_shape" # optional key for the length in each spatial dimension
SPACE = "space" # possible values of space type are defined in `SpaceKeys`
ORIGINAL_CHANNEL_DIM = "original_channel_dim" # an integer or "no_channel"
ORIGINAL_CHANNEL_DIM = "original_channel_dim" # an integer or float("nan")


class ColorOrder(StrEnum):
Expand Down
7 changes: 3 additions & 4 deletions tests/test_splitdim.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,12 @@ def test_correct_shape(self, shape, keepdim, im_type):
arr[0, 0, 0, 0] *= 2
self.assertEqual(arr.flatten()[0], out[0].flatten()[0])

def test_error(self):
"""Should fail because splitting along singleton dimension"""
def test_singleton(self):
shape = (2, 1, 8, 7)
for p in TEST_NDARRAYS:
arr = p(np.random.rand(*shape))
with self.assertRaises(RuntimeError):
_ = SplitDim(dim=1)(arr)
out = SplitDim(dim=1)(arr)
self.assertEqual(out[0].shape, shape)


if __name__ == "__main__":
Expand Down
9 changes: 4 additions & 5 deletions tests/test_splitdimd.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def setUpClass(cls) -> None:
affine = make_rand_affine()
data = {"i": make_nifti_image(arr, affine)}

loader = LoadImaged("i")
loader = LoadImaged("i", image_only=True)
cls.data = loader(data)

@parameterized.expand(TESTS)
Expand Down Expand Up @@ -84,13 +84,12 @@ def test_correct(self, keepdim, im_type, update_meta, list_output):
arr[0, 0, 0, 0] *= 2
self.assertEqual(arr.flatten()[0], out.flatten()[0])

def test_error(self):
"""Should fail because splitting along singleton dimension"""
def test_singleton(self):
shape = (2, 1, 8, 7)
for p in TEST_NDARRAYS:
arr = p(np.random.rand(*shape))
with self.assertRaises(RuntimeError):
_ = SplitDimd("i", dim=1)({"i": arr})
out = SplitDimd("i", dim=1)({"i": arr})
self.assertEqual(out["i"].shape, shape)


if __name__ == "__main__":
Expand Down