|  | 
|  | 1 | +from keras.src.api_export import keras_export | 
|  | 2 | +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501 | 
|  | 3 | +    BaseImagePreprocessingLayer, | 
|  | 4 | +) | 
|  | 5 | +from keras.src.random import SeedGenerator | 
|  | 6 | + | 
|  | 7 | + | 
|  | 8 | +@keras_export("keras.layers.RandomGaussianBlur") | 
|  | 9 | +class RandomGaussianBlur(BaseImagePreprocessingLayer): | 
|  | 10 | +    """Applies random Gaussian blur to images for data augmentation. | 
|  | 11 | +
 | 
|  | 12 | +    This layer performs a Gaussian blur operation on input images with a | 
|  | 13 | +    randomly selected degree of blurring, controlled by the `factor` and | 
|  | 14 | +    `sigma` arguments. | 
|  | 15 | +
 | 
|  | 16 | +    Args: | 
|  | 17 | +        factor: A single float or a tuple of two floats. | 
|  | 18 | +            `factor` controls the extent to which the image hue is impacted. | 
|  | 19 | +            `factor=0.0` makes this layer perform a no-op operation, | 
|  | 20 | +            while a value of `1.0` performs the most aggressive | 
|  | 21 | +            blurring available. If a tuple is used, a `factor` is | 
|  | 22 | +            sampled between the two values for every image augmented. If a | 
|  | 23 | +            single float is used, a value between `0.0` and the passed float is | 
|  | 24 | +            sampled. Default is 1.0. | 
|  | 25 | +        kernel_size: Integer. Size of the Gaussian kernel used for blurring. | 
|  | 26 | +            Must be an odd integer. Default is 3. | 
|  | 27 | +        sigma: Float or tuple of two floats. Standard deviation of the Gaussian | 
|  | 28 | +            kernel. Controls the intensity of the blur. If a tuple is provided, | 
|  | 29 | +            a value is sampled between the two for each image. Default is 1.0. | 
|  | 30 | +        value_range: the range of values the incoming images will have. | 
|  | 31 | +            Represented as a two-number tuple written `[low, high]`. This is | 
|  | 32 | +            typically either `[0, 1]` or `[0, 255]` depending on how your | 
|  | 33 | +            preprocessing pipeline is set up. | 
|  | 34 | +        seed: Integer. Used to create a random seed. | 
|  | 35 | +    """ | 
|  | 36 | + | 
|  | 37 | +    _USE_BASE_FACTOR = False | 
|  | 38 | +    _FACTOR_BOUNDS = (0, 1) | 
|  | 39 | + | 
|  | 40 | +    def __init__( | 
|  | 41 | +        self, | 
|  | 42 | +        factor=1.0, | 
|  | 43 | +        kernel_size=3, | 
|  | 44 | +        sigma=1.0, | 
|  | 45 | +        value_range=(0, 255), | 
|  | 46 | +        data_format=None, | 
|  | 47 | +        seed=None, | 
|  | 48 | +        **kwargs, | 
|  | 49 | +    ): | 
|  | 50 | +        super().__init__(data_format=data_format, **kwargs) | 
|  | 51 | +        self._set_factor(factor) | 
|  | 52 | +        self.kernel_size = self._set_kernel_size(kernel_size, "kernel_size") | 
|  | 53 | +        self.sigma = self._set_factor_by_name(sigma, "sigma") | 
|  | 54 | +        self.value_range = value_range | 
|  | 55 | +        self.seed = seed | 
|  | 56 | +        self.generator = SeedGenerator(seed) | 
|  | 57 | + | 
|  | 58 | +    def _set_kernel_size(self, factor, name): | 
|  | 59 | +        error_msg = f"{name} must be an odd number. Received: {name}={factor}" | 
|  | 60 | +        if isinstance(factor, (tuple, list)): | 
|  | 61 | +            if len(factor) != 2: | 
|  | 62 | +                error_msg = ( | 
|  | 63 | +                    f"The `{name}` argument should be a number " | 
|  | 64 | +                    "(or a list of two numbers) " | 
|  | 65 | +                    f"Received: {name}={factor}" | 
|  | 66 | +                ) | 
|  | 67 | +                raise ValueError(error_msg) | 
|  | 68 | +            if (factor[0] % 2 == 0) or (factor[1] % 2 == 0): | 
|  | 69 | +                raise ValueError(error_msg) | 
|  | 70 | +            lower, upper = factor | 
|  | 71 | +        elif isinstance(factor, (int, float)): | 
|  | 72 | +            if factor % 2 == 0: | 
|  | 73 | +                raise ValueError(error_msg) | 
|  | 74 | +            lower, upper = factor, factor | 
|  | 75 | +        else: | 
|  | 76 | +            raise ValueError(error_msg) | 
|  | 77 | + | 
|  | 78 | +        return lower, upper | 
|  | 79 | + | 
|  | 80 | +    def _set_factor_by_name(self, factor, name): | 
|  | 81 | +        error_msg = ( | 
|  | 82 | +            f"The `{name}` argument should be a number " | 
|  | 83 | +            "(or a list of two numbers) " | 
|  | 84 | +            "in the range " | 
|  | 85 | +            f"[{self._FACTOR_BOUNDS[0]}, {self._FACTOR_BOUNDS[1]}]. " | 
|  | 86 | +            f"Received: factor={factor}" | 
|  | 87 | +        ) | 
|  | 88 | +        if isinstance(factor, (tuple, list)): | 
|  | 89 | +            if len(factor) != 2: | 
|  | 90 | +                raise ValueError(error_msg) | 
|  | 91 | +            if ( | 
|  | 92 | +                factor[0] > self._FACTOR_BOUNDS[1] | 
|  | 93 | +                or factor[1] < self._FACTOR_BOUNDS[0] | 
|  | 94 | +            ): | 
|  | 95 | +                raise ValueError(error_msg) | 
|  | 96 | +            lower, upper = sorted(factor) | 
|  | 97 | +        elif isinstance(factor, (int, float)): | 
|  | 98 | +            if ( | 
|  | 99 | +                factor < self._FACTOR_BOUNDS[0] | 
|  | 100 | +                or factor > self._FACTOR_BOUNDS[1] | 
|  | 101 | +            ): | 
|  | 102 | +                raise ValueError(error_msg) | 
|  | 103 | +            factor = abs(factor) | 
|  | 104 | +            lower, upper = [max(-factor, self._FACTOR_BOUNDS[0]), factor] | 
|  | 105 | +        else: | 
|  | 106 | +            raise ValueError(error_msg) | 
|  | 107 | +        return lower, upper | 
|  | 108 | + | 
|  | 109 | +    def create_gaussian_kernel(self, kernel_size, sigma, num_channels): | 
|  | 110 | +        def get_gaussian_kernel1d(size, sigma): | 
|  | 111 | +            x = ( | 
|  | 112 | +                self.backend.numpy.arange(size, dtype=self.compute_dtype) | 
|  | 113 | +                - (size - 1) / 2 | 
|  | 114 | +            ) | 
|  | 115 | +            kernel1d = self.backend.numpy.exp(-0.5 * (x / sigma) ** 2) | 
|  | 116 | +            return kernel1d / self.backend.numpy.sum(kernel1d) | 
|  | 117 | + | 
|  | 118 | +        def get_gaussian_kernel2d(size, sigma): | 
|  | 119 | +            kernel1d_x = get_gaussian_kernel1d(size[0], sigma[0]) | 
|  | 120 | +            kernel1d_y = get_gaussian_kernel1d(size[1], sigma[1]) | 
|  | 121 | +            return self.backend.numpy.tensordot(kernel1d_y, kernel1d_x, axes=0) | 
|  | 122 | + | 
|  | 123 | +        kernel = get_gaussian_kernel2d(kernel_size, sigma) | 
|  | 124 | + | 
|  | 125 | +        kernel = self.backend.numpy.reshape( | 
|  | 126 | +            kernel, (kernel_size[0], kernel_size[1], 1, 1) | 
|  | 127 | +        ) | 
|  | 128 | +        kernel = self.backend.numpy.tile(kernel, [1, 1, num_channels, 1]) | 
|  | 129 | + | 
|  | 130 | +        kernel = self.backend.cast(kernel, self.compute_dtype) | 
|  | 131 | + | 
|  | 132 | +        return kernel | 
|  | 133 | + | 
|  | 134 | +    def get_random_transformation(self, data, training=True, seed=None): | 
|  | 135 | +        if not training: | 
|  | 136 | +            return None | 
|  | 137 | + | 
|  | 138 | +        if isinstance(data, dict): | 
|  | 139 | +            images = data["images"] | 
|  | 140 | +        else: | 
|  | 141 | +            images = data | 
|  | 142 | + | 
|  | 143 | +        images_shape = self.backend.shape(images) | 
|  | 144 | +        rank = len(images_shape) | 
|  | 145 | +        if rank == 3: | 
|  | 146 | +            batch_size = 1 | 
|  | 147 | +        elif rank == 4: | 
|  | 148 | +            batch_size = images_shape[0] | 
|  | 149 | +        else: | 
|  | 150 | +            raise ValueError( | 
|  | 151 | +                "Expected the input image to be rank 3 or 4. Received " | 
|  | 152 | +                f"inputs.shape={images_shape}" | 
|  | 153 | +            ) | 
|  | 154 | + | 
|  | 155 | +        seed = seed or self._get_seed_generator(self.backend._backend) | 
|  | 156 | + | 
|  | 157 | +        blur_probability = self.backend.random.uniform( | 
|  | 158 | +            shape=(batch_size,), | 
|  | 159 | +            minval=self.factor[0], | 
|  | 160 | +            maxval=self.factor[1], | 
|  | 161 | +            seed=seed, | 
|  | 162 | +        ) | 
|  | 163 | + | 
|  | 164 | +        random_threshold = self.backend.random.uniform( | 
|  | 165 | +            shape=(batch_size,), | 
|  | 166 | +            minval=0.0, | 
|  | 167 | +            maxval=1.0, | 
|  | 168 | +            seed=seed, | 
|  | 169 | +        ) | 
|  | 170 | +        should_apply_blur = random_threshold < blur_probability | 
|  | 171 | + | 
|  | 172 | +        blur_factor = ( | 
|  | 173 | +            self.backend.random.uniform( | 
|  | 174 | +                shape=(2,), | 
|  | 175 | +                minval=self.sigma[0], | 
|  | 176 | +                maxval=self.sigma[1], | 
|  | 177 | +                seed=seed, | 
|  | 178 | +                dtype=self.compute_dtype, | 
|  | 179 | +            ) | 
|  | 180 | +            + 1e-6 | 
|  | 181 | +        ) | 
|  | 182 | + | 
|  | 183 | +        return { | 
|  | 184 | +            "should_apply_blur": should_apply_blur, | 
|  | 185 | +            "blur_factor": blur_factor, | 
|  | 186 | +        } | 
|  | 187 | + | 
|  | 188 | +    def transform_images(self, images, transformation=None, training=True): | 
|  | 189 | +        images = self.backend.cast(images, self.compute_dtype) | 
|  | 190 | +        if training and transformation is not None: | 
|  | 191 | +            if self.data_format == "channels_first": | 
|  | 192 | +                images = self.backend.numpy.swapaxes(images, -3, -1) | 
|  | 193 | + | 
|  | 194 | +            blur_factor = transformation["blur_factor"] | 
|  | 195 | +            should_apply_blur = transformation["should_apply_blur"] | 
|  | 196 | + | 
|  | 197 | +            kernel = self.create_gaussian_kernel( | 
|  | 198 | +                self.kernel_size, | 
|  | 199 | +                blur_factor, | 
|  | 200 | +                self.backend.shape(images)[-1], | 
|  | 201 | +            ) | 
|  | 202 | + | 
|  | 203 | +            blur_images = self.backend.nn.depthwise_conv( | 
|  | 204 | +                images, | 
|  | 205 | +                kernel, | 
|  | 206 | +                strides=1, | 
|  | 207 | +                padding="same", | 
|  | 208 | +                data_format="channels_last", | 
|  | 209 | +            ) | 
|  | 210 | + | 
|  | 211 | +            images = self.backend.numpy.where( | 
|  | 212 | +                should_apply_blur[:, None, None, None], | 
|  | 213 | +                blur_images, | 
|  | 214 | +                images, | 
|  | 215 | +            ) | 
|  | 216 | + | 
|  | 217 | +            images = self.backend.numpy.clip( | 
|  | 218 | +                images, self.value_range[0], self.value_range[1] | 
|  | 219 | +            ) | 
|  | 220 | + | 
|  | 221 | +            if self.data_format == "channels_first": | 
|  | 222 | +                images = self.backend.numpy.swapaxes(images, -3, -1) | 
|  | 223 | + | 
|  | 224 | +            images = self.backend.cast(images, dtype=self.compute_dtype) | 
|  | 225 | + | 
|  | 226 | +        return images | 
|  | 227 | + | 
|  | 228 | +    def transform_labels(self, labels, transformation, training=True): | 
|  | 229 | +        return labels | 
|  | 230 | + | 
|  | 231 | +    def transform_segmentation_masks( | 
|  | 232 | +        self, segmentation_masks, transformation, training=True | 
|  | 233 | +    ): | 
|  | 234 | +        return segmentation_masks | 
|  | 235 | + | 
|  | 236 | +    def transform_bounding_boxes( | 
|  | 237 | +        self, bounding_boxes, transformation, training=True | 
|  | 238 | +    ): | 
|  | 239 | +        return bounding_boxes | 
|  | 240 | + | 
|  | 241 | +    def compute_output_shape(self, input_shape): | 
|  | 242 | +        return input_shape | 
|  | 243 | + | 
|  | 244 | +    def get_config(self): | 
|  | 245 | +        config = super().get_config() | 
|  | 246 | +        config.update( | 
|  | 247 | +            { | 
|  | 248 | +                "factor": self.factor, | 
|  | 249 | +                "kernel_size": self.kernel_size, | 
|  | 250 | +                "sigma": self.sigma, | 
|  | 251 | +                "value_range": self.value_range, | 
|  | 252 | +                "seed": self.seed, | 
|  | 253 | +            } | 
|  | 254 | +        ) | 
|  | 255 | +        return config | 
0 commit comments