Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert RandomZoom to backend-agnostic and improve affine_transform #574

Merged
merged 18 commits into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 48 additions & 15 deletions keras_core/backend/torch/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def resize(
"constant": "zeros",
"nearest": "border",
# "wrap", not supported by torch
# "mirror", not supported by torch
"reflect": "reflection",
"mirror": "reflection",
# "reflect", not supported by torch
}


Expand Down Expand Up @@ -122,7 +122,7 @@ def _apply_grid_transform(
grid,
mode=interpolation,
padding_mode=fill_mode,
align_corners=False,
align_corners=True,
)
# Fill with required color
if fill_value is not None:
Expand Down Expand Up @@ -187,7 +187,8 @@ def affine_transform(
f"transform.shape={transform.shape}"
)

if fill_mode != "constant":
# the default of tnn.grid_sample is zeros
if fill_mode != "constant" or (fill_mode == "constant" and fill_value == 0):
fill_value = None
fill_mode = AFFINE_TRANSFORM_FILL_MODES[fill_mode]

Expand All @@ -202,21 +203,53 @@ def affine_transform(
if data_format == "channels_last":
image = image.permute((0, 3, 1, 2))

batch_size = image.shape[0]

# get indices
shape = [*image.shape[-2:], image.shape[-3]] # (H, W, C)
meshgrid = torch.meshgrid(
*[torch.arange(size) for size in shape], indexing="ij"
)
indices = torch.concatenate(
[torch.unsqueeze(x, dim=-1) for x in meshgrid], dim=-1
)
indices = torch.tile(indices, (batch_size, 1, 1, 1, 1))
indices = indices.to(transform)

# swap the values
a0 = transform[:, 0].clone()
a2 = transform[:, 2].clone()
b1 = transform[:, 4].clone()
b2 = transform[:, 5].clone()
transform[:, 0] = b1
transform[:, 2] = b2
transform[:, 4] = a0
transform[:, 5] = a2

# deal with transform
h, w = image.shape[2], image.shape[3]
theta = torch.zeros((image.shape[0], 2, 3)).to(transform)
theta[:, 0, 0] = transform[:, 0]
theta[:, 0, 1] = transform[:, 1] * h / w
theta[:, 0, 2] = (
transform[:, 2] * 2 / w + theta[:, 0, 0] + theta[:, 0, 1] - 1
transform = torch.nn.functional.pad(
transform, pad=[0, 1, 0, 0], mode="constant", value=1
)
theta[:, 1, 0] = transform[:, 3] * w / h
theta[:, 1, 1] = transform[:, 4]
theta[:, 1, 2] = (
transform[:, 5] * 2 / h + theta[:, 1, 0] + theta[:, 1, 1] - 1
transform = torch.reshape(transform, (batch_size, 3, 3))
offset = transform[:, 0:2, 2].clone()
offset = torch.nn.functional.pad(offset, pad=[0, 1, 0, 0])
transform[:, 0:2, 2] = 0

# transform the indices
coordinates = torch.einsum("Bhwij, Bjk -> Bhwik", indices, transform)
coordinates = torch.moveaxis(coordinates, source=-1, destination=1)
coordinates += torch.reshape(a=offset, shape=(*offset.shape, 1, 1, 1))
coordinates = coordinates[:, 0:2, ..., 0]
coordinates = coordinates.permute((0, 2, 3, 1))

# normalize coordinates
h, w = image.shape[-2], image.shape[-1]
coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / (w - 1) * 2.0 - 1.0
coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / (h - 1) * 2.0 - 1.0
grid = torch.stack(
[coordinates[:, :, :, 1], coordinates[:, :, :, 0]], dim=-1
)

grid = tnn.affine_grid(theta, image.shape)
affined = _apply_grid_transform(
image, grid, interpolation, fill_mode, fill_value
)
Expand Down
226 changes: 174 additions & 52 deletions keras_core/layers/preprocessing/random_zoom.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import numpy as np

from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
from keras_core.utils.module_utils import tensorflow as tf
from keras_core.layers.preprocessing.tf_data_layer import TFDataLayer
from keras_core.random.seed_generator import SeedGenerator


@keras_core_export("keras_core.layers.RandomZoom")
class RandomZoom(Layer):
class RandomZoom(TFDataLayer):
"""A preprocessing layer which randomly zooms images during training.

This layer will randomly zoom in or out on each axis of an image
Expand All @@ -18,17 +15,24 @@ class RandomZoom(Layer):
of integer or floating point dtype.
By default, the layer will output floats.

**Note:** This layer wraps `tf.keras.layers.RandomZoom`. It cannot
be used as part of the compiled computation graph of a model with
any backend other than TensorFlow.
It can however be used with any backend when running eagerly.
It can also always be used as part of an input preprocessing pipeline
with any backend (outside the model itself), which is how we recommend
to use this layer.
Input shape:
3D (unbatched) or 4D (batched) tensor with shape:
`(..., height, width, channels)`, in `"channels_last"` format,
or `(..., channels, height, width)`, in `"channels_first"` format.

Output shape:
3D (unbatched) or 4D (batched) tensor with shape:
`(..., target_height, target_width, channels)`,
or `(..., channels, target_height, target_width)`,
in `"channels_first"` format.

**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).

**Note:** The result image with the same transform might be different in
torch backend compared to other backends. The reason is the difference of
the interpolation implementation in `tnn.grid_sample`.

Args:
height_factor: a float represented as fraction of value,
or a tuple of size 2 representing lower and upper bound
Expand Down Expand Up @@ -71,22 +75,30 @@ class RandomZoom(Layer):
seed: Integer. Used to create a random seed.
fill_value: a float represents the value to be filled outside
the boundaries when `fill_mode="constant"`.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, height, width, channels)`
while `"channels_first"` corresponds to inputs with shape
`(batch, channels, height, width)`. It defaults to the
`image_data_format` value found in your Keras config file at
`~/.keras/keras.json`. If you never set it, then it will be
`"channels_last"`.
**kwargs: Base layer keyword arguments, such as `name` and `dtype`.

Example:

>>> input_img = np.random.random((32, 224, 224, 3))
>>> layer = keras_core.layers.RandomZoom(.5, .2)
>>> out_img = layer(input_img)

Input shape:
3D (unbatched) or 4D (batched) tensor with shape:
`(..., height, width, channels)`, in `"channels_last"` format.

Output shape:
3D (unbatched) or 4D (batched) tensor with shape:
`(..., height, width, channels)`, in `"channels_last"` format.
"""

_FACTOR_VALIDATION_ERROR = (
"The `factor` argument should be a number (or a list of two numbers) "
"in the range [-1.0, 1.0]. "
)
_SUPPORTED_FILL_MODE = ("reflect", "wrap", "constant", "nearest")
_SUPPORTED_INTERPOLATION = ("nearest", "bilinear")

def __init__(
self,
height_factor,
Expand All @@ -95,46 +107,156 @@ def __init__(
interpolation="bilinear",
seed=None,
fill_value=0.0,
name=None,
data_format=None,
**kwargs,
):
if not tf.available:
raise ImportError(
"Layer RandomZoom requires TensorFlow. "
"Install it via `pip install tensorflow`."
super().__init__(**kwargs)
self.height_factor = height_factor
self.height_lower, self.height_upper = self._set_factor(
height_factor, "height_factor"
)
self.width_factor = width_factor
if width_factor is not None:
self.width_lower, self.width_upper = self._set_factor(
width_factor, "width_factor"
)
if fill_mode not in self._SUPPORTED_FILL_MODE:
raise NotImplementedError(
f"Unknown `fill_mode` {fill_mode}. Expected of one "
f"{self._SUPPORTED_FILL_MODE}."
)
if interpolation not in self._SUPPORTED_INTERPOLATION:
raise NotImplementedError(
f"Unknown `interpolation` {interpolation}. Expected of one "
f"{self._SUPPORTED_INTERPOLATION}."
)

super().__init__(name=name, **kwargs)
self.seed = seed or backend.random.make_default_seed()
self.layer = tf.keras.layers.RandomZoom(
height_factor=height_factor,
width_factor=width_factor,
fill_mode=fill_mode,
interpolation=interpolation,
seed=self.seed,
name=name,
fill_value=fill_value,
**kwargs,
)
self._allow_non_tensor_positional_args = True
self._convert_input_args = False
self.fill_mode = fill_mode
self.fill_value = fill_value
self.interpolation = interpolation
self.seed = seed
self.generator = SeedGenerator(seed)
self.data_format = backend.standardize_data_format(data_format)

self.supports_jit = False

def _set_factor(self, factor, factor_name):
if isinstance(factor, (tuple, list)):
if len(factor) != 2:
raise ValueError(
self._FACTOR_VALIDATION_ERROR
+ f"Received: {factor_name}={factor}"
)
self._check_factor_range(factor[0])
self._check_factor_range(factor[1])
lower, upper = sorted(factor)
elif isinstance(factor, (int, float)):
self._check_factor_range(factor)
factor = abs(factor)
lower, upper = [-factor, factor]
else:
raise ValueError(
self._FACTOR_VALIDATION_ERROR
+ f"Received: {factor_name}={factor}"
)
return lower, upper

def _check_factor_range(self, input_number):
if input_number > 1.0 or input_number < -1.0:
raise ValueError(
self._FACTOR_VALIDATION_ERROR
+ f"Received: input_number={input_number}"
)

def call(self, inputs, training=True):
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
inputs = tf.convert_to_tensor(backend.convert_to_numpy(inputs))
outputs = self.layer.call(inputs, training=training)
if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
):
outputs = backend.convert_to_tensor(outputs)
inputs = self.backend.cast(inputs, self.compute_dtype)
if training:
return self._randomly_zoom_inputs(inputs)
else:
return inputs

def _randomly_zoom_inputs(self, inputs):
unbatched = len(inputs.shape) == 3
if unbatched:
inputs = self.backend.numpy.expand_dims(inputs, axis=0)

batch_size = self.backend.shape(inputs)[0]
if self.data_format == "channels_first":
height = inputs.shape[-2]
width = inputs.shape[-1]
else:
height = inputs.shape[-3]
width = inputs.shape[-2]

seed_generator = self._get_seed_generator(self.backend._backend)
height_zoom = self.backend.random.uniform(
minval=1.0 + self.height_lower,
maxval=1.0 + self.height_upper,
shape=[batch_size, 1],
seed=seed_generator,
)
if self.width_factor is not None:
width_zoom = self.backend.random.uniform(
minval=1.0 + self.width_lower,
maxval=1.0 + self.width_upper,
shape=[batch_size, 1],
seed=seed_generator,
)
else:
width_zoom = height_zoom
zooms = self.backend.cast(
self.backend.numpy.concatenate([width_zoom, height_zoom], axis=1),
dtype="float32",
)

outputs = self.backend.image.affine_transform(
inputs,
transform=self._get_zoom_matrix(zooms, height, width),
interpolation=self.interpolation,
fill_mode=self.fill_mode,
fill_value=self.fill_value,
data_format=self.data_format,
)

if unbatched:
outputs = self.backend.numpy.squeeze(outputs, axis=0)
return outputs

def _get_zoom_matrix(self, zooms, image_height, image_width):
num_zooms = self.backend.shape(zooms)[0]
# The zoom matrix looks like:
# [[zx 0 0]
# [0 zy 0]
# [0 0 1]]
# where the last entry is implicit.
# zoom matrices are always float32.
x_offset = ((image_width - 1.0) / 2.0) * (1.0 - zooms[:, 0:1])
y_offset = ((image_height - 1.0) / 2.0) * (1.0 - zooms[:, 1:])
return self.backend.numpy.concatenate(
[
zooms[:, 0:1],
self.backend.numpy.zeros((num_zooms, 1)),
x_offset,
self.backend.numpy.zeros((num_zooms, 1)),
zooms[:, 1:],
y_offset,
self.backend.numpy.zeros((num_zooms, 2)),
],
axis=1,
)

def compute_output_shape(self, input_shape):
return tuple(self.layer.compute_output_shape(input_shape))
return input_shape

def get_config(self):
config = self.layer.get_config()
config.update({"seed": self.seed})
return config
base_config = super().get_config()
config = {
"height_factor": self.height_factor,
"width_factor": self.width_factor,
"fill_mode": self.fill_mode,
"interpolation": self.interpolation,
"seed": self.seed,
"fill_value": self.fill_value,
"data_format": self.data_format,
}
return {**base_config, **config}
Loading