Skip to content

Commit 0df93e4

Browse files
author
NoobMaster
authored
add sharpness (#1452)
* init sharpness * add batching * test commit * test factors
1 parent 3c2df18 commit 0df93e4

File tree

2 files changed

+78
-3
lines changed

2 files changed

+78
-3
lines changed

tensorflow_addons/image/color_ops.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414
# ==============================================================================
1515
"""Color operations.
1616
equalize: Equalizes image histogram
17+
sharpness: Sharpen image
1718
"""
1819

1920
import tensorflow as tf
2021

21-
from tensorflow_addons.utils.types import TensorLike
22+
from tensorflow_addons.utils.types import TensorLike, Number
2223
from tensorflow_addons.image.utils import to_4D_image, from_4D_image
24+
from tensorflow_addons.image.compose_ops import blend
2325

2426
from typing import Optional
2527
from functools import partial
@@ -84,7 +86,7 @@ def equalize(
8486
(num_images, num_rows, num_columns, num_channels) (NHWC), or
8587
(num_images, num_channels, num_rows, num_columns) (NCHW), or
8688
(num_rows, num_columns, num_channels) (HWC), or
87-
(num_channels, num_rows, num_columns) (HWC), or
89+
(num_channels, num_rows, num_columns) (CHW), or
8890
(num_rows, num_columns) (HW). The rank must be statically known (the
8991
shape is not `TensorShape(None)`).
9092
data_format: Either 'channels_first' or 'channels_last'
@@ -98,3 +100,55 @@ def equalize(
98100
fn = partial(equalize_image, data_format=data_format)
99101
image = tf.map_fn(fn, image)
100102
return from_4D_image(image, image_dims)
103+
104+
105+
def sharpness_image(image: TensorLike, factor: Number) -> tf.Tensor:
106+
"""Implements Sharpness function from PIL using TF ops."""
107+
orig_image = image
108+
image_dtype = image.dtype
109+
# Make image 4D for conv operation.
110+
image = tf.expand_dims(image, 0)
111+
# SMOOTH PIL Kernel.
112+
image = tf.cast(image, tf.float32)
113+
kernel = (
114+
tf.constant(
115+
[[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32, shape=[3, 3, 1, 1]
116+
)
117+
/ 13.0
118+
)
119+
# Tile across channel dimension.
120+
kernel = tf.tile(kernel, [1, 1, 3, 1])
121+
strides = [1, 1, 1, 1]
122+
degenerate = tf.nn.depthwise_conv2d(
123+
image, kernel, strides, padding="VALID", dilations=[1, 1]
124+
)
125+
degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
126+
degenerate = tf.squeeze(tf.cast(degenerate, image_dtype), [0])
127+
128+
# For the borders of the resulting image, fill in the values of the
129+
# original image.
130+
mask = tf.ones_like(degenerate)
131+
padded_mask = tf.pad(mask, [[1, 1], [1, 1], [0, 0]])
132+
padded_degenerate = tf.pad(degenerate, [[1, 1], [1, 1], [0, 0]])
133+
result = tf.where(tf.equal(padded_mask, 1), padded_degenerate, orig_image)
134+
# Blend the final result.
135+
blended = blend(result, orig_image, factor)
136+
return tf.cast(blended, image_dtype)
137+
138+
139+
def sharpness(image: TensorLike, factor: Number) -> tf.Tensor:
140+
"""Change sharpness of image(s)
141+
142+
Args:
143+
images: A tensor of shape
144+
(num_images, num_rows, num_columns, num_channels) (NHWC), or
145+
(num_rows, num_columns, num_channels) (HWC)
146+
factor: A floating point value or Tensor above 0.0.
147+
Returns:
148+
Image(s) with the same type and shape as `images`, sharper.
149+
"""
150+
image_dims = tf.rank(image)
151+
image = to_4D_image(image)
152+
fn = partial(sharpness_image, factor=factor)
153+
image = tf.map_fn(fn, image)
154+
return from_4D_image(image, image_dims)

tensorflow_addons/image/tests/color_ops_test.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import numpy as np
2020

2121
from tensorflow_addons.image import color_ops
22-
from PIL import Image, ImageOps
22+
from PIL import Image, ImageOps, ImageEnhance
2323

2424
_DTYPES = {
2525
np.uint8,
@@ -53,3 +53,24 @@ def test_equalize_channel_first(shape):
5353
image = tf.ones(shape=shape, dtype=tf.uint8)
5454
equalized = color_ops.equalize(image, "channels_first")
5555
np.testing.assert_equal(equalized.numpy(), image.numpy())
56+
57+
58+
@pytest.mark.parametrize("dtype", _DTYPES)
59+
@pytest.mark.parametrize("shape", [(5, 5, 3), (10, 5, 5, 3)])
60+
def test_sharpness_dtype_shape(dtype, shape):
61+
image = np.ones(shape=shape, dtype=dtype)
62+
sharp = color_ops.sharpness(tf.constant(image), 0).numpy()
63+
np.testing.assert_equal(sharp, image)
64+
assert sharp.dtype == image.dtype
65+
66+
67+
@pytest.mark.parametrize("factor", [0, 0.25, 0.5, 0.75, 1])
68+
def test_sharpness_with_PIL(factor):
69+
np.random.seed(0)
70+
image = np.random.randint(low=0, high=255, size=(10, 5, 5, 3), dtype=np.uint8)
71+
sharpened = np.stack(
72+
[ImageEnhance.Sharpness(Image.fromarray(i)).enhance(factor) for i in image]
73+
)
74+
np.testing.assert_allclose(
75+
color_ops.sharpness(tf.constant(image), factor).numpy(), sharpened, atol=1
76+
)

0 commit comments

Comments
 (0)