Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert RandomZoom to backend-agnostic and improve affine_transform #574

Merged
merged 18 commits into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Convert RandomZoom
  • Loading branch information
james77777778 committed Jul 21, 2023
commit 0ff51ba17932506caf9ec4db44c561c62cc0e26b
191 changes: 154 additions & 37 deletions keras_core/layers/preprocessing/random_zoom.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import numpy as np

from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
from keras_core.utils.module_utils import tensorflow as tf
from keras_core.layers.preprocessing.tf_data_layer import TFDataLayer
from keras_core.random.seed_generator import SeedGenerator


@keras_core_export("keras_core.layers.RandomZoom")
class RandomZoom(Layer):
class RandomZoom(TFDataLayer):
"""A preprocessing layer which randomly zooms images during training.

This layer will randomly zoom in or out on each axis of an image
Expand Down Expand Up @@ -87,6 +84,13 @@ class RandomZoom(Layer):
`(..., height, width, channels)`, in `"channels_last"` format.
"""

_FACTOR_VALIDATION_ERROR = (
"The `factor` argument should be a number (or a list of two numbers) "
"in the range [-1.0, 1.0]. "
)
_SUPPORTED_FILL_MODE = ("reflect", "wrap", "constant", "nearest")
_SUPPORTED_INTERPOLATION = ("nearest", "bilinear")

def __init__(
self,
height_factor,
Expand All @@ -95,46 +99,159 @@ def __init__(
interpolation="bilinear",
seed=None,
fill_value=0.0,
name=None,
data_format=None,
**kwargs,
):
if not tf.available:
raise ImportError(
"Layer RandomZoom requires TensorFlow. "
"Install it via `pip install tensorflow`."
super().__init__(**kwargs)
self.height_factor = height_factor
self.height_lower, self.height_upper = self._set_factor(
height_factor, "height_factor"
)
self.width_factor = width_factor
if width_factor is not None:
self.width_lower, self.width_upper = self._set_factor(
width_factor, "width_factor"
)
self._check_fill_mode_and_interpolation(fill_mode, interpolation)

self.fill_mode = fill_mode
self.fill_value = fill_value
self.interpolation = interpolation
self.seed = seed
self.generator = SeedGenerator(seed)
self.data_format = backend.standardize_data_format(data_format)

super().__init__(name=name, **kwargs)
self.seed = seed or backend.random.make_default_seed()
self.layer = tf.keras.layers.RandomZoom(
height_factor=height_factor,
width_factor=width_factor,
fill_mode=fill_mode,
interpolation=interpolation,
seed=self.seed,
name=name,
fill_value=fill_value,
**kwargs,
)
self._allow_non_tensor_positional_args = True
self._convert_input_args = False
self.supports_jit = False

def _set_factor(self, factor, factor_name):
if isinstance(factor, (tuple, list)):
if len(factor) != 2:
raise ValueError(
self._FACTOR_VALIDATION_ERROR
+ f"Received: {factor_name}={factor}"
)
self._check_factor_range(factor[0])
self._check_factor_range(factor[1])
lower, upper = sorted(factor)
elif isinstance(factor, (int, float)):
self._check_factor_range(factor)
factor = abs(factor)
lower, upper = [-factor, factor]
else:
raise ValueError(
self._FACTOR_VALIDATION_ERROR
+ f"Received: {factor_name}={factor}"
)
return lower, upper

def _check_factor_range(self, input_number):
if input_number > 1.0 or input_number < -1.0:
raise ValueError(
self._FACTOR_VALIDATION_ERROR
+ f"Received: input_number={input_number}"
)

def _check_fill_mode_and_interpolation(self, fill_mode, interpolation):
if fill_mode not in self._SUPPORTED_FILL_MODE:
raise NotImplementedError(
f"Unknown `fill_mode` {fill_mode}. Expected of one "
f"{self._SUPPORTED_FILL_MODE}."
)
if interpolation not in self._SUPPORTED_INTERPOLATION:
raise NotImplementedError(
f"Unknown `interpolation` {interpolation}. Expected of one "
f"{self._SUPPORTED_INTERPOLATION}."
)

def call(self, inputs, training=True):
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
inputs = tf.convert_to_tensor(backend.convert_to_numpy(inputs))
outputs = self.layer.call(inputs, training=training)
if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
):
outputs = backend.convert_to_tensor(outputs)
inputs = self.backend.cast(inputs, self.compute_dtype)
if training:
return self._randomly_zoom_inputs(inputs)
else:
return inputs

def _randomly_zoom_inputs(self, inputs):
unbatched = len(inputs.shape) == 3
if unbatched:
inputs = self.backend.numpy.expand_dims(inputs, axis=0)

batch_size = self.backend.shape(inputs)[0]
if self.data_format == "channels_first":
height = inputs.shape[-2]
width = inputs.shape[-1]
else:
height = inputs.shape[-3]
width = inputs.shape[-2]

seed_generator = self._get_seed_generator(self.backend._backend)
height_zoom = self.backend.random.uniform(
minval=1.0 + self.height_lower,
maxval=1.0 + self.height_upper,
shape=[batch_size, 1],
seed=seed_generator,
)
if self.width_factor is not None:
width_zoom = self.backend.random.uniform(
minval=1.0 + self.width_lower,
maxval=1.0 + self.width_upper,
shape=[batch_size, 1],
seed=seed_generator,
)
else:
width_zoom = height_zoom
zooms = self.backend.cast(
self.backend.numpy.concatenate([width_zoom, height_zoom], axis=1),
dtype="float32",
)

outputs = self.backend.image.affine_transform(
inputs,
transform=self._get_zoom_matrix(zooms, height, width),
interpolation=self.interpolation,
fill_mode=self.fill_mode,
fill_value=self.fill_value,
data_format=self.data_format,
)

if unbatched:
outputs = self.backend.numpy.squeeze(outputs, axis=0)
return outputs

def _get_zoom_matrix(self, zooms, image_height, image_width):
num_zooms = self.backend.shape(zooms)[0]
# The zoom matrix looks like:
# [[zx 0 0]
# [0 zy 0]
# [0 0 1]]
# where the last entry is implicit.
# zoom matrices are always float32.
x_offset = ((image_width - 1.0) / 2.0) * (1.0 - zooms[:, 0:1])
y_offset = ((image_height - 1.0) / 2.0) * (1.0 - zooms[:, 1:])
return self.backend.numpy.concatenate(
[
zooms[:, 0:1],
self.backend.numpy.zeros((num_zooms, 1)),
x_offset,
self.backend.numpy.zeros((num_zooms, 1)),
zooms[:, 1:],
y_offset,
self.backend.numpy.zeros((num_zooms, 2)),
],
axis=1,
)

def compute_output_shape(self, input_shape):
return tuple(self.layer.compute_output_shape(input_shape))
return input_shape

def get_config(self):
config = self.layer.get_config()
config.update({"seed": self.seed})
return config
base_config = super().get_config()
config = {
"height_factor": self.height_factor,
"width_factor": self.width_factor,
"fill_mode": self.fill_mode,
"interpolation": self.interpolation,
"seed": self.seed,
"fill_value": self.fill_value,
"data_format": self.data_format,
}
return {**base_config, **config}
60 changes: 42 additions & 18 deletions keras_core/layers/preprocessing/random_zoom_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,27 @@ def test_random_zoom(self, height_factor, width_factor):

def test_random_zoom_out_correctness(self):
input_image = np.reshape(np.arange(0, 25), (1, 5, 5, 1))
expected_output = np.asarray(
[
[0, 0, 0, 0, 0],
[0, 5, 7, 9, 0],
[0, 10, 12, 14, 0],
[0, 20, 22, 24, 0],
[0, 0, 0, 0, 0],
]
)
if backend.backend() == "torch":
# slightly different output with torch backend
expected_output = np.asarray(
[
[0, 0, 0, 0, 0],
[0, 6, 7, 9, 0],
[0, 11, 12, 14, 0],
[0, 21, 22, 24, 0],
[0, 0, 0, 0, 0],
]
)
else:
expected_output = np.asarray(
[
[0, 0, 0, 0, 0],
[0, 5, 7, 9, 0],
[0, 10, 12, 14, 0],
[0, 20, 22, 24, 0],
[0, 0, 0, 0, 0],
]
)
expected_output = backend.convert_to_tensor(
np.reshape(expected_output, (1, 5, 5, 1))
)
Expand All @@ -60,15 +72,27 @@ def test_random_zoom_out_correctness(self):

def test_random_zoom_in_correctness(self):
input_image = np.reshape(np.arange(0, 25), (1, 5, 5, 1))
expected_output = np.asarray(
[
[6, 7, 7, 8, 8],
[11, 12, 12, 13, 13],
[11, 12, 12, 13, 13],
[16, 17, 17, 18, 18],
[16, 17, 17, 18, 18],
]
)
if backend.backend() == "torch":
# slightly different output with torch backend
expected_output = np.asarray(
[
[6, 6, 7, 7, 8],
[6, 6, 7, 7, 8],
[11, 11, 12, 12, 13],
[11, 11, 12, 12, 13],
[16, 16, 17, 17, 18],
]
)
else:
expected_output = np.asarray(
[
[6, 7, 7, 8, 8],
[11, 12, 12, 13, 13],
[11, 12, 12, 13, 13],
[16, 17, 17, 18, 18],
[16, 17, 17, 18, 18],
]
)
expected_output = backend.convert_to_tensor(
np.reshape(expected_output, (1, 5, 5, 1))
)
Expand Down