|
| 1 | +import keras |
| 2 | + |
| 3 | +from keras_hub.src.api_export import keras_hub_export |
| 4 | +from keras_hub.src.models.backbone import Backbone |
| 5 | +from keras_hub.src.models.vit.vit_layers import ViTEncoder |
| 6 | +from keras_hub.src.models.vit.vit_layers import ViTPatchingAndEmbedding |
| 7 | +from keras_hub.src.utils.keras_utils import standardize_data_format |
| 8 | + |
| 9 | + |
| 10 | +@keras_hub_export("keras_hub.models.ViTBackbone") |
| 11 | +class ViTBackbone(Backbone): |
| 12 | + """Vision Transformer (ViT) backbone. |
| 13 | +
|
| 14 | + This backbone implements the Vision Transformer architecture as described in |
| 15 | + [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929). |
| 16 | + It transforms the input image into a sequence of patches, embeds them, and |
| 17 | + then processes them through a series of Transformer encoder layers. |
| 18 | +
|
| 19 | + Args: |
| 20 | + image_shape: A tuple or list of 3 integers representing the shape of the |
| 21 | + input image `(height, width, channels)`, `height` and `width` must |
| 22 | + be equal. |
| 23 | + patch_size: int. The size of each image patch, the input image will be |
| 24 | + divided into patches of shape `(patch_size, patch_size)`. |
| 25 | + num_layers: int. The number of transformer encoder layers. |
| 26 | + num_heads: int. specifying the number of attention heads in each |
| 27 | + Transformer encoder layer. |
| 28 | + hidden_dim: int. The dimensionality of the hidden representations. |
| 29 | + mlp_dim: int. The dimensionality of the intermediate MLP layer in |
| 30 | + each Transformer encoder layer. |
| 31 | + dropout_rate: float. The dropout rate for the Transformer encoder |
| 32 | + layers. |
| 33 | + attention_dropout: float. The dropout rate for the attention mechanism |
| 34 | + in each Transformer encoder layer. |
| 35 | + layer_norm_epsilon: float. Value used for numerical stability in |
| 36 | + layer normalization. |
| 37 | + use_mha_bias: bool. Whether to use bias in the multi-head |
| 38 | + attention layers. |
| 39 | + use_mlp_bias: bool. Whether to use bias in the MLP layers. |
| 40 | + data_format: str. `"channels_last"` or `"channels_first"`, specifying |
| 41 | + the data format for the input image. If `None`, defaults to |
| 42 | + `"channels_last"`. |
| 43 | + dtype: The dtype of the layer weights. Defaults to None. |
| 44 | + **kwargs: Additional keyword arguments to be passed to the parent |
| 45 | + `Backbone` class. |
| 46 | + """ |
| 47 | + |
| 48 | + def __init__( |
| 49 | + self, |
| 50 | + image_shape, |
| 51 | + patch_size, |
| 52 | + num_layers, |
| 53 | + num_heads, |
| 54 | + hidden_dim, |
| 55 | + mlp_dim, |
| 56 | + dropout_rate=0.0, |
| 57 | + attention_dropout=0.0, |
| 58 | + layer_norm_epsilon=1e-6, |
| 59 | + use_mha_bias=True, |
| 60 | + use_mlp_bias=True, |
| 61 | + data_format=None, |
| 62 | + dtype=None, |
| 63 | + **kwargs, |
| 64 | + ): |
| 65 | + # === Laters === |
| 66 | + data_format = standardize_data_format(data_format) |
| 67 | + h_axis, w_axis, channels_axis = ( |
| 68 | + (-3, -2, -1) if data_format == "channels_last" else (-2, -1, -3) |
| 69 | + ) |
| 70 | + # Check that the input image is well specified. |
| 71 | + if image_shape[h_axis] is None or image_shape[w_axis] is None: |
| 72 | + raise ValueError( |
| 73 | + f"Image shape must have defined height and width. Found `None` " |
| 74 | + f"at index {h_axis} (height) or {w_axis} (width). " |
| 75 | + f"Image shape: {image_shape}" |
| 76 | + ) |
| 77 | + if image_shape[h_axis] != image_shape[w_axis]: |
| 78 | + raise ValueError( |
| 79 | + f"Image height and width must be equal. Found height: " |
| 80 | + f"{image_shape[h_axis]}, width: {image_shape[w_axis]} at " |
| 81 | + f"indices {h_axis} and {w_axis} respectively. Image shape: " |
| 82 | + f"{image_shape}" |
| 83 | + ) |
| 84 | + |
| 85 | + num_channels = image_shape[channels_axis] |
| 86 | + |
| 87 | + # === Functional Model === |
| 88 | + inputs = keras.layers.Input(shape=image_shape) |
| 89 | + |
| 90 | + x = ViTPatchingAndEmbedding( |
| 91 | + image_size=image_shape[h_axis], |
| 92 | + patch_size=patch_size, |
| 93 | + hidden_dim=hidden_dim, |
| 94 | + num_channels=num_channels, |
| 95 | + data_format=data_format, |
| 96 | + dtype=dtype, |
| 97 | + name="vit_patching_and_embedding", |
| 98 | + )(inputs) |
| 99 | + |
| 100 | + output = ViTEncoder( |
| 101 | + num_layers=num_layers, |
| 102 | + num_heads=num_heads, |
| 103 | + hidden_dim=hidden_dim, |
| 104 | + mlp_dim=mlp_dim, |
| 105 | + dropout_rate=dropout_rate, |
| 106 | + attention_dropout=attention_dropout, |
| 107 | + layer_norm_epsilon=layer_norm_epsilon, |
| 108 | + use_mha_bias=use_mha_bias, |
| 109 | + use_mlp_bias=use_mlp_bias, |
| 110 | + dtype=dtype, |
| 111 | + name="vit_encoder", |
| 112 | + )(x) |
| 113 | + |
| 114 | + super().__init__( |
| 115 | + inputs=inputs, |
| 116 | + outputs=output, |
| 117 | + dtype=dtype, |
| 118 | + **kwargs, |
| 119 | + ) |
| 120 | + |
| 121 | + # === Config === |
| 122 | + self.image_shape = image_shape |
| 123 | + self.patch_size = patch_size |
| 124 | + self.num_layers = num_layers |
| 125 | + self.num_heads = num_heads |
| 126 | + self.hidden_dim = hidden_dim |
| 127 | + self.mlp_dim = mlp_dim |
| 128 | + self.dropout_rate = dropout_rate |
| 129 | + self.attention_dropout = attention_dropout |
| 130 | + self.layer_norm_epsilon = layer_norm_epsilon |
| 131 | + self.use_mha_bias = use_mha_bias |
| 132 | + self.use_mlp_bias = use_mlp_bias |
| 133 | + self.data_format = data_format |
| 134 | + |
| 135 | + def get_config(self): |
| 136 | + config = super().get_config() |
| 137 | + config.update( |
| 138 | + { |
| 139 | + "image_shape": self.image_shape, |
| 140 | + "patch_size": self.patch_size, |
| 141 | + "num_layers": self.num_layers, |
| 142 | + "num_heads": self.num_heads, |
| 143 | + "hidden_dim": self.hidden_dim, |
| 144 | + "mlp_dim": self.mlp_dim, |
| 145 | + "dropout_rate": self.dropout_rate, |
| 146 | + "attention_dropout": self.attention_dropout, |
| 147 | + "layer_norm_epsilon": self.layer_norm_epsilon, |
| 148 | + "use_mha_bias": self.use_mha_bias, |
| 149 | + "use_mlp_bias": self.use_mlp_bias, |
| 150 | + } |
| 151 | + ) |
| 152 | + return config |
0 commit comments