Skip to content

Commit

Permalink
Add convert_image_dtype to functionals (#2078)
Browse files Browse the repository at this point in the history
* add convert_image_dtype to functionals

* add ConvertImageDtype transform

* add test

* remove underscores from numbers since they are not compatible with python<3.6

* address review comments 1/3

* fix torch.bool

* use torch.iinfo in test

* fix flake8

* remove double conversion

* fix flake9

* bug fix

* add error messages to test

* disable torch.float16 and torch.half for now

* add docstring

* add test for consistency

* move nested function to top

* test in CI

* dirty progress

* add int to int and cleanup

* lint

Co-authored-by: Philip Meier <meier.philip@posteo.de>
  • Loading branch information
pmeier and Philip Meier committed Jun 11, 2020
1 parent 54da5db commit c2e8a00
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 4 deletions.
110 changes: 110 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,22 @@
os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')


def cycle_over(objs):
objs = list(objs)
for idx, obj in enumerate(objs):
yield obj, objs[:idx] + objs[idx + 1:]


def int_dtypes():
yield from iter(
(torch.uint8, torch.int8, torch.int16, torch.short, torch.int32, torch.int, torch.int64, torch.long,)
)


def float_dtypes():
yield from iter((torch.float32, torch.float, torch.float64, torch.double))


class Tester(unittest.TestCase):

def test_crop(self):
Expand Down Expand Up @@ -510,6 +526,100 @@ def test_to_tensor(self):
output = trans(img)
self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))

def test_convert_image_dtype_float_to_float(self):
for input_dtype, output_dtypes in cycle_over(float_dtypes()):
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
for output_dtype in output_dtypes:
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
output_image = transform(input_image)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0.0, 1.0

self.assertAlmostEqual(actual_min, desired_min)
self.assertAlmostEqual(actual_max, desired_max)

def test_convert_image_dtype_float_to_int(self):
for input_dtype in float_dtypes():
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
for output_dtype in int_dtypes():
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)

if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or (
input_dtype == torch.float64 and output_dtype == torch.int64
):
with self.assertRaises(RuntimeError):
transform(input_image)
else:
output_image = transform(input_image)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, torch.iinfo(output_dtype).max

self.assertEqual(actual_min, desired_min)
self.assertEqual(actual_max, desired_max)

def test_convert_image_dtype_int_to_float(self):
for input_dtype in int_dtypes():
input_image = torch.tensor((0, torch.iinfo(input_dtype).max), dtype=input_dtype)
for output_dtype in float_dtypes():
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
output_image = transform(input_image)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0.0, 1.0

self.assertAlmostEqual(actual_min, desired_min)
self.assertGreaterEqual(actual_min, desired_min)
self.assertAlmostEqual(actual_max, desired_max)
self.assertLessEqual(actual_max, desired_max)

def test_convert_image_dtype_int_to_int(self):
for input_dtype, output_dtypes in cycle_over(int_dtypes()):
input_max = torch.iinfo(input_dtype).max
input_image = torch.tensor((0, input_max), dtype=input_dtype)
for output_dtype in output_dtypes:
output_max = torch.iinfo(output_dtype).max

with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
output_image = transform(input_image)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, output_max

# see https://github.com/pytorch/vision/pull/2078#issuecomment-641036236 for details
if input_max >= output_max:
error_term = 0
else:
error_term = 1 - (torch.iinfo(output_dtype).max + 1) // (torch.iinfo(input_dtype).max + 1)

self.assertEqual(actual_min, desired_min)
self.assertEqual(actual_max, desired_max + error_term)

def test_convert_image_dtype_int_to_int_consistency(self):
for input_dtype, output_dtypes in cycle_over(int_dtypes()):
input_max = torch.iinfo(input_dtype).max
input_image = torch.tensor((0, input_max), dtype=input_dtype)
for output_dtype in output_dtypes:
output_max = torch.iinfo(output_dtype).max
if output_max <= input_max:
continue

with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
inverse_transfrom = transforms.ConvertImageDtype(input_dtype)
output_image = inverse_transfrom(transform(input_image))

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, input_max

self.assertEqual(actual_min, desired_min)
self.assertEqual(actual_max, desired_max)

@unittest.skipIf(accimage is None, 'accimage not available')
def test_accimage_to_tensor(self):
trans = transforms.ToTensor()
Expand Down
59 changes: 59 additions & 0 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,65 @@ def pil_to_tensor(pic):
return img


def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly
Args:
image (torch.Tensor): Image to be converted
dtype (torch.dtype): Desired data type of the output
Returns:
(torch.Tensor): Converted image
.. note::
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
If converted back and forth, this mismatch has no effect.
Raises:
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
of the integer ``dtype``.
"""
if image.dtype == dtype:
return image

if image.dtype.is_floating_point:
# float to float
if dtype.is_floating_point:
return image.to(dtype)

# float to int
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
image.dtype == torch.float64 and dtype == torch.int64
):
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
raise RuntimeError(msg)

eps = 1e-3
return image.mul(torch.iinfo(dtype).max + 1 - eps).to(dtype)
else:
# int to float
if dtype.is_floating_point:
max = torch.iinfo(image.dtype).max
image = image.to(dtype)
return image / max

# int to int
input_max = torch.iinfo(image.dtype).max
output_max = torch.iinfo(dtype).max

if input_max > output_max:
factor = (input_max + 1) // (output_max + 1)
image = image // factor
return image.to(dtype)
else:
factor = (output_max + 1) // (input_max + 1)
image = image.to(dtype)
return image * factor


def to_pil_image(pic, mode=None):
"""Convert a tensor or an ndarray to PIL Image.
Expand Down
33 changes: 29 additions & 4 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from . import functional as F


__all__ = ["Compose", "ToTensor", "PILToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
"Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
"RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
"ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale",
"CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop",
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
"LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
"RandomPerspective", "RandomErasing"]

_pil_interpolation_to_str = {
Expand Down Expand Up @@ -115,6 +115,31 @@ def __repr__(self):
return self.__class__.__name__ + '()'


class ConvertImageDtype(object):
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly
Args:
dtype (torch.dtype): Desired data type of the output
.. note::
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
If converted back and forth, this mismatch has no effect.
Raises:
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
of the integer ``dtype``.
"""

def __init__(self, dtype: torch.dtype) -> None:
self.dtype = dtype

def __call__(self, image: torch.Tensor) -> torch.Tensor:
return F.convert_image_dtype(image, self.dtype)


class ToPILImage(object):
"""Convert a tensor or an ndarray to PIL Image.
Expand Down

0 comments on commit c2e8a00

Please sign in to comment.