-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Description
The standard rule for dtype support for images and videos is:
- All floating point and integer tensors are supported.
- Floating point tensors are valid in the range
[0.0, 1.0]
and integer tensors in[0, torch.iinfo(dtype).max]
(this is currently under review since there were a few cases, where this was not true or simply not handled. See Don't hardcode 255 unless uint8 is enforced #6825)
However we have currently two kernels that only support uint8
images or videos:
vision/torchvision/prototype/transforms/functional/_color.py
Lines 373 to 375 in c84dbfa
def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: if image.dtype != torch.uint8: raise TypeError(f"Only torch.uint8 image tensors are supported, but found {image.dtype}") vision/torchvision/transforms/functional_tensor.py
Lines 788 to 789 in c84dbfa
if img.dtype != torch.uint8: raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
This also holds for transforms v1 so this is not a problem of the new API.
One consequence of that is that AA transforms are only supported for uint8
images
vision/torchvision/transforms/autoaugment.py
Lines 104 to 107 in c84dbfa
class AutoAugment(torch.nn.Module): | |
r"""AutoAugment data augmentation method based on | |
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_. | |
If the image is torch Tensor, it should be of type torch.uint8, and it is expected |
since both
vision/torchvision/transforms/autoaugment.py
Lines 76 to 77 in c84dbfa
elif op_name == "Posterize": | |
img = F.posterize(img, int(magnitude)) |
and
vision/torchvision/transforms/autoaugment.py
Lines 82 to 83 in c84dbfa
elif op_name == "Equalize": | |
img = F.equalize(img) |
are used.
One possible way of mitigating this to simply have a convert_dtype(image, torch.uint8)
in the beginning and converting back after computation.
That is probably needed for equalize
since we recently switched away from the histogram ops of torch
towards our "custom" implementation to enable batch processing (#6757). However, this relies on the fact that the input is an integer and in its current form even on uint8
due to some hardcoded constants.
For posterize
I think it is fairly easy to provide the same functionality for float inputs directly without going through a dtype conversion first.