-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Developer Guide Lazy Transforms
Transforms are refactored into multiple layers:
- functional
- array-based a. deterministic b. random
- dictionary-based a. deterministic b. random
Functional transforms are the base of any transform implementation. They are stateless implementations of the actual transform operation to be carried out. They are all capable of operating either in immediate mode (the operation is defined and then immediately applied) or in lazy mode (the operation is added to the metatensor pending list).
Functional transforms have the following pattern:
def functional_operation(
img: torch.Tensor,
..., # operation specific parameters
shape_override: Optional[Sequence[int]] = None,
lazy_evaluation: Optional[bool] = True
):
img_ = convert_to_tensor(img, track_meta=get_track_meta())
# the effective shape of the image can differ from the actual current shape,
# when an image has one or more pending transforms to be applied. Transforms
# typically need the shape that the image will have at the point this transform
# is carried out rather than the shape of the image at the point this transform
# is defined
input_shape = img_.shape if shape_override is None else shape_override
# this is typically needed to fully specify the transform
input_ndim = len(input_shape) - 1
transform = get_a_specific_homogenous_matrix_or_grid_describing_the_operation(...)
# this might be needed if the transform is known to change the shape of the
# resulting image
im_extents = extents_from_shape(input_shape)
im_extents = [transform @ e for e in im_extents]
output_shape = shape_from_extents(input_shape, im_extents)
# everything required to specify the transform at the point that it is applied
# note that shape_override should always be set as this is how chains of lazy
# transforms pass the correct image shape on to the next transform
metadata = {
...,
"shape_override": output_shape
}
# either apply the operation immediately or just append it to the pending list
return lazily_apply_op(img_, MetaMatrix(transform, metadata), lazy_evaluation)
lazily_apply_op
is defined as follows:
def lazily_apply_op(
tensor, op, lazy_evaluation
) -> Union[MetaTensor, Tuple[torch.Tensor, Optional[MetaMatrix]]]:
if isinstance(tensor, MetaTensor):
tensor.push_pending_operation(op)
if lazy_evaluation is False:
result = apply(tensor)
return result
else:
return tensor
else:
if lazy_evaluation is False:
result = apply(tensor, [op])
return result, None
else:
return tensor, op
As a rule, functional transform metadata should include: . the parameters specific to the transform (e.g. angles for rotation) . parameters influencing the operation, such as mode, padding_mode, etc. . shape_override, if the resulting shape differs from the shape passed in
Note: This section is a discussion of a design option. It is not currently the plan to implement this.
Instead of passing the overridden shape from functional transform to functional transform, it is possible to make all of the functional transforms functors that can be called at the point that the transform is actually applied. A functor transform would have the following implementation:
def functional_operation_functor(
img: torch.Tensor,
..., # operation specific parameters
lazy_evaluation: Optional[bool] = True
):
def _inner(inner_img):
img_ = convert_to_tensor(inner_img, track_meta=get_track_meta())
# the effective shape of the image can differ from the actual current shape,
# when an image has one or more pending transforms to be applied. Transforms
# typically need the shape that the image will have at the point this transform
# is carried out rather than the shape of the image at the point this transform
# is defined
input_shape = img_.shape if shape_override is None else shape_override
# this is typically needed to fully specify the transform
input_ndim = len(input_shape) - 1
transform = get_a_specific_homogenous_matrix_or_grid_describing_the_operation(...)
# this might be needed if the transform is known to change the shape of the
# resulting image
im_extents = extents_from_shape(input_shape)
im_extents = [transform @ e for e in im_extents]
shape_override_ = shape_from_extents(input_shape, im_extents)
# everything required to specify the transform at the point that it is applied
metadata = {
...
}
return lazily_apply_op(img, _inner, lazily_apply_op)
class ADeterministicArrayTransform(InvertibleTransform, LazyTransform):
backend = [TransformBackends.TORCH]
def __init__(
self,
..., # transform-specific arguments
lazy_evaluation: Optional[bool] = False
):
LazyTransform.__init__(self, lazy_evaluation)
# set member variables for transform-specific arguments
...
def __call__(
self,
img: NdarrayOrTensor,
..., # call-time transform-specific arguments
shape_override: Optional[Sequence] = None
) -> NdarrayOrTensor:
# determine transform-specific parameters to pass to function
...
shape_override_ = shape_override
if (shape_override_ is None and isinstance(img, MetaTensor) and
img.has_pending_transforms):
tx = img.peek_pending_transform()
shape_override_ = tx.metadata.get("shape_override", None)
img_t, _ = rotate(img, ..., shape_override_)
return img_t
def inverse(self, data):
raise NotImplementedError()
Random array transforms wrap the deterministic version of the transform:
class ARandomArrayTransform(RandomizableTransform, InvertibleTransform, LazyTrait):
def __init__(
self,
..., # transform-specific args
lazy_evaluation: Optional[bool] = True
):
RandomizableTransform.__init__(self, prob)
self.op = AnArrayTransform(...)
self.random_params = 0 # some default value
def randomize(self, data: Optional[Any] = None) -> None:
super().randomize(None)
if self._do_transform:
self.random_params = self.R.some_random_parameterized_value()
else:
self.random_params = 0 # the default value again
def __call__(
self,
img: NdarrayOrTensor,
..., # call-time transform-specific args
randomize: Optional[bool] = True,
shape_override: Optional[Sequence] = None
) -> NdarrayOrTensor:
if randomize:
self.randomize(data=img)
params = self.random_params
return self.op(img, self.random_params, ..., shape_override)
@property
def lazy_evaluation(self):
return self.op.lazy_evaluation
@lazy_evaluation.setter
def lazy_evaluation(self, value):
self.op.lazy_evaluation = value
def inverse(
self,
data: NdarrayOrTensor,
):
raise NotImplementedError()
class ARandomArrayTransform(InvertibleTransform, LazyTrait, RandomizableTrait):
def __init__(
self,
..., # transform-specific args
lazy_evaluation: Optional[bool] = True
):
self.randomizer = ARandomizer(..., prob)
self.op = ADeterministicArrayTransform(0, ..., lazy_evaluation)
def __call__(
self,
img: NdarrayOrTensor,
..., # call-time transform-specific args
shape_override: Optional[Sequence] = None
) -> NdarrayOrTensor:
angles = self.randomizer.sample(img)
# TODO: the random transforms have been implemented to make use of Array ops,
# which creates a problem if the operation name for "RandRotate" needs
# to be "RandRotate" instead of "Rotate". This can be done via several
# approaches:
# 1. Use the functional op directly
# 2. Pass an override to the array op for the name
return self.op(img, angles, mode, padding_mode, align_corners, shape_override)
@property
def lazy_evaluation(self):
return self.op.lazy_evaluation
@lazy_evaluation.setter
def lazy_evaluation(self, value):
self.op.lazy_evaluation = value
def inverse(
self,
data: NdarrayOrTensor,
):
raise NotImplementedError()
class ADeterministicDictionaryTransform(MapTransform, InvertibleTransform, LazyTrait):
backend = ADeterministicDictionaryTransform.backend
def __init__(
self,
keys: KeysCollection,
lazy_evaluation: Optional[bool] = True
...,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
# operation-specific member variables
...
self.op = ADeterministicArrayTransform(..., lazy_evaluation=lazy_evaluation)
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
rd = dict(data)
for key, mode, padding_mode, align_corners, dtype in self.key_iterator(
rd, self.mode, self.padding_mode, self.align_corners, self.dtype
):
rd[key] = self.op(rd[key], ...)
return rd
def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.op.inverse(d[key])
return d
class ARandomDictionaryTransform(MapTransform, InvertibleTransform,
LazyTrait, RandomizableTrait):
def __init__(
self,
keys: KeysCollection,
..., # transform-specific args
allow_missing_keys: Optional[bool] = False,
lazy_evaluation: Optional[bool] = True,
):
self.keys = keys
self.allow_missing_keys = allow_missing_keys
self.op = RandRotate2(
range_x, range_y, range_z, prob,
keep_size, mode, padding_mode, align_corners, dtype, lazy_evaluation
)
def __call__(self, data: Mapping[Hashable, torch.Tensor]):
rd = dict(data)
first_key = self.first_key(rd)
if first_key == ():
out = convert_to_tensor(rd, track_meta=get_track_meta())
return out
self.op.randomize(rd[first_key])
it = self.key_iterator(rd, self.mode, self.padding_mode, self.align_corners)
for key, mode, padding_mode, align_corners in it:
rd[key] = self.op(rd[key], mode=mode, padding_mode=padding_mode,
align_corners=align_corners, randomize=False)
return rd
def inverse(self, data):
raise NotImplementedError()