Skip to content

Commit 10dd6fc

Browse files
ishandutta0098Yu0610
authored andcommitted
📝 [array] Add examples for EnsureType and CastToType (Project-MONAI#7245)
Fixes Project-MONAI#7101 ### Description Added examples in the docstrings for `EnsureType` and `CastToType` transforms which show how they function under different circumstances. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Ishan Dutta <ishandutta0098@gmail.com> Signed-off-by: Yu0610 <612410030@alum.ccu.edu.tw>
1 parent c2b2fd6 commit 10dd6fc

File tree

1 file changed

+34
-1
lines changed

1 file changed

+34
-1
lines changed

monai/transforms/utility/array.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,23 @@ class CastToType(Transform):
333333
"""
334334
Cast the Numpy data to specified numpy data type, or cast the PyTorch Tensor to
335335
specified PyTorch data type.
336+
337+
Example:
338+
>>> import numpy as np
339+
>>> import torch
340+
>>> transform = CastToType(dtype=np.float32)
341+
342+
>>> # Example with a numpy array
343+
>>> img_np = np.array([0, 127, 255], dtype=np.uint8)
344+
>>> img_np_casted = transform(img_np)
345+
>>> img_np_casted
346+
array([ 0. , 127. , 255. ], dtype=float32)
347+
348+
>>> # Example with a PyTorch tensor
349+
>>> img_tensor = torch.tensor([0, 127, 255], dtype=torch.uint8)
350+
>>> img_tensor_casted = transform(img_tensor)
351+
>>> img_tensor_casted
352+
tensor([ 0., 127., 255.]) # dtype is float32
336353
"""
337354

338355
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
@@ -413,10 +430,26 @@ class EnsureType(Transform):
413430
dtype: target data content type to convert, for example: np.float32, torch.float, etc.
414431
device: for Tensor data type, specify the target device.
415432
wrap_sequence: if `False`, then lists will recursively call this function, default to `True`.
416-
E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`.
417433
track_meta: if `True` convert to ``MetaTensor``, otherwise to Pytorch ``Tensor``,
418434
if ``None`` behave according to return value of py:func:`monai.data.meta_obj.get_track_meta`.
419435
436+
Example with wrap_sequence=True:
437+
>>> import numpy as np
438+
>>> import torch
439+
>>> transform = EnsureType(data_type="tensor", wrap_sequence=True)
440+
>>> # Converting a list to a tensor
441+
>>> data_list = [1, 2., 3]
442+
>>> tensor_data = transform(data_list)
443+
>>> tensor_data
444+
tensor([1., 2., 3.]) # All elements have dtype float32
445+
446+
Example with wrap_sequence=False:
447+
>>> transform = EnsureType(data_type="tensor", wrap_sequence=False)
448+
>>> # Converting each element in a list to individual tensors
449+
>>> data_list = [1, 2, 3]
450+
>>> tensors_list = transform(data_list)
451+
>>> tensors_list
452+
[tensor(1), tensor(2.), tensor(3)] # Only second element is float32 rest are int64
420453
"""
421454

422455
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

0 commit comments

Comments
 (0)