Skip to content

Developer Guide Lazy Transforms

Ben Murray edited this page Nov 23, 2022 · 8 revisions

Lazy Resampling Transform Redesign

Transforms are refactored into multiple layers:

  1. functional
  2. array-based a. deterministic b. random
  3. dictionary-based a. deterministic b. random

Functional transforms

Functional transforms are the base of any transform implementation. They are stateless implementations of the actual transform operation to be carried out. They are all capable of operating either in immediate mode (the operation is defined and then immediately applied) or in lazy mode (the operation is added to the metatensor pending list).

Functional transforms have the following pattern:

def functional_operation(
    img: torch.Tensor,
    ..., # operation specific parameters
    shape_override: Optional[Sequence[int]] = None,
    lazy_evaluation: Optional[bool] = True
):
    img_ = convert_to_tensor(img, track_meta=get_track_meta())

    # the effective shape of the image can differ from the actual current shape,
    # when an image has one or more pending transforms to be applied. Transforms
    # typically need the shape that the image will have at the point this transform
    # is carried out rather than the shape of the image at the point this transform
    # is defined
    input_shape = img_.shape if shape_override is None else shape_override

    # this is typically needed to fully specify the transform
    input_ndim = len(input_shape) - 1

    transform = get_a_specific_homogenous_matrix_or_grid_describing_the_operation(...)

    # this might be needed if the transform is known to change the shape of the
    # resulting image
    im_extents = extents_from_shape(input_shape)
    im_extents = [transform @ e for e in im_extents]
    output_shape = shape_from_extents(input_shape, im_extents)

    # everything required to specify the transform at the point that it is applied
    # note that shape_override should always be set as this is how chains of lazy
    # transforms pass the correct image shape on to the next transform
    metadata = {
        ...,
        "shape_override": output_shape
    }

    # either apply the operation immediately or just append it to the pending list
    return lazily_apply_op(img_, MetaMatrix(transform, metadata), lazy_evaluation)

lazily_apply_op is defined as follows:

def lazily_apply_op(
        tensor, op, lazy_evaluation
) -> Union[MetaTensor, Tuple[torch.Tensor, Optional[MetaMatrix]]]:

    if isinstance(tensor, MetaTensor):
        tensor.push_pending_operation(op)
        if lazy_evaluation is False:
            result = apply(tensor)
            return result
        else:
            return tensor
    else:
        if lazy_evaluation is False:
            result = apply(tensor, [op])
            return result, None
        else:
            return tensor, op

As a rule, functional transform metadata should include: . the parameters specific to the transform (e.g. angles for rotation) . parameters influencing the operation, such as mode, padding_mode, etc. . shape_override, if the resulting shape differs from the shape passed in

Alternative design using functors to defer transform calculation

Note: This section is a discussion of a design option. It is not currently the plan to implement this.

Instead of passing the overridden shape from functional transform to functional transform, it is possible to make all of the functional transforms functors that can be called at the point that the transform is actually applied. A functor transform would have the following implementation:

def functional_operation_functor(
    img: torch.Tensor,
    ..., # operation specific parameters
    lazy_evaluation: Optional[bool] = True
):
    def _inner(inner_img):
        img_ = convert_to_tensor(inner_img, track_meta=get_track_meta())

        # the effective shape of the image can differ from the actual current shape,
        # when an image has one or more pending transforms to be applied. Transforms
        # typically need the shape that the image will have at the point this transform
        # is carried out rather than the shape of the image at the point this transform
        # is defined
        input_shape = img_.shape if shape_override is None else shape_override

        # this is typically needed to fully specify the transform
        input_ndim = len(input_shape) - 1

        transform = get_a_specific_homogenous_matrix_or_grid_describing_the_operation(...)

        # this might be needed if the transform is known to change the shape of the
        # resulting image
        im_extents = extents_from_shape(input_shape)
        im_extents = [transform @ e for e in im_extents]
        shape_override_ = shape_from_extents(input_shape, im_extents)

        # everything required to specify the transform at the point that it is applied
        metadata = {
            ...
        }

    return lazily_apply_op(img, _inner, lazily_apply_op)

Array transforms

Deterministic array transforms

class ADeterministicArrayTransform(InvertibleTransform, LazyTransform):

    backend = [TransformBackends.TORCH]

    def __init__(
            self,
            ..., # transform-specific arguments
            lazy_evaluation: Optional[bool] = False
    ):
        LazyTransform.__init__(self, lazy_evaluation)
        # set member variables for transform-specific arguments
        ...

    def __call__(
            self,
            img: NdarrayOrTensor,
            ..., # call-time transform-specific arguments
            shape_override: Optional[Sequence] = None
    ) -> NdarrayOrTensor:

        # determine transform-specific parameters to pass to function
        ...

        shape_override_ = shape_override
        if (shape_override_ is None and isinstance(img, MetaTensor) and
            img.has_pending_transforms):
            tx = img.peek_pending_transform()
            shape_override_ = tx.metadata.get("shape_override", None)

        img_t, _ = rotate(img, ..., shape_override_)

        return img_t

    def inverse(self, data):
        raise NotImplementedError()

Random array transforms

Random array transforms wrap the deterministic version of the transform:

class ARandomArrayTransform(RandomizableTransform, InvertibleTransform, LazyTrait):
    def __init__(
            self,
            ..., # transform-specific args
            lazy_evaluation: Optional[bool] = True
    ):
        RandomizableTransform.__init__(self, prob)

        self.op = AnArrayTransform(...)

        self.random_params = 0 # some default value

    def randomize(self, data: Optional[Any] = None) -> None:
        super().randomize(None)
        if self._do_transform:
            self.random_params = self.R.some_random_parameterized_value()
        else:
            self.random_params = 0 # the default value again

    def __call__(
            self,
            img: NdarrayOrTensor,
            ..., # call-time transform-specific args
            randomize: Optional[bool] = True,
            shape_override: Optional[Sequence] = None
    ) -> NdarrayOrTensor:
        if randomize:
            self.randomize(data=img)

        params = self.random_params

        return self.op(img, self.random_params, ..., shape_override)

    @property
    def lazy_evaluation(self):
        return self.op.lazy_evaluation

    @lazy_evaluation.setter
    def lazy_evaluation(self, value):
        self.op.lazy_evaluation = value

    def inverse(
            self,
            data: NdarrayOrTensor,
    ):
        raise NotImplementedError()

Alternative design using external randomizer

class ARandomArrayTransform(InvertibleTransform, LazyTrait, RandomizableTrait):

    def __init__(
            self,
            ..., # transform-specific args
            lazy_evaluation: Optional[bool] = True
    ):
        self.randomizer = ARandomizer(..., prob)

        self.op = ADeterministicArrayTransform(0, ..., lazy_evaluation)

    def __call__(
            self,
            img: NdarrayOrTensor,
            ..., # call-time transform-specific args
            shape_override: Optional[Sequence] = None
    ) -> NdarrayOrTensor:
        angles = self.randomizer.sample(img)

        # TODO: the random transforms have been implemented to make use of Array ops,
        # which creates a problem if the operation name for "RandRotate" needs
        # to be "RandRotate" instead of "Rotate". This can be done via several
        # approaches:
        # 1. Use the functional op directly
        # 2. Pass an override to the array op for the name
        return self.op(img, angles, mode, padding_mode, align_corners, shape_override)

    @property
    def lazy_evaluation(self):
        return self.op.lazy_evaluation

    @lazy_evaluation.setter
    def lazy_evaluation(self, value):
        self.op.lazy_evaluation = value

    def inverse(
            self,
            data: NdarrayOrTensor,
    ):
        raise NotImplementedError()

Dictionary transforms

Deterministic dictionary transforms

class ADeterministicDictionaryTransform(MapTransform, InvertibleTransform, LazyTrait):

    backend = ADeterministicDictionaryTransform.backend

    def __init__(
        self,
        keys: KeysCollection,
        lazy_evaluation: Optional[bool] = True
        ...,
        allow_missing_keys: bool = False,
    ) -> None:
        super().__init__(keys, allow_missing_keys)

        # operation-specific member variables
        ...

        self.op = ADeterministicArrayTransform(..., lazy_evaluation=lazy_evaluation)

    def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
        rd = dict(data)
        for key, mode, padding_mode, align_corners, dtype in self.key_iterator(
            rd, self.mode, self.padding_mode, self.align_corners, self.dtype
        ):
            rd[key] = self.op(rd[key], ...)
        return rd

    def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
        d = dict(data)
        for key in self.key_iterator(d):
            d[key] = self.op.inverse(d[key])
        return d

Random dictionary transforms

class ARandomDictionaryTransform(MapTransform, InvertibleTransform,
                                 LazyTrait, RandomizableTrait):

    def __init__(
            self,
            keys: KeysCollection,
            ..., # transform-specific args
            allow_missing_keys: Optional[bool] = False,
            lazy_evaluation: Optional[bool] = True,
    ):
        self.keys = keys
        self.allow_missing_keys = allow_missing_keys
        self.op = RandRotate2(
            range_x, range_y, range_z, prob,
            keep_size, mode, padding_mode, align_corners, dtype, lazy_evaluation
        )

    def __call__(self, data: Mapping[Hashable, torch.Tensor]):
        rd = dict(data)
        first_key = self.first_key(rd)
        if first_key == ():
            out = convert_to_tensor(rd, track_meta=get_track_meta())
            return out

        self.op.randomize(rd[first_key])

        it = self.key_iterator(rd, self.mode, self.padding_mode, self.align_corners)
        for key, mode, padding_mode, align_corners in it:
            rd[key] = self.op(rd[key], mode=mode, padding_mode=padding_mode,
                              align_corners=align_corners, randomize=False)

        return rd


    def inverse(self, data):
        raise NotImplementedError()
Clone this wiki locally