Skip to content

Commit

Permalink
Adjust hue accepts torch tensor (#2300)
Browse files Browse the repository at this point in the history
* Adjust hue

* Adjust hue acceps torch.tensor uint8

Co-authored-by: Vikram Mukunda Rao Tankasali <vikramtankasali@devvm765.lla0.facebook.com>
  • Loading branch information
vikramtankasali and Vikram Mukunda Rao Tankasali committed Jun 11, 2020
1 parent 747f406 commit 54da5db
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 0 deletions.
40 changes: 40 additions & 0 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import unittest
import random
import colorsys
from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple


Expand Down Expand Up @@ -56,6 +57,45 @@ def test_crop(self):
cropped_img_script = script_crop(img_tensor, top, left, height, width)
self.assertTrue(torch.equal(img_cropped, cropped_img_script))

def test_hsv2rgb(self):
shape = (3, 100, 150)
for _ in range(20):
img = torch.rand(*shape, dtype=torch.float)
ft_img = F_t._hsv2rgb(img).permute(1, 2, 0).flatten(0, 1)

h, s, v, = img.unbind(0)
h = h.flatten().numpy()
s = s.flatten().numpy()
v = v.flatten().numpy()

rgb = []
for h1, s1, v1 in zip(h, s, v):
rgb.append(colorsys.hsv_to_rgb(h1, s1, v1))

colorsys_img = torch.tensor(rgb, dtype=torch.float32)
max_diff = (ft_img - colorsys_img).abs().max()
self.assertLess(max_diff, 1e-5)

def test_rgb2hsv(self):
shape = (3, 150, 100)
for _ in range(20):
img = torch.rand(*shape, dtype=torch.float)
ft_hsv_img = F_t._rgb2hsv(img).permute(1, 2, 0).flatten(0, 1)

r, g, b, = img.unbind(0)
r = r.flatten().numpy()
g = g.flatten().numpy()
b = b.flatten().numpy()

hsv = []
for r1, g1, b1 in zip(r, g, b):
hsv.append(colorsys.rgb_to_hsv(r1, g1, b1))

colorsys_img = torch.tensor(hsv, dtype=torch.float32)

max_diff = (colorsys_img - ft_hsv_img).abs().max()
self.assertLess(max_diff, 1e-5)

def test_adjustments(self):
script_adjust_brightness = torch.jit.script(F_t.adjust_brightness)
script_adjust_contrast = torch.jit.script(F_t.adjust_contrast)
Expand Down
92 changes: 92 additions & 0 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,54 @@ def adjust_contrast(img, contrast_factor):
return _blend(img, mean, contrast_factor)


def adjust_hue(img, hue_factor):
"""Adjust hue of an image.
The image hue is adjusted by converting the image to HSV and
cyclically shifting the intensities in the hue channel (H).
The image is then converted back to original image mode.
`hue_factor` is the amount of shift in H channel and must be in the
interval `[-0.5, 0.5]`.
See `Hue`_ for more details.
.. _Hue: https://en.wikipedia.org/wiki/Hue
Args:
img (Tensor): Image to be adjusted. Image type is either uint8 or float.
hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively.
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
with complementary colors while 0 gives the original image.
Returns:
Tensor: Hue adjusted image.
"""
if not(-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))

if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')

orig_dtype = img.dtype
if img.dtype == torch.uint8:
img = img.to(dtype=torch.float32) / 255.0

img = _rgb2hsv(img)
h, s, v = img.unbind(0)
h += hue_factor
h = h % 1.0
img = torch.stack((h, s, v))
img_hue_adj = _hsv2rgb(img)

if orig_dtype == torch.uint8:
img_hue_adj = (img_hue_adj * 255.0).to(dtype=orig_dtype)

return img_hue_adj


def adjust_saturation(img, saturation_factor):
# type: (Tensor, float) -> Tensor
"""Adjust color saturation of an RGB image.
Expand Down Expand Up @@ -235,3 +283,47 @@ def _blend(img1, img2, ratio):
# type: (Tensor, Tensor, float) -> Tensor
bound = 1 if img1.dtype in [torch.half, torch.float32, torch.float64] else 255
return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype)


def _rgb2hsv(img):
r, g, b = img.unbind(0)

maxc, _ = torch.max(img, dim=0)
minc, _ = torch.min(img, dim=0)

cr = maxc - minc
s = cr / maxc
rc = (maxc - r) / cr
gc = (maxc - g) / cr
bc = (maxc - b) / cr

t = (maxc != minc)
s = t * s
hr = (maxc == r) * (bc - gc)
hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
h = (hr + hg + hb)
h = t * h
h = torch.fmod((h / 6.0 + 1.0), 1.0)
return torch.stack((h, s, maxc))


def _hsv2rgb(img):
h, s, v = img.unbind(0)
i = torch.floor(h * 6.0)
f = (h * 6.0) - i
i = i.to(dtype=torch.int32)

p = torch.clamp((v * (1.0 - s)), 0.0, 1.0)
q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0)
t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
i = i % 6

mask = i == torch.arange(6)[:, None, None]

a1 = torch.stack((v, q, p, p, t, v))
a2 = torch.stack((t, v, v, q, p, p))
a3 = torch.stack((p, p, t, v, v, q))
a4 = torch.stack((a1, a2, a3))

return torch.einsum("ijk, xijk -> xjk", mask.to(dtype=img.dtype), a4)

0 comments on commit 54da5db

Please sign in to comment.