Skip to content

Commit

Permalink
Split RandomlyResizedCrop into two API surfaces (RandomlyZoomedCrop, …
Browse files Browse the repository at this point in the history
…RandomCropAndResize) (keras-team#738)

* Sync

* Added zoom factor to RRC

* Used tf.shape

* dtype mismatch

* Debugging...

* Debugging...

* Fix example

* RRC uses preprocessing.transform now

* Minor error

* Minor bug

* minor issue

* minor issue

* minor bug

* Added unit tests

* Fix serialization test

* KerasCV simclr api update

* Augmenter

* Augmenter

* serialization test

* serialization test

* fix failing test

* Split RRC API into two layers

* Split RRC API into two layers

* Format serialization_test

* Implemented bounding box support

* Add preprocessing

* serialization test

* serialization test

* serialization test

* RandomCropAndResize in SimCLR

* RandomCropAndResize in SimCLR

* Update examples

* Update examples

* Update examples

* Update target_size

Co-authored-by: Luke Wood <lukewoodcs@gmail.com>
  • Loading branch information
AdityaKane2001 and LukeWood authored Sep 22, 2022
1 parent 1490ca9 commit 3bcf158
Show file tree
Hide file tree
Showing 12 changed files with 492 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,22 @@

import demo_utils

from keras_cv.layers.preprocessing import RandomResizedCrop
from keras_cv.layers import RandomCropAndResize


def main():
many_elephants = demo_utils.load_elephant_tensor(output_size=(300, 300))
layer = RandomResizedCrop(
layer = RandomCropAndResize(
target_size=(224, 224),
crop_area_factor=(0.08, 1.0),
crop_area_factor=(0.8, 1.0),
aspect_ratio_factor=(3.0 / 4.0, 4.0 / 3.0),
)
augmented = layer(many_elephants)
demo_utils.gallery_show(augmented.numpy())

layer = RandomCropAndResize(
target_size=(224, 224),
crop_area_factor=(0.01, 1.0),
aspect_ratio_factor=(3.0 / 4.0, 4.0 / 3.0),
)
augmented = layer(many_elephants)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import demo_utils

from keras_cv.layers import RandomlyZoomedCrop


def main():
many_elephants = demo_utils.load_elephant_tensor(output_size=(300, 300))
layer = RandomlyZoomedCrop(
target_size=(224, 224),
zoom_factor=(0.8, 1.2),
aspect_ratio_factor=(3.0 / 4.0, 4.0 / 3.0),
)
augmented = layer(many_elephants)
demo_utils.gallery_show(augmented.numpy())

layer = RandomlyZoomedCrop(
target_size=(224, 224),
zoom_factor=(0.08, 2.0),
aspect_ratio_factor=(3.0 / 4.0, 4.0 / 3.0),
)
augmented = layer(many_elephants)
demo_utils.gallery_show(augmented.numpy())


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion keras_cv/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,16 @@
RandomColorDegeneration,
)
from keras_cv.layers.preprocessing.random_color_jitter import RandomColorJitter
from keras_cv.layers.preprocessing.random_crop_and_resize import RandomCropAndResize
from keras_cv.layers.preprocessing.random_cutout import RandomCutout
from keras_cv.layers.preprocessing.random_flip import RandomFlip
from keras_cv.layers.preprocessing.random_gaussian_blur import RandomGaussianBlur
from keras_cv.layers.preprocessing.random_hue import RandomHue
from keras_cv.layers.preprocessing.random_jpeg_quality import RandomJpegQuality
from keras_cv.layers.preprocessing.random_resized_crop import RandomResizedCrop
from keras_cv.layers.preprocessing.random_saturation import RandomSaturation
from keras_cv.layers.preprocessing.random_sharpness import RandomSharpness
from keras_cv.layers.preprocessing.random_shear import RandomShear
from keras_cv.layers.preprocessing.randomly_zoomed_crop import RandomlyZoomedCrop
from keras_cv.layers.preprocessing.solarization import Solarization
from keras_cv.layers.regularization.drop_path import DropPath
from keras_cv.layers.regularization.dropblock_2d import DropBlock2D
Expand Down
3 changes: 2 additions & 1 deletion keras_cv/layers/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,15 @@
RandomColorDegeneration,
)
from keras_cv.layers.preprocessing.random_color_jitter import RandomColorJitter
from keras_cv.layers.preprocessing.random_crop_and_resize import RandomCropAndResize
from keras_cv.layers.preprocessing.random_cutout import RandomCutout
from keras_cv.layers.preprocessing.random_flip import RandomFlip
from keras_cv.layers.preprocessing.random_gaussian_blur import RandomGaussianBlur
from keras_cv.layers.preprocessing.random_hue import RandomHue
from keras_cv.layers.preprocessing.random_jpeg_quality import RandomJpegQuality
from keras_cv.layers.preprocessing.random_resized_crop import RandomResizedCrop
from keras_cv.layers.preprocessing.random_rotation import RandomRotation
from keras_cv.layers.preprocessing.random_saturation import RandomSaturation
from keras_cv.layers.preprocessing.random_sharpness import RandomSharpness
from keras_cv.layers.preprocessing.random_shear import RandomShear
from keras_cv.layers.preprocessing.randomly_zoomed_crop import RandomlyZoomedCrop
from keras_cv.layers.preprocessing.solarization import Solarization
4 changes: 2 additions & 2 deletions keras_cv/layers/preprocessing/augmenter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_return_shapes(self):
preprocessing.Grayscale(
output_channels=1,
),
preprocessing.RandomResizedCrop(
preprocessing.RandomCropAndResize(
target_size=(100, 100),
crop_area_factor=(1, 1),
aspect_ratio_factor=(1, 1),
Expand All @@ -46,7 +46,7 @@ def test_in_tf_function(self):
preprocessing.Grayscale(
output_channels=1,
),
preprocessing.RandomResizedCrop(
preprocessing.RandomCropAndResize(
target_size=(100, 100),
crop_area_factor=(1, 1),
aspect_ratio_factor=(1, 1),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


@tf.keras.utils.register_keras_serializable(package="keras_cv")
class RandomResizedCrop(BaseImageAugmentationLayer):
class RandomCropAndResize(BaseImageAugmentationLayer):
"""Randomly crops a part of an image and resizes it to provided size.
This implementation takes an intuitive approach, where we crop the images to a
Expand Down Expand Up @@ -142,6 +142,9 @@ def call(self, inputs, training=True):
def augment_image(self, image, transformation, **kwargs):
return self._crop_and_resize(image, transformation)

def augment_target(self, target, **kwargs):
return target

def _resize(self, image, **kwargs):
outputs = tf.keras.preprocessing.image.smart_resize(
image, self.target_size, **kwargs
Expand Down Expand Up @@ -186,9 +189,6 @@ def _check_class_arguments(
f"aspect_ratio_factor={aspect_ratio_factor}"
)

def augment_target(self, augment_target, **kwargs):
return augment_target

def augment_segmentation_mask(self, segmentation_mask, transformation, **kwargs):
return self._crop_and_resize(
segmentation_mask, transformation, method="nearest"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from keras_cv.layers import preprocessing


class RandomResizedCropTest(tf.test.TestCase, parameterized.TestCase):
class RandomCropAndResizeTest(tf.test.TestCase, parameterized.TestCase):
height, width = 300, 300
batch_size = 4
target_size = (224, 224)
Expand All @@ -30,7 +30,7 @@ def test_train_augments_image(self):
input_image_shape = (self.batch_size, self.height, self.width, 3)
image = tf.random.uniform(shape=input_image_shape, seed=self.seed)

layer = preprocessing.RandomResizedCrop(
layer = preprocessing.RandomCropAndResize(
target_size=self.target_size,
aspect_ratio_factor=(3 / 4, 4 / 3),
crop_area_factor=(0.8, 1.0),
Expand All @@ -46,7 +46,7 @@ def test_grayscale(self):
input_image_shape = (self.batch_size, self.height, self.width, 1)
image = tf.random.uniform(shape=input_image_shape)

layer = preprocessing.RandomResizedCrop(
layer = preprocessing.RandomCropAndResize(
target_size=self.target_size,
aspect_ratio_factor=(3 / 4, 4 / 3),
crop_area_factor=(0.8, 1.0),
Expand All @@ -62,7 +62,7 @@ def test_preserves_image(self):
image_shape = (self.batch_size, self.height, self.width, 3)
image = tf.random.uniform(shape=image_shape)

layer = preprocessing.RandomResizedCrop(
layer = preprocessing.RandomCropAndResize(
target_size=self.target_size,
aspect_ratio_factor=(3 / 4, 4 / 3),
crop_area_factor=(0.8, 1.0),
Expand All @@ -84,7 +84,7 @@ def test_target_size_errors(self, target_size):
ValueError,
"`target_size` must be tuple of two integers. Received target_size=(.*)",
):
_ = preprocessing.RandomResizedCrop(
_ = preprocessing.RandomCropAndResize(
target_size=target_size,
aspect_ratio_factor=(3 / 4, 4 / 3),
crop_area_factor=(0.8, 1.0),
Expand All @@ -101,7 +101,7 @@ def test_aspect_ratio_factor_errors(self, aspect_ratio_factor):
"`aspect_ratio_factor` must be tuple of two positive floats or "
"keras_cv.core.FactorSampler instance. Received aspect_ratio_factor=(.*)",
):
_ = preprocessing.RandomResizedCrop(
_ = preprocessing.RandomCropAndResize(
target_size=(224, 224),
aspect_ratio_factor=aspect_ratio_factor,
crop_area_factor=(0.8, 1.0),
Expand All @@ -119,7 +119,7 @@ def test_crop_area_factor_errors(self, crop_area_factor):
"equal to 1 or keras_cv.core.FactorSampler instance. Received "
"crop_area_factor=(.*)",
):
_ = preprocessing.RandomResizedCrop(
_ = preprocessing.RandomCropAndResize(
target_size=(224, 224),
aspect_ratio_factor=(3 / 4, 4 / 3),
crop_area_factor=crop_area_factor,
Expand All @@ -136,7 +136,7 @@ def test_augment_sparse_segmentation_mask(self):
inputs = {"images": image, "segmentation_masks": mask}

# Crop-only to exactly 1/2 of the size
layer = preprocessing.RandomResizedCrop(
layer = preprocessing.RandomCropAndResize(
target_size=(150, 150),
aspect_ratio_factor=(1, 1),
crop_area_factor=(1, 1),
Expand All @@ -149,7 +149,7 @@ def test_augment_sparse_segmentation_mask(self):
self.assertAllClose(output["segmentation_masks"], input_mask_resized)

# Crop to an arbitrary size and make sure we don't do bad interpolation
layer = preprocessing.RandomResizedCrop(
layer = preprocessing.RandomCropAndResize(
target_size=(233, 233),
aspect_ratio_factor=(3 / 4, 4 / 3),
crop_area_factor=(0.8, 1.0),
Expand All @@ -172,7 +172,7 @@ def test_augment_one_hot_segmentation_mask(self):
inputs = {"images": image, "segmentation_masks": mask}

# Crop-only to exactly 1/2 of the size
layer = preprocessing.RandomResizedCrop(
layer = preprocessing.RandomCropAndResize(
target_size=(150, 150),
aspect_ratio_factor=(1, 1),
crop_area_factor=(1, 1),
Expand Down
Loading

0 comments on commit 3bcf158

Please sign in to comment.