|
| 1 | +# Copyright 2024 The KerasCV Authors |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import keras |
| 16 | +from keras import ops |
| 17 | + |
| 18 | +from keras_nlp.src.api_export import keras_nlp_export |
| 19 | +from keras_nlp.src.models.backbone import Backbone |
| 20 | +from keras_nlp.src.models.vit_det.vit_layers import AddPositionalEmbedding |
| 21 | +from keras_nlp.src.models.vit_det.vit_layers import ViTDetPatchingAndEmbedding |
| 22 | +from keras_nlp.src.models.vit_det.vit_layers import WindowedTransformerEncoder |
| 23 | + |
| 24 | + |
| 25 | +@keras_nlp_export("keras_nlp.models.ViTDetBackbone") |
| 26 | +class ViTDetBackbone(Backbone): |
| 27 | + """An implementation of ViT image encoder. |
| 28 | +
|
| 29 | + The ViTDetBackbone uses a windowed transformer encoder and relative |
| 30 | + positional encodings. The code has been adapted from [Segment Anything |
| 31 | + paper](https://arxiv.org/abs/2304.02643), [Segment Anything GitHub]( |
| 32 | + https://github.com/facebookresearch/segment-anything) and [Detectron2]( |
| 33 | + https://github.com/facebookresearch/detectron2). |
| 34 | +
|
| 35 | + Args: |
| 36 | + hidden_size (int): The latent dimensionality to be projected |
| 37 | + into in the output of each stacked windowed transformer encoder. |
| 38 | + num_layers (int): The number of transformer encoder layers to |
| 39 | + stack in the Vision Transformer. |
| 40 | + intermediate_dim (int): The dimensionality of the hidden Dense |
| 41 | + layer in the transformer MLP head. |
| 42 | + num_heads (int): the number of heads to use in the |
| 43 | + `MultiHeadAttentionWithRelativePE` layer of each transformer |
| 44 | + encoder. |
| 45 | + global_attention_layer_indices (list): Indexes for blocks using |
| 46 | + global attention. |
| 47 | + image_shape (tuple[int], optional): The size of the input image in |
| 48 | + `(H, W, C)` format. Defaults to `(1024, 1024, 3)`. |
| 49 | + include_rescaling (bool, optional): Whether to rescale the inputs. If |
| 50 | + set to `True`, inputs will be passed through a |
| 51 | + `Rescaling(1/255.0)` layer. Defaults to `False`. |
| 52 | + patch_size (int, optional): the patch size to be supplied to the |
| 53 | + Patching layer to turn input images into a flattened sequence of |
| 54 | + patches. Defaults to `16`. |
| 55 | + num_output_channels (int, optional): The number of channels (features) |
| 56 | + in the output (image encodings). Defaults to `256`. |
| 57 | + use_bias (bool, optional): Whether to use bias to project the keys, |
| 58 | + queries, and values in the attention layer. Defaults to `True`. |
| 59 | + use_abs_pos (bool, optional): Whether to add absolute positional |
| 60 | + embeddings to the output patches. Defaults to `True`. |
| 61 | + use_rel_pos (bool, optional): Whether to use relative positional |
| 62 | + emcodings in the attention layer. Defaults to `True`. |
| 63 | + window_size (int, optional): The size of the window for windowed |
| 64 | + attention in the transformer encoder blocks. Defaults to `14`. |
| 65 | + layer_norm_epsilon (int, optional): The epsilon to use in the layer |
| 66 | + normalization blocks in transformer encoder. Defaults to `1e-6`. |
| 67 | +
|
| 68 | + Examples: |
| 69 | + ```python |
| 70 | + input_data = np.ones((2, 224, 224, 3), dtype="float32") |
| 71 | +
|
| 72 | + # Pretrained ViTDetBackbone backbone. |
| 73 | + model = keras_nlp.models.ViTDetBackbone.from_preset("vit_det") |
| 74 | + model(input_data) |
| 75 | +
|
| 76 | + # Randomly initialized ViTDetBackbone backbone with a custom config. |
| 77 | + model = keras_nlp.models.ViTDetBackbone( |
| 78 | + image_shape = (16, 16, 3), |
| 79 | + patch_size = 2, |
| 80 | + hidden_size = 4, |
| 81 | + num_layers = 2, |
| 82 | + global_attention_layer_indices = [2, 5, 8, 11], |
| 83 | + intermediate_dim = 4 * 4, |
| 84 | + num_heads = 2, |
| 85 | + num_output_channels = 2, |
| 86 | + window_size = 2, |
| 87 | + ) |
| 88 | + model(input_data) |
| 89 | + ``` |
| 90 | + """ |
| 91 | + |
| 92 | + def __init__( |
| 93 | + self, |
| 94 | + hidden_size, |
| 95 | + num_layers, |
| 96 | + intermediate_dim, |
| 97 | + num_heads, |
| 98 | + global_attention_layer_indices, |
| 99 | + include_rescaling=True, |
| 100 | + image_shape=(1024, 1024, 3), |
| 101 | + patch_size=16, |
| 102 | + num_output_channels=256, |
| 103 | + use_bias=True, |
| 104 | + use_abs_pos=True, |
| 105 | + use_rel_pos=True, |
| 106 | + window_size=14, |
| 107 | + layer_norm_epsilon=1e-6, |
| 108 | + **kwargs |
| 109 | + ): |
| 110 | + # === Functional model === |
| 111 | + img_input = keras.layers.Input(shape=image_shape) |
| 112 | + # Check that the input image is well specified. |
| 113 | + if img_input.shape[-3] is None or img_input.shape[-2] is None: |
| 114 | + raise ValueError( |
| 115 | + "Height and width of the image must be specified" |
| 116 | + " in `image_shape`." |
| 117 | + ) |
| 118 | + if img_input.shape[-3] != img_input.shape[-2]: |
| 119 | + raise ValueError( |
| 120 | + "Input image must be square i.e. the height must" |
| 121 | + " be equal to the width in the `image_shape`" |
| 122 | + " tuple/tensor." |
| 123 | + ) |
| 124 | + img_size = img_input.shape[-3] |
| 125 | + x = img_input |
| 126 | + if include_rescaling: |
| 127 | + # Use common rescaling strategy across keras_cv |
| 128 | + x = keras.layers.Rescaling(1.0 / 255.0)(x) |
| 129 | + # VITDet scales inputs based on the standard ImageNet mean/stddev. |
| 130 | + x = (x - ops.array([0.485, 0.456, 0.406], dtype=x.dtype)) / ( |
| 131 | + ops.array([0.229, 0.224, 0.225], dtype=x.dtype) |
| 132 | + ) |
| 133 | + x = ViTDetPatchingAndEmbedding( |
| 134 | + kernel_size=(patch_size, patch_size), |
| 135 | + strides=(patch_size, patch_size), |
| 136 | + embed_dim=hidden_size, |
| 137 | + )(x) |
| 138 | + if use_abs_pos: |
| 139 | + x = AddPositionalEmbedding(img_size, patch_size, hidden_size)(x) |
| 140 | + for i in range(num_layers): |
| 141 | + x = WindowedTransformerEncoder( |
| 142 | + project_dim=hidden_size, |
| 143 | + intermediate_dim=intermediate_dim, |
| 144 | + num_heads=num_heads, |
| 145 | + use_bias=use_bias, |
| 146 | + use_rel_pos=use_rel_pos, |
| 147 | + window_size=( |
| 148 | + window_size |
| 149 | + if i not in global_attention_layer_indices |
| 150 | + else 0 |
| 151 | + ), |
| 152 | + input_size=(img_size // patch_size, img_size // patch_size), |
| 153 | + )(x) |
| 154 | + x = keras.layers.Conv2D( |
| 155 | + filters=num_output_channels, kernel_size=1, use_bias=False |
| 156 | + )(x) |
| 157 | + x = keras.layers.LayerNormalization(epsilon=1e-6)(x) |
| 158 | + x = keras.layers.Conv2D( |
| 159 | + filters=num_output_channels, |
| 160 | + kernel_size=3, |
| 161 | + padding="same", |
| 162 | + use_bias=False, |
| 163 | + )(x) |
| 164 | + x = keras.layers.LayerNormalization(epsilon=1e-6)(x) |
| 165 | + |
| 166 | + super().__init__(inputs=img_input, outputs=x, **kwargs) |
| 167 | + |
| 168 | + # === Config === |
| 169 | + self.patch_size = patch_size |
| 170 | + self.image_shape = image_shape |
| 171 | + self.hidden_size = hidden_size |
| 172 | + self.num_layers = num_layers |
| 173 | + self.intermediate_dim = intermediate_dim |
| 174 | + self.num_heads = num_heads |
| 175 | + self.num_output_channels = num_output_channels |
| 176 | + self.use_bias = use_bias |
| 177 | + self.use_rel_pos = use_rel_pos |
| 178 | + self.use_abs_pos = use_abs_pos |
| 179 | + self.window_size = window_size |
| 180 | + self.global_attention_layer_indices = global_attention_layer_indices |
| 181 | + self.layer_norm_epsilon = layer_norm_epsilon |
| 182 | + self.include_rescaling = include_rescaling |
| 183 | + |
| 184 | + def get_config(self): |
| 185 | + config = super().get_config() |
| 186 | + config.update( |
| 187 | + { |
| 188 | + "image_shape": self.image_shape, |
| 189 | + "include_rescaling": self.include_rescaling, |
| 190 | + "patch_size": self.patch_size, |
| 191 | + "hidden_size": self.hidden_size, |
| 192 | + "num_layers": self.num_layers, |
| 193 | + "intermediate_dim": self.intermediate_dim, |
| 194 | + "num_heads": self.num_heads, |
| 195 | + "num_output_channels": self.num_output_channels, |
| 196 | + "use_bias": self.use_bias, |
| 197 | + "use_abs_pos": self.use_abs_pos, |
| 198 | + "use_rel_pos": self.use_rel_pos, |
| 199 | + "window_size": self.window_size, |
| 200 | + "global_attention_layer_indices": self.global_attention_layer_indices, |
| 201 | + "layer_norm_epsilon": self.layer_norm_epsilon, |
| 202 | + } |
| 203 | + ) |
| 204 | + return config |
0 commit comments