From 54da5db4c816983bb641630519214c8f2f34d324 Mon Sep 17 00:00:00 2001 From: vikramtankasali <39167441+vikramtankasali@users.noreply.github.com> Date: Thu, 11 Jun 2020 14:16:12 +0100 Subject: [PATCH] Adjust hue accepts torch tensor (#2300) * Adjust hue * Adjust hue acceps torch.tensor uint8 Co-authored-by: Vikram Mukunda Rao Tankasali --- test/test_functional_tensor.py | 40 +++++++++ torchvision/transforms/functional_tensor.py | 92 +++++++++++++++++++++ 2 files changed, 132 insertions(+) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 3eead60aed5..1a8c77c827f 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -6,6 +6,7 @@ import numpy as np import unittest import random +import colorsys from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple @@ -56,6 +57,45 @@ def test_crop(self): cropped_img_script = script_crop(img_tensor, top, left, height, width) self.assertTrue(torch.equal(img_cropped, cropped_img_script)) + def test_hsv2rgb(self): + shape = (3, 100, 150) + for _ in range(20): + img = torch.rand(*shape, dtype=torch.float) + ft_img = F_t._hsv2rgb(img).permute(1, 2, 0).flatten(0, 1) + + h, s, v, = img.unbind(0) + h = h.flatten().numpy() + s = s.flatten().numpy() + v = v.flatten().numpy() + + rgb = [] + for h1, s1, v1 in zip(h, s, v): + rgb.append(colorsys.hsv_to_rgb(h1, s1, v1)) + + colorsys_img = torch.tensor(rgb, dtype=torch.float32) + max_diff = (ft_img - colorsys_img).abs().max() + self.assertLess(max_diff, 1e-5) + + def test_rgb2hsv(self): + shape = (3, 150, 100) + for _ in range(20): + img = torch.rand(*shape, dtype=torch.float) + ft_hsv_img = F_t._rgb2hsv(img).permute(1, 2, 0).flatten(0, 1) + + r, g, b, = img.unbind(0) + r = r.flatten().numpy() + g = g.flatten().numpy() + b = b.flatten().numpy() + + hsv = [] + for r1, g1, b1 in zip(r, g, b): + hsv.append(colorsys.rgb_to_hsv(r1, g1, b1)) + + colorsys_img = torch.tensor(hsv, dtype=torch.float32) + + max_diff = (colorsys_img - ft_hsv_img).abs().max() + self.assertLess(max_diff, 1e-5) + def test_adjustments(self): script_adjust_brightness = torch.jit.script(F_t.adjust_brightness) script_adjust_contrast = torch.jit.script(F_t.adjust_contrast) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index c0815393c37..89440701d17 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -118,6 +118,54 @@ def adjust_contrast(img, contrast_factor): return _blend(img, mean, contrast_factor) +def adjust_hue(img, hue_factor): + """Adjust hue of an image. + + The image hue is adjusted by converting the image to HSV and + cyclically shifting the intensities in the hue channel (H). + The image is then converted back to original image mode. + + `hue_factor` is the amount of shift in H channel and must be in the + interval `[-0.5, 0.5]`. + + See `Hue`_ for more details. + + .. _Hue: https://en.wikipedia.org/wiki/Hue + + Args: + img (Tensor): Image to be adjusted. Image type is either uint8 or float. + hue_factor (float): How much to shift the hue channel. Should be in + [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in + HSV space in positive and negative direction respectively. + 0 means no shift. Therefore, both -0.5 and 0.5 will give an image + with complementary colors while 0 gives the original image. + + Returns: + Tensor: Hue adjusted image. + """ + if not(-0.5 <= hue_factor <= 0.5): + raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor)) + + if not _is_tensor_a_torch_image(img): + raise TypeError('tensor is not a torch image.') + + orig_dtype = img.dtype + if img.dtype == torch.uint8: + img = img.to(dtype=torch.float32) / 255.0 + + img = _rgb2hsv(img) + h, s, v = img.unbind(0) + h += hue_factor + h = h % 1.0 + img = torch.stack((h, s, v)) + img_hue_adj = _hsv2rgb(img) + + if orig_dtype == torch.uint8: + img_hue_adj = (img_hue_adj * 255.0).to(dtype=orig_dtype) + + return img_hue_adj + + def adjust_saturation(img, saturation_factor): # type: (Tensor, float) -> Tensor """Adjust color saturation of an RGB image. @@ -235,3 +283,47 @@ def _blend(img1, img2, ratio): # type: (Tensor, Tensor, float) -> Tensor bound = 1 if img1.dtype in [torch.half, torch.float32, torch.float64] else 255 return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype) + + +def _rgb2hsv(img): + r, g, b = img.unbind(0) + + maxc, _ = torch.max(img, dim=0) + minc, _ = torch.min(img, dim=0) + + cr = maxc - minc + s = cr / maxc + rc = (maxc - r) / cr + gc = (maxc - g) / cr + bc = (maxc - b) / cr + + t = (maxc != minc) + s = t * s + hr = (maxc == r) * (bc - gc) + hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc) + hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc) + h = (hr + hg + hb) + h = t * h + h = torch.fmod((h / 6.0 + 1.0), 1.0) + return torch.stack((h, s, maxc)) + + +def _hsv2rgb(img): + h, s, v = img.unbind(0) + i = torch.floor(h * 6.0) + f = (h * 6.0) - i + i = i.to(dtype=torch.int32) + + p = torch.clamp((v * (1.0 - s)), 0.0, 1.0) + q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0) + t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0) + i = i % 6 + + mask = i == torch.arange(6)[:, None, None] + + a1 = torch.stack((v, q, p, p, t, v)) + a2 = torch.stack((t, v, v, q, p, p)) + a3 = torch.stack((p, p, t, v, v, q)) + a4 = torch.stack((a1, a2, a3)) + + return torch.einsum("ijk, xijk -> xjk", mask.to(dtype=img.dtype), a4)