Skip to content

4855 5860 update the pending transform utilities #5916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,7 +1031,7 @@ def _get_array_data(self, img):
img: a Nibabel image object loaded from an image file.

"""
return np.asanyarray(img.dataobj)
return np.asanyarray(img.dataobj, order="C")


class NumpyReader(ImageReader):
Expand Down
17 changes: 12 additions & 5 deletions monai/data/meta_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
import numpy as np
import torch

from monai.utils.enums import TraceKeys
from monai.utils.misc import first
from monai.utils import TraceKeys, first, is_immutable

_TRACK_META = True

Expand Down Expand Up @@ -107,27 +106,35 @@ def flatten_meta_objs(*args: Iterable):
@staticmethod
def copy_items(data):
"""returns a copy of the data. list and dict are shallow copied for efficiency purposes."""
if is_immutable(data):
return data
if isinstance(data, (list, dict, np.ndarray)):
return data.copy()
if isinstance(data, torch.Tensor):
return data.detach().clone()
return deepcopy(data)

def copy_meta_from(self, input_objs, copy_attr=True) -> None:
def copy_meta_from(self, input_objs, copy_attr=True, keys=None):
"""
Copy metadata from a `MetaObj` or an iterable of `MetaObj` instances.

Args:
input_objs: list of `MetaObj` to copy data from.
copy_attr: whether to copy each attribute with `MetaObj.copy_item`.
note that if the attribute is a nested list or dict, only a shallow copy will be done.
keys: the keys of attributes to copy from the ``input_objs``.
If None, all keys from the input_objs will be copied.
"""
first_meta = input_objs if isinstance(input_objs, MetaObj) else first(input_objs, default=self)
if not hasattr(first_meta, "__dict__"):
return self
first_meta = first_meta.__dict__
keys = first_meta.keys() if keys is None else keys
if not copy_attr:
self.__dict__ = first_meta.copy() # shallow copy for performance
self.__dict__ = {a: first_meta[a] for a in keys if a in first_meta} # shallow copy for performance
else:
self.__dict__.update({a: MetaObj.copy_items(first_meta[a]) for a in first_meta})
self.__dict__.update({a: MetaObj.copy_items(first_meta[a]) for a in keys if a in first_meta})
return self

@staticmethod
def get_default_meta() -> dict:
Expand Down
24 changes: 17 additions & 7 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,15 +503,15 @@ def clone(self):

@staticmethod
def ensure_torch_and_prune_meta(
im: NdarrayTensor, meta: dict, simple_keys: bool = False, pattern: str | None = None, sep: str = "."
im: NdarrayTensor, meta: dict | None, simple_keys: bool = False, pattern: str | None = None, sep: str = "."
):
"""
Convert the image to `torch.Tensor`. If `affine` is in the `meta` dictionary,
Convert the image to MetaTensor (when meta is not None). If `affine` is in the `meta` dictionary,
convert that to `torch.Tensor`, too. Remove any superfluous metadata.

Args:
im: Input image (`np.ndarray` or `torch.Tensor`)
meta: Metadata dictionary.
meta: Metadata dictionary. When it's None, the metadata is not tracked, this method returns a torch.Tensor.
simple_keys: whether to keep only a simple subset of metadata keys.
pattern: combined with `sep`, a regular expression used to match and prune keys
in the metadata (nested dictionary), default to None, no key deletion.
Expand All @@ -521,14 +521,17 @@ def ensure_torch_and_prune_meta(

Returns:
By default, a `MetaTensor` is returned.
However, if `get_track_meta()` is `False`, a `torch.Tensor` is returned.
However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned.
"""
img = convert_to_tensor(im) # potentially ascontiguousarray
img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None) # potentially ascontiguousarray

# if not tracking metadata, return `torch.Tensor`
if not get_track_meta() or meta is None:
if not isinstance(img, MetaTensor):
return img

if meta is None:
meta = {}

# remove any superfluous metadata.
if simple_keys:
# ensure affine is of type `torch.Tensor`
Expand All @@ -540,7 +543,14 @@ def ensure_torch_and_prune_meta(
meta = monai.transforms.DeleteItemsd(keys=pattern, sep=sep, use_re=True)(meta)

# return the `MetaTensor`
return MetaTensor(img, meta=meta)
if meta is None:
meta = {}
img.meta = meta
if MetaKeys.AFFINE in meta:
img.affine = meta[MetaKeys.AFFINE] # this uses the affine property setter
else:
img.affine = MetaTensor.get_default_affine()
return img

def __repr__(self):
"""
Expand Down
2 changes: 2 additions & 0 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
ensure_tuple_size,
fall_back_tuple,
first,
get_equivalent_dtype,
issequenceiterable,
look_up_option,
optional_import,
Expand Down Expand Up @@ -924,6 +925,7 @@ def to_affine_nd(r: np.ndarray | int, affine: NdarrayTensor, dtype=np.float64) -
an (r+1) x (r+1) matrix (tensor or ndarray depends on the input ``affine`` data type)

"""
dtype = get_equivalent_dtype(dtype, np.ndarray)
affine_np = convert_data_type(affine, output_type=np.ndarray, dtype=dtype, wrap_sequence=True)[0]
affine_np = affine_np.copy()
if affine_np.ndim != 2:
Expand Down
188 changes: 137 additions & 51 deletions monai/transforms/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
import torch

from monai import transforms
from monai.data.meta_obj import MetaObj, get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.transforms.transform import Transform
from monai.utils.enums import TraceKeys
from monai.data.utils import to_affine_nd
from monai.transforms.transform import LazyTransform, Transform
from monai.utils import LazyAttr, MetaKeys, TraceKeys, convert_to_dst_type, convert_to_numpy, convert_to_tensor

__all__ = ["TraceableTransform", "InvertibleTransform"]

Expand Down Expand Up @@ -72,76 +74,160 @@ def trace_key(key: Hashable = None):
return f"{TraceKeys.KEY_SUFFIX}"
return f"{key}{TraceKeys.KEY_SUFFIX}"

def get_transform_info(
self, data, key: Hashable = None, extra_info: dict | None = None, orig_size: tuple | None = None
) -> dict:
@staticmethod
def transform_info_keys():
"""The keys to store necessary info of an applied transform."""
return (
TraceKeys.CLASS_NAME,
TraceKeys.ID,
TraceKeys.TRACING,
TraceKeys.LAZY_EVALUATION,
TraceKeys.DO_TRANSFORM,
)

def get_transform_info(self) -> dict:
"""
Return a dictionary with the relevant information pertaining to an applied transform.
"""
vals = (
self.__class__.__name__,
id(self),
self.tracing,
self.lazy_evaluation if isinstance(self, LazyTransform) else False,
self._do_transform if hasattr(self, "_do_transform") else True,
)
return dict(zip(self.transform_info_keys(), vals))

Args:
data: input data. Can be dictionary or MetaTensor. We can use `shape` to
determine the original size of the object (unless that has been given
explicitly, see `orig_size`).
key: if data is a dictionary, data[key] will be modified.
extra_info: if desired, any extra information pertaining to the applied
transform can be stored in this dictionary. These are often needed for
computing the inverse transformation.
orig_size: sometimes during the inverse it is useful to know what the size
of the original image was, in which case it can be supplied here.
def push_transform(self, data, *args, **kwargs):
"""
Push to a stack of applied transforms of ``data``.

Returns:
Dictionary of data pertaining to the applied transformation.
Args:
data: dictionary of data or `MetaTensor`.
args: additional positional arguments to track_transform_meta.
kwargs: additional keyword arguments to track_transform_meta,
set ``replace=True`` (default False) to rewrite the last transform infor in
applied_operation/pending_operation based on ``self.get_transform_info()``.
"""
info = {TraceKeys.CLASS_NAME: self.__class__.__name__, TraceKeys.ID: id(self)}
if orig_size is not None:
info[TraceKeys.ORIG_SIZE] = orig_size
elif isinstance(data, Mapping) and key in data and hasattr(data[key], "shape"):
info[TraceKeys.ORIG_SIZE] = data[key].shape[1:]
elif hasattr(data, "shape"):
info[TraceKeys.ORIG_SIZE] = data.shape[1:]
if extra_info is not None:
info[TraceKeys.EXTRA_INFO] = extra_info
# If class is randomizable transform, store whether the transform was actually performed (based on `prob`)
if hasattr(self, "_do_transform"): # RandomizableTransform
info[TraceKeys.DO_TRANSFORM] = self._do_transform
return info

def push_transform(
self, data, key: Hashable = None, extra_info: dict | None = None, orig_size: tuple | None = None
) -> None:
transform_info = self.get_transform_info()
lazy_eval = transform_info.get(TraceKeys.LAZY_EVALUATION, False)
do_transform = transform_info.get(TraceKeys.DO_TRANSFORM, True)
kwargs = kwargs or {}
replace = kwargs.pop("replace", False) # whether to rewrite the most recently pushed transform info
if replace and get_track_meta() and isinstance(data, MetaTensor):
if not lazy_eval:
xform = self.pop_transform(data, check=False) if do_transform else {}
meta_obj = self.push_transform(data, orig_size=xform.get(TraceKeys.ORIG_SIZE), extra_info=xform)
return data.copy_meta_from(meta_obj)
if do_transform:
xform = data.pending_operations.pop() # type: ignore
xform.update(transform_info)
meta_obj = self.push_transform(data, transform_info=xform, lazy_evaluation=lazy_eval)
return data.copy_meta_from(meta_obj)
return data
kwargs["lazy_evaluation"] = lazy_eval
if "transform_info" in kwargs and isinstance(kwargs["transform_info"], dict):
kwargs["transform_info"].update(transform_info)
else:
kwargs["transform_info"] = transform_info
meta_obj = TraceableTransform.track_transform_meta(data, *args, **kwargs)
return data.copy_meta_from(meta_obj) if isinstance(data, MetaTensor) else data

@classmethod
def track_transform_meta(
cls,
data,
key: Hashable = None,
sp_size=None,
affine=None,
extra_info: dict | None = None,
orig_size: tuple | None = None,
transform_info=None,
lazy_evaluation=False,
):
"""
Push to a stack of applied transforms.
Update a stack of applied/pending transforms metadata of ``data``.

Args:
data: dictionary of data or `MetaTensor`.
key: if data is a dictionary, data[key] will be modified.
sp_size: the expected output spatial size when the transform is applied.
it can be tensor or numpy, but will be converted to a list of integers.
affine: the affine representation of the (spatial) transform in the image space.
When the transform is applied, meta_tensor.affine will be updated to ``meta_tensor.affine @ affine``.
extra_info: if desired, any extra information pertaining to the applied
transform can be stored in this dictionary. These are often needed for
computing the inverse transformation.
orig_size: sometimes during the inverse it is useful to know what the size
of the original image was, in which case it can be supplied here.
transform_info: info from self.get_transform_info().
lazy_evaluation: whether to push the transform to pending_operations or applied_operations.

Returns:
None, but data has been updated to store the applied transformation.

For backward compatibility, if ``data`` is a dictionary, it returns the dictionary with
updated ``data[key]``. Otherwise, this function returns a MetaObj with updated transform metadata.
"""
if not self.tracing:
return
info = self.get_transform_info(data, key, extra_info, orig_size)
data_t = data[key] if key is not None else data # compatible with the dict data representation
out_obj = MetaObj()
# after deprecating metadict, we should always convert data_t to metatensor here
if isinstance(data_t, MetaTensor):
out_obj.copy_meta_from(data_t, keys=out_obj.__dict__.keys())

if not lazy_evaluation and affine is not None and isinstance(data_t, MetaTensor):
# not lazy evaluation, directly update the metatensor affine (don't push to the stack)
orig_affine = data_t.peek_pending_affine()
orig_affine = convert_to_dst_type(orig_affine, affine)[0]
affine = orig_affine @ to_affine_nd(len(orig_affine) - 1, affine, dtype=affine.dtype)
out_obj.meta[MetaKeys.AFFINE] = convert_to_tensor(affine, device=torch.device("cpu"))

if not (get_track_meta() and transform_info and transform_info.get(TraceKeys.TRACING)):
if isinstance(data, Mapping):
if not isinstance(data, dict):
data = dict(data)
data[key] = data_t.copy_meta_from(out_obj) if isinstance(data_t, MetaTensor) else data_t
return data
return out_obj # return with data_t as tensor if get_track_meta() is False

info = transform_info
# track the current spatial shape
if orig_size is not None:
info[TraceKeys.ORIG_SIZE] = orig_size
elif isinstance(data_t, MetaTensor):
info[TraceKeys.ORIG_SIZE] = data_t.peek_pending_shape()
elif hasattr(data_t, "shape"):
info[TraceKeys.ORIG_SIZE] = data_t.shape[1:]
# include extra_info
if extra_info is not None:
info[TraceKeys.EXTRA_INFO] = extra_info

if isinstance(data, MetaTensor):
data.push_applied_operation(info)
elif isinstance(data, Mapping):
if key in data and isinstance(data[key], MetaTensor):
data[key].push_applied_operation(info)
# push the transform info to the applied_operation or pending_operation stack
if lazy_evaluation:
if sp_size is None:
if LazyAttr.SHAPE not in info:
warnings.warn("spatial size is None in push transform.")
else:
info[LazyAttr.SHAPE] = tuple(convert_to_numpy(sp_size, wrap_sequence=True).tolist())
if affine is None:
if LazyAttr.AFFINE not in info:
warnings.warn("affine is None in push transform.")
else:
# If this is the first, create list
if self.trace_key(key) not in data:
if not isinstance(data, dict):
data = dict(data)
data[self.trace_key(key)] = []
data[self.trace_key(key)].append(info)
info[LazyAttr.AFFINE] = convert_to_tensor(affine, device=torch.device("cpu"))
out_obj.push_pending_operation(info)
else:
warnings.warn(f"`data` should be either `MetaTensor` or dictionary, got {type(data)}. {info} not tracked.")
out_obj.push_applied_operation(info)
if isinstance(data, Mapping):
if not isinstance(data, dict):
data = dict(data)
if isinstance(data_t, MetaTensor):
data[key] = data_t.copy_meta_from(out_obj)
else:
x_k = TraceableTransform.trace_key(key)
if x_k not in data:
data[x_k] = [] # If this is the first, create list
data[x_k].append(info)
return data
return out_obj

def check_transforms_match(self, transform: Mapping) -> None:
"""Check transforms are of same instance."""
Expand Down
Loading