Skip to content

Commit

Permalink
add int to int and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Jun 11, 2020
1 parent adfb096 commit 28e2fbf
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 75 deletions.
137 changes: 98 additions & 39 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@
os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')


def cycle_over(objs):
objs = list(objs)
for idx, obj in enumerate(objs):
yield obj, objs[:idx] + objs[idx + 1:]

def int_dtypes():
yield from iter(
(torch.uint8, torch.int8, torch.int16, torch.short, torch.int32, torch.int, torch.int64, torch.long,)
)

def float_dtypes():
yield from iter((torch.float32, torch.float, torch.float64, torch.double))


class Tester(unittest.TestCase):

def test_crop(self):
Expand Down Expand Up @@ -510,54 +524,99 @@ def test_to_tensor(self):
output = trans(img)
self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))

def test_convert_image_dtype(self):
def cycle_over(objs):
objs = list(objs)
for idx, obj in enumerate(objs):
yield obj, objs[:idx] + objs[idx + 1:]

# dtype_max_value = {
# dtype: 1.0
# for dtype in (torch.float32, torch.float, torch.float64, torch.double)#, torch.bool,)
# # torch.float16 and torch.half are disabled for now since they do not support torch.max
# # See https://github.com/pytorch/pytorch/issues/28623#issuecomment-611379051
# # (torch.float32, torch.float, torch.float64, torch.double, torch.float16, torch.half, torch.bool, )
# }
dtype_max_value = {}
dtype_max_value.update(
{
dtype: torch.iinfo(dtype).max
for dtype in (
torch.uint8,
torch.int8,
torch.int16,
torch.short,
torch.int32,
torch.int,
torch.int64,
torch.long,
)
}
)
def test_convert_image_dtype_float_to_float(self):
for input_dtype, output_dtypes in cycle_over(float_dtypes()):
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
for output_dtype in output_dtypes:
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
output_image = transform(input_image)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0.0, 1.0

self.assertAlmostEqual(actual_min, desired_min)
self.assertAlmostEqual(actual_max, desired_max)

def test_convert_image_dtype_float_to_int(self):
for input_dtype in float_dtypes():
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
for output_dtype in int_dtypes():
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)

if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or (
input_dtype == torch.float64 and output_dtype == torch.int64
):
with self.assertRaises(RuntimeError):
transform(input_image)
else:
output_image = transform(input_image)

for input_dtype, output_dtypes in cycle_over(dtype_max_value.keys()):
input_image = torch.ones(1, dtype=input_dtype) * dtype_max_value[input_dtype]
actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, torch.iinfo(output_dtype).max

self.assertEqual(actual_min, desired_min)
self.assertEqual(actual_max, desired_max)

def test_convert_image_dtype_int_to_float(self):
for input_dtype in int_dtypes():
input_image = torch.tensor((0, torch.iinfo(input_dtype).max), dtype=input_dtype)
for output_dtype in float_dtypes():
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
output_image = transform(input_image)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0.0, 1.0

self.assertAlmostEqual(actual_min, desired_min)
self.assertGreaterEqual(actual_min, desired_min)
self.assertAlmostEqual(actual_max, desired_max)
self.assertLessEqual(actual_max, desired_max)

def test_convert_image_dtype_int_to_int(self):
for input_dtype, output_dtypes in cycle_over(int_dtypes()):
input_max = torch.iinfo(input_dtype).max
input_image = torch.tensor((0, input_max), dtype=input_dtype)
for output_dtype in output_dtypes:
output_max = torch.iinfo(output_dtype).max

with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
output_image = transform(input_image)

actual = output_image.dtype
desired = output_dtype
self.assertEqual(actual, desired)
actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, output_max

actual = torch.max(output_image).item()
desired = dtype_max_value[output_dtype]
if output_dtype.is_floating_point:
self.assertAlmostEqual(actual, desired)
# see https://github.com/pytorch/vision/pull/2078#issuecomment-641036236 for details
if input_max >= output_max:
error_term = 0
else:
self.assertEqual(actual, desired)
error_term = 1 - (torch.iinfo(output_dtype).max + 1) // (torch.iinfo(input_dtype).max + 1)

self.assertEqual(actual_min, desired_min)
self.assertEqual(actual_max, desired_max + error_term)

def test_convert_image_dtype_int_to_int_consistency(self):
for input_dtype, output_dtypes in cycle_over(int_dtypes()):
input_max = torch.iinfo(input_dtype).max
input_image = torch.tensor((0, input_max), dtype=input_dtype)
for output_dtype in output_dtypes:
output_max = torch.iinfo(output_dtype).max
if output_max <= input_max:
continue

with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
inverse_transfrom = transforms.ConvertImageDtype(input_dtype)
output_image = inverse_transfrom(transform(input_image))

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, input_max

self.assertEqual(actual_min, desired_min)
self.assertEqual(actual_max, desired_max)

@unittest.skipIf(accimage is None, 'accimage not available')
def test_accimage_to_tensor(self):
Expand Down
70 changes: 34 additions & 36 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,7 @@ def pil_to_tensor(pic):
return img


def convert_image_dtype(
image: torch.Tensor, dtype: torch.dtype = torch.float
) -> torch.Tensor:
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly
Args:
Expand All @@ -125,28 +123,42 @@ def convert_image_dtype(
Returns:
(torch.Tensor): Converted image
.. note::
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
If converted back and forth, this mismatch has no effect.
Raises:
TypeError: When trying to cast :class:`torch.float32` to :class:`torch.int32`
or :class:`torch.int64` as well as for trying to cast
:class:`torch.float64` to :class:`torch.int64`. These conversions are
unsafe since the floating point ``dtype`` cannot store consecutive XXX. which might lead to overflow errors
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
of the integer ``dtype``.
"""
def float_to_float(image: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return image.to(dtype)

def float_to_int(image: torch.Tensor, dtype: torch.dtype, eps=1e-3) -> torch.Tensor:
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (image.dtype == torch.float64 and dtype == torch.int64):
msg = (f"The cast from {image.dtype} to {dtype} cannot be performed safely, "
f"since {image.dtype} cannot ")
raise TypeError(msg)
return image.mul(torch.iinfo(dtype).max + 1 - eps).to(dtype)
if image.dtype == dtype:
return image

def int_to_float(image: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
max = torch.iinfo(image.dtype).max
image = image.to(dtype)
return image / max
if image.dtype.is_floating_point:
# float to float
if dtype.is_floating_point:
return image.to(dtype)

# float to int
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
image.dtype == torch.float64 and dtype == torch.int64
):
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
raise RuntimeError(msg)

def int_to_int(image: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
eps = 1e-3
return image.mul(torch.iinfo(dtype).max + 1 - eps).to(dtype)
else:
# int to float
if dtype.is_floating_point:
max = torch.iinfo(image.dtype).max
image = image.to(dtype)
return image / max

# int to int
input_max = torch.iinfo(image.dtype).max
output_max = torch.iinfo(dtype).max

Expand All @@ -157,21 +169,7 @@ def int_to_int(image: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
else:
factor = (output_max + 1) // (input_max + 1)
image = image.to(dtype)
return (image + 1) * factor - 1

if image.dtype == dtype:
return image

if image.dtype.is_floating_point:
if dtype.is_floating_point:
return float_to_float(image, dtype)
else:
return float_to_int(image, dtype)
else:
if dtype.is_floating_point:
return int_to_float(image, dtype)
else:
return int_to_int(image, dtype)
return image * factor


def to_pil_image(pic, mode=None):
Expand Down
11 changes: 11 additions & 0 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,18 @@ class ConvertImageDtype(object):
Args:
dtype (torch.dtype): Desired data type of the output
.. note::
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
If converted back and forth, this mismatch has no effect.
Raises:
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
of the integer ``dtype``.
"""

def __init__(self, dtype: torch.dtype) -> None:
self.dtype = dtype

Expand Down

0 comments on commit 28e2fbf

Please sign in to comment.