Skip to content

Commit

Permalink
[prototype] Gaussian Blur clean up (#6888)
Browse files Browse the repository at this point in the history
* Refactor gaussian_blur

* Add conditional reshape

* Further refactoring

* Remove unused import.
  • Loading branch information
datumbox authored Nov 2, 2022
1 parent c4c0ef9 commit 1921613
Showing 1 changed file with 33 additions and 27 deletions.
60 changes: 33 additions & 27 deletions torchvision/prototype/transforms/functional/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
from torch.nn.functional import conv2d, pad as torch_pad
from torchvision.prototype import features
from torchvision.transforms import functional_tensor as _FT
from torchvision.transforms.functional import pil_to_tensor, to_pil_image


Expand Down Expand Up @@ -68,9 +67,9 @@ def normalize(


def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
lim = (kernel_size - 1) / (2 * math.sqrt(2) * sigma)
lim = (kernel_size - 1) / (2.0 * math.sqrt(2.0) * sigma)
x = torch.linspace(-lim, lim, steps=kernel_size, dtype=dtype, device=device)
kernel1d = torch.softmax(-x.pow_(2), dim=0)
kernel1d = torch.softmax(x.pow_(2).neg_(), dim=0)
return kernel1d


Expand All @@ -89,54 +88,61 @@ def gaussian_blur_image_tensor(
# TODO: consider deprecating integers from sigma on the future
if isinstance(kernel_size, int):
kernel_size = [kernel_size, kernel_size]
if len(kernel_size) != 2:
elif len(kernel_size) != 2:
raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}")
for ksize in kernel_size:
if ksize % 2 == 0 or ksize < 0:
raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}")

if sigma is None:
sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]

if sigma is not None and not isinstance(sigma, (int, float, list, tuple)):
raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}")
if isinstance(sigma, (int, float)):
sigma = [float(sigma), float(sigma)]
if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
sigma = [sigma[0], sigma[0]]
if len(sigma) != 2:
raise ValueError(f"If sigma is a sequence, its length should be 2. Got {len(sigma)}")
else:
if isinstance(sigma, (list, tuple)):
length = len(sigma)
if length == 1:
s = float(sigma[0])
sigma = [s, s]
elif length != 2:
raise ValueError(f"If sigma is a sequence, its length should be 2. Got {length}")
elif isinstance(sigma, (int, float)):
s = float(sigma)
sigma = [s, s]
else:
raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}")
for s in sigma:
if s <= 0.0:
raise ValueError(f"sigma should have positive values. Got {sigma}")

if image.numel() == 0:
return image

dtype = image.dtype
shape = image.shape

if image.ndim > 4:
ndim = image.ndim
if ndim == 3:
image = image.unsqueeze(dim=0)
elif ndim > 4:
image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False

dtype = image.dtype if torch.is_floating_point(image) else torch.float32
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=image.device)
kernel = kernel.expand(image.shape[-3], 1, kernel.shape[0], kernel.shape[1])
fp = torch.is_floating_point(image)
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype if fp else torch.float32, device=image.device)
kernel = kernel.expand(shape[-3], 1, kernel.shape[0], kernel.shape[1])

image, need_cast, need_squeeze, out_dtype = _FT._cast_squeeze_in(image, [kernel.dtype])
output = image if fp else image.to(dtype=torch.float32)

# padding = (left, right, top, bottom)
padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
output = torch_pad(image, padding, mode="reflect")
output = conv2d(output, kernel, groups=output.shape[-3])

output = _FT._cast_squeeze_out(output, need_cast, need_squeeze, out_dtype)
output = torch_pad(output, padding, mode="reflect")
output = conv2d(output, kernel, groups=shape[-3])

if needs_unsquash:
if ndim == 3:
output = output.squeeze(dim=0)
elif ndim > 4:
output = output.reshape(shape)

if not fp:
output = output.round_().to(dtype=dtype)

return output


Expand Down

0 comments on commit 1921613

Please sign in to comment.