Skip to content

5127 fixes dtypes in randomized transforms #5129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions monai/apps/auto3dseg/hpo_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def get_hyperparameters(self):
raise NotImplementedError

@abstractmethod
def update_params(self, *args, **kwargs): # type: ignore
def update_params(self, *args, **kwargs):

"""Update Algo parameters according to the hyperparameters to be evaluated."""
raise NotImplementedError

Expand All @@ -48,7 +49,8 @@ def set_score(self):
raise NotImplementedError

@abstractmethod
def run_algo(self, *args, **kwargs): # type: ignore
def run_algo(self, *args, **kwargs):

"""Interface for launch the training given the fetched hyperparameters."""
raise NotImplementedError

Expand Down
2 changes: 1 addition & 1 deletion monai/auto3dseg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def concat_val_to_np(
elif ragged:
return np.concatenate(np_list, **kwargs) # type: ignore
else:
return np.concatenate([np_list], **kwargs) # type: ignore
return np.concatenate([np_list], **kwargs)


def concat_multikeys_to_dict(
Expand Down
11 changes: 4 additions & 7 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,13 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen
"""
Apply the transform to `img`.
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
img = convert_to_tensor(img, track_meta=get_track_meta(), dtype=self.dtype)
if randomize:
super().randomize(None)

if not self._do_transform:
return img

img, *_ = convert_data_type(img, dtype=self.dtype)
if self.channel_wise:
_mean = ensure_tuple_rep(self.mean, len(img))
_std = ensure_tuple_rep(self.std, len(img))
Expand Down Expand Up @@ -335,9 +334,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
if self.dtype is not None:
img, *_ = convert_data_type(img, dtype=self.dtype)
img = convert_to_tensor(img, track_meta=get_track_meta(), dtype=self.dtype)
if self.channel_wise:
for i, d in enumerate(img):
img[i] = self._stdshift(d) # type: ignore
Expand Down Expand Up @@ -394,7 +391,7 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen
"""
Apply the transform to `img`.
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
img = convert_to_tensor(img, track_meta=get_track_meta(), dtype=self.dtype)
if randomize:
self.randomize()

Expand Down Expand Up @@ -506,7 +503,7 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen
self.randomize()

if not self._do_transform:
return img
return convert_data_type(img, dtype=self.dtype)[0]

return ScaleIntensity(minv=None, maxv=None, factor=self.factor, dtype=self.dtype)(img)

Expand Down
36 changes: 18 additions & 18 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ def __init__(
When `mode` is an integer, using numpy/cupy backends, this argument accepts
{'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
dtype: data type for resampling computation. Defaults to ``float64`` for best precision.
If ``None``, use the data type of input data. To be compatible with other modules,
the output data type is always ``np.float32``.
the output data type is always ``float32``.
"""
self.mode = mode
self.padding_mode = padding_mode
Expand All @@ -160,7 +160,7 @@ def _post_process(
"""
Small fn to simplify returning data. If `MetaTensor`, update affine. Elif
tracking metadata is desired, create `MetaTensor` with affine. Else, return
image as `torch.Tensor`. Output type is always `torch.float32`.
image as `torch.Tensor`. Output type is always `float32`.

Also append the transform to the stack.
"""
Expand Down Expand Up @@ -473,9 +473,9 @@ def __init__(
See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
align_corners: Geometrically, we consider the pixels of the input as squares rather than points.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
dtype: data type for resampling computation. Defaults to ``float64`` for best precision.
If None, use the data type of input data. To be compatible with other modules,
the output data type is always ``np.float32``.
the output data type is always ``float32``.
scale_extent: whether the scale is computed based on the spacing or the full extent of voxels,
default False. The option is ignored if output spatial size is specified when calling this transform.
See also: :py:func:`monai.data.utils.compute_shape_offset`. When this is True, `align_corners`
Expand Down Expand Up @@ -526,7 +526,7 @@ def __call__(
Defaults to ``None``, effectively using the value of `self.align_corners`.
dtype: data type for resampling computation. Defaults to ``self.dtype``.
If None, use the data type of input data. To be compatible with other modules,
the output data type is always ``np.float32``.
the output data type is always ``float32``.
scale_extent: whether the scale is computed based on the spacing or the full extent of voxels,
The option is ignored if output spatial size is specified when calling this transform.
See also: :py:func:`monai.data.utils.compute_shape_offset`. When this is True, `align_corners`
Expand Down Expand Up @@ -961,9 +961,9 @@ class Rotate(InvertibleTransform):
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
align_corners: Defaults to False.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
dtype: data type for resampling computation. Defaults to ``np.float32``.
dtype: data type for resampling computation. Defaults to ``float32``.
If None, use the data type of input data. To be compatible with other modules,
the output data type is always ``np.float32``.
the output data type is always ``float32``.
"""

backend = [TransformBackends.TORCH]
Expand All @@ -975,7 +975,7 @@ def __init__(
mode: str = GridSampleMode.BILINEAR,
padding_mode: str = GridSamplePadMode.BORDER,
align_corners: bool = False,
dtype: Union[DtypeLike, torch.dtype] = np.float32,
dtype: Union[DtypeLike, torch.dtype] = torch.float32,
) -> None:
self.angle = angle
self.keep_size = keep_size
Expand Down Expand Up @@ -1007,7 +1007,7 @@ def __call__(
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
dtype: data type for resampling computation. Defaults to ``self.dtype``.
If None, use the data type of input data. To be compatible with other modules,
the output data type is always ``np.float32``.
the output data type is always ``float32``.

Raises:
ValueError: When ``img`` spatially is not one of [2D, 3D].
Expand Down Expand Up @@ -1388,9 +1388,9 @@ class RandRotate(RandomizableTransform, InvertibleTransform):
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
align_corners: Defaults to False.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
dtype: data type for resampling computation. Defaults to ``np.float32``.
dtype: data type for resampling computation. Defaults to ``float32``.
If None, use the data type of input data. To be compatible with other modules,
the output data type is always ``np.float32``.
the output data type is always ``float32``.
"""

backend = Rotate.backend
Expand Down Expand Up @@ -1460,7 +1460,7 @@ def __call__(
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
dtype: data type for resampling computation. Defaults to ``self.dtype``.
If None, use the data type of input data. To be compatible with other modules,
the output data type is always ``np.float32``.
the output data type is always ``float32``.
randomize: whether to execute `randomize()` function first, default to True.
"""
if randomize:
Expand All @@ -1477,7 +1477,7 @@ def __call__(
)
out = rotator(img)
else:
out = convert_to_tensor(img, track_meta=get_track_meta())
out = convert_to_tensor(img, track_meta=get_track_meta(), dtype=torch.float32)
if get_track_meta():
rot_info = self.pop_transform(out, check=False) if self._do_transform else {}
self.push_transform(out, extra_info=rot_info)
Expand Down Expand Up @@ -1688,7 +1688,7 @@ def __call__(
self.randomize(img=img)

if not self._do_transform:
out = convert_to_tensor(img, track_meta=get_track_meta())
out = convert_to_tensor(img, track_meta=get_track_meta(), dtype=torch.float32)
else:
out = Zoom(
self._zoom,
Expand Down Expand Up @@ -1731,7 +1731,7 @@ class AffineGrid(Transform):
pixel/voxel relative to the center of the input image. Defaults to no translation.
scale_params: scale factor for every spatial dims. a tuple of 2 floats for 2D,
a tuple of 3 floats for 3D. Defaults to `1.0`.
dtype: data type for the grid computation. Defaults to ``np.float32``.
dtype: data type for the grid computation. Defaults to ``float32``.
If ``None``, use the data type of input data (if `grid` is provided).
device: device on which the tensor will be allocated, if a new grid is generated.
affine: If applied, ignore the params (`rotate_params`, etc.) and use the
Expand Down Expand Up @@ -2007,7 +2007,7 @@ def __init__(
`[-1, 1]` (for torch ``grid_sample`` implementation) to be compatible with the underlying
resampling API.
device: device on which the tensor will be allocated.
dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
dtype: data type for resampling computation. Defaults to ``float64`` for best precision.
If ``None``, use the data type of input data. To be compatible with other modules,
the output data type is always `float32`.

Expand Down Expand Up @@ -2199,7 +2199,7 @@ def __init__(
If `normalized=False`, additional coordinate normalization will be applied before resampling.
See also: :py:func:`monai.networks.utils.normalize_transform`.
device: device on which the tensor will be allocated.
dtype: data type for resampling computation. Defaults to ``np.float32``.
dtype: data type for resampling computation. Defaults to ``float32``.
If ``None``, use the data type of input data. To be compatible with other modules,
the output data type is always `float32`.
image_only: if True return only the image volume, otherwise return (image, affine).
Expand Down
28 changes: 15 additions & 13 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,9 @@ def __init__(
align_corners: Geometrically, we consider the pixels of the input as squares rather than points.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
It also can be a sequence of bool, each element corresponds to a key in ``keys``.
dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
dtype: data type for resampling computation. Defaults to ``float64`` for best precision.
If None, use the data type of input data. To be compatible with other modules,
the output data type is always ``np.float32``.
the output data type is always ``float32``.
It also can be a sequence of dtypes, each element corresponds to a key in ``keys``.
dst_keys: the key of the corresponding ``dst_affine`` in the metadata dictionary.
allow_missing_keys: don't raise exception if key is missing.
Expand Down Expand Up @@ -268,9 +268,9 @@ def __init__(
align_corners: Geometrically, we consider the pixels of the input as squares rather than points.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
It also can be a sequence of bool, each element corresponds to a key in ``keys``.
dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
dtype: data type for resampling computation. Defaults to ``float64`` for best precision.
If None, use the data type of input data. To be compatible with other modules,
the output data type is always ``np.float32``.
the output data type is always ``float32``.
It also can be a sequence of dtypes, each element corresponds to a key in ``keys``.
allow_missing_keys: don't raise exception if key is missing.
"""
Expand Down Expand Up @@ -375,9 +375,9 @@ def __init__(
align_corners: Geometrically, we consider the pixels of the input as squares rather than points.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
It also can be a sequence of bool, each element corresponds to a key in ``keys``.
dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
dtype: data type for resampling computation. Defaults to ``float64`` for best precision.
If None, use the data type of input data. To be compatible with other modules,
the output data type is always ``np.float32``.
the output data type is always ``float32``.
It also can be a sequence of dtypes, each element corresponds to a key in ``keys``.
scale_extent: whether the scale is computed based on the spacing or the full extent of voxels,
default False. The option is ignored if output spatial size is specified when calling this transform.
Expand Down Expand Up @@ -696,7 +696,7 @@ def __init__(
See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
It also can be a sequence, each element corresponds to a key in ``keys``.
device: device on which the tensor will be allocated.
dtype: data type for resampling computation. Defaults to ``np.float32``.
dtype: data type for resampling computation. Defaults to ``float32``.
If ``None``, use the data type of input data. To be compatible with other modules,
the output data type is always `float32`.
allow_missing_keys: don't raise exception if key is missing.
Expand Down Expand Up @@ -861,6 +861,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
# do the transform
if do_resampling:
d[key] = self.rand_affine(d[key], mode=mode, padding_mode=padding_mode, grid=grid) # type: ignore
else:
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32)
if get_track_meta():
xform = self.pop_transform(d[key], check=False) if do_resampling else {}
self.push_transform(d[key], extra_info={"do_resampling": do_resampling, "rand_affine_info": xform})
Expand Down Expand Up @@ -1320,9 +1322,9 @@ class Rotated(MapTransform, InvertibleTransform):
align_corners: Defaults to False.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
It also can be a sequence of bool, each element corresponds to a key in ``keys``.
dtype: data type for resampling computation. Defaults to ``np.float32``.
dtype: data type for resampling computation. Defaults to ``float32``.
If None, use the data type of input data. To be compatible with other modules,
the output data type is always ``np.float32``.
the output data type is always ``float32``.
It also can be a sequence of dtype or None, each element corresponds to a key in ``keys``.
allow_missing_keys: don't raise exception if key is missing.
"""
Expand Down Expand Up @@ -1393,9 +1395,9 @@ class RandRotated(RandomizableTransform, MapTransform, InvertibleTransform):
align_corners: Defaults to False.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
It also can be a sequence of bool, each element corresponds to a key in ``keys``.
dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
dtype: data type for resampling computation. Defaults to ``float64`` for best precision.
If None, use the data type of input data. To be compatible with other modules,
the output data type is always ``np.float32``.
the output data type is always ``float32``.
It also can be a sequence of dtype or None, each element corresponds to a key in ``keys``.
allow_missing_keys: don't raise exception if key is missing.
"""
Expand Down Expand Up @@ -1450,7 +1452,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc
randomize=False,
)
else:
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32)
if get_track_meta():
rot_info = self.pop_transform(d[key], check=False) if self._do_transform else {}
self.push_transform(d[key], extra_info=rot_info)
Expand Down Expand Up @@ -1618,7 +1620,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc
d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, randomize=False
)
else:
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32)
if get_track_meta():
xform = self.pop_transform(d[key], check=False) if self._do_transform else {}
self.push_transform(d[key], extra_info=xform)
Expand Down
2 changes: 1 addition & 1 deletion monai/utils/type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def get_dtype(data: Any):

def convert_to_tensor(
data,
dtype: Optional[torch.dtype] = None,
dtype: Union[DtypeLike, torch.dtype] = None,
device: Union[None, str, torch.device] = None,
wrap_sequence: bool = False,
track_meta: bool = False,
Expand Down
2 changes: 2 additions & 0 deletions tests/test_rand_rician_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def test_correct_results(self, _, in_type, mean, std):
rician_fn.set_random_state(seed)
im = in_type(self.imt)
noised = rician_fn(im)
if isinstance(im, torch.Tensor):
self.assertEqual(im.dtype, noised.dtype)
np.random.seed(seed)
np.random.random()
_std = np.random.uniform(0, std)
Expand Down
Loading