Skip to content

Commit 15564ca

Browse files
authored
[ViT] Vision Transformer (ViT) backbone, layers, and image classifier (#1989)
* vit base * Add vit backbone, classifier and preprocessor layers * update args * add default args * correct build method * fix build issues * fix bugs * Update backbone args and configs * correct position ids dtype * build token layer * token layer build * assign correct dtype to TokenLayer * fix build shape of token layer * correct mlp dens var names * use default norm mean and std as per hugging face config * correct position_ids * remove separate token layer * correct position ids * Checkpoint conversion script and minor changes * correct flag type * correct key name * use flat list later we can extract in between layers if needed * Add test cases and correct dtype polciy for model * add proper docstrings * correct test cases * use numpy for test data * nit * nit * add presets * load vit preset from hugging face directly * nit * handle num classes case for ViT * replace toke with first * convert all vit checkpoints using tools * Add custom ImageClassifier for ViT * remove token pooling and rename representation_size to intermediate_dim
1 parent 5180e78 commit 15564ca

File tree

14 files changed

+1513
-0
lines changed

14 files changed

+1513
-0
lines changed

keras_hub/api/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
SegFormerImageConverter,
6464
)
6565
from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter
66+
from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter
6667
from keras_hub.src.models.whisper.whisper_audio_converter import (
6768
WhisperAudioConverter,
6869
)

keras_hub/api/models/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,11 @@
330330
from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import (
331331
VGGImageClassifierPreprocessor,
332332
)
333+
from keras_hub.src.models.vit.vit_backbone import ViTBackbone
334+
from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier
335+
from keras_hub.src.models.vit.vit_image_classifier_preprocessor import (
336+
ViTImageClassifierPreprocessor,
337+
)
333338
from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone
334339
from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone
335340
from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from keras_hub.src.models.vit.vit_backbone import ViTBackbone
2+
from keras_hub.src.models.vit.vit_presets import backbone_presets
3+
from keras_hub.src.utils.preset_utils import register_presets
4+
5+
register_presets(backbone_presets, ViTBackbone)
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import keras
2+
3+
from keras_hub.src.api_export import keras_hub_export
4+
from keras_hub.src.models.backbone import Backbone
5+
from keras_hub.src.models.vit.vit_layers import ViTEncoder
6+
from keras_hub.src.models.vit.vit_layers import ViTPatchingAndEmbedding
7+
from keras_hub.src.utils.keras_utils import standardize_data_format
8+
9+
10+
@keras_hub_export("keras_hub.models.ViTBackbone")
11+
class ViTBackbone(Backbone):
12+
"""Vision Transformer (ViT) backbone.
13+
14+
This backbone implements the Vision Transformer architecture as described in
15+
[An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929).
16+
It transforms the input image into a sequence of patches, embeds them, and
17+
then processes them through a series of Transformer encoder layers.
18+
19+
Args:
20+
image_shape: A tuple or list of 3 integers representing the shape of the
21+
input image `(height, width, channels)`, `height` and `width` must
22+
be equal.
23+
patch_size: int. The size of each image patch, the input image will be
24+
divided into patches of shape `(patch_size, patch_size)`.
25+
num_layers: int. The number of transformer encoder layers.
26+
num_heads: int. specifying the number of attention heads in each
27+
Transformer encoder layer.
28+
hidden_dim: int. The dimensionality of the hidden representations.
29+
mlp_dim: int. The dimensionality of the intermediate MLP layer in
30+
each Transformer encoder layer.
31+
dropout_rate: float. The dropout rate for the Transformer encoder
32+
layers.
33+
attention_dropout: float. The dropout rate for the attention mechanism
34+
in each Transformer encoder layer.
35+
layer_norm_epsilon: float. Value used for numerical stability in
36+
layer normalization.
37+
use_mha_bias: bool. Whether to use bias in the multi-head
38+
attention layers.
39+
use_mlp_bias: bool. Whether to use bias in the MLP layers.
40+
data_format: str. `"channels_last"` or `"channels_first"`, specifying
41+
the data format for the input image. If `None`, defaults to
42+
`"channels_last"`.
43+
dtype: The dtype of the layer weights. Defaults to None.
44+
**kwargs: Additional keyword arguments to be passed to the parent
45+
`Backbone` class.
46+
"""
47+
48+
def __init__(
49+
self,
50+
image_shape,
51+
patch_size,
52+
num_layers,
53+
num_heads,
54+
hidden_dim,
55+
mlp_dim,
56+
dropout_rate=0.0,
57+
attention_dropout=0.0,
58+
layer_norm_epsilon=1e-6,
59+
use_mha_bias=True,
60+
use_mlp_bias=True,
61+
data_format=None,
62+
dtype=None,
63+
**kwargs,
64+
):
65+
# === Laters ===
66+
data_format = standardize_data_format(data_format)
67+
h_axis, w_axis, channels_axis = (
68+
(-3, -2, -1) if data_format == "channels_last" else (-2, -1, -3)
69+
)
70+
# Check that the input image is well specified.
71+
if image_shape[h_axis] is None or image_shape[w_axis] is None:
72+
raise ValueError(
73+
f"Image shape must have defined height and width. Found `None` "
74+
f"at index {h_axis} (height) or {w_axis} (width). "
75+
f"Image shape: {image_shape}"
76+
)
77+
if image_shape[h_axis] != image_shape[w_axis]:
78+
raise ValueError(
79+
f"Image height and width must be equal. Found height: "
80+
f"{image_shape[h_axis]}, width: {image_shape[w_axis]} at "
81+
f"indices {h_axis} and {w_axis} respectively. Image shape: "
82+
f"{image_shape}"
83+
)
84+
85+
num_channels = image_shape[channels_axis]
86+
87+
# === Functional Model ===
88+
inputs = keras.layers.Input(shape=image_shape)
89+
90+
x = ViTPatchingAndEmbedding(
91+
image_size=image_shape[h_axis],
92+
patch_size=patch_size,
93+
hidden_dim=hidden_dim,
94+
num_channels=num_channels,
95+
data_format=data_format,
96+
dtype=dtype,
97+
name="vit_patching_and_embedding",
98+
)(inputs)
99+
100+
output = ViTEncoder(
101+
num_layers=num_layers,
102+
num_heads=num_heads,
103+
hidden_dim=hidden_dim,
104+
mlp_dim=mlp_dim,
105+
dropout_rate=dropout_rate,
106+
attention_dropout=attention_dropout,
107+
layer_norm_epsilon=layer_norm_epsilon,
108+
use_mha_bias=use_mha_bias,
109+
use_mlp_bias=use_mlp_bias,
110+
dtype=dtype,
111+
name="vit_encoder",
112+
)(x)
113+
114+
super().__init__(
115+
inputs=inputs,
116+
outputs=output,
117+
dtype=dtype,
118+
**kwargs,
119+
)
120+
121+
# === Config ===
122+
self.image_shape = image_shape
123+
self.patch_size = patch_size
124+
self.num_layers = num_layers
125+
self.num_heads = num_heads
126+
self.hidden_dim = hidden_dim
127+
self.mlp_dim = mlp_dim
128+
self.dropout_rate = dropout_rate
129+
self.attention_dropout = attention_dropout
130+
self.layer_norm_epsilon = layer_norm_epsilon
131+
self.use_mha_bias = use_mha_bias
132+
self.use_mlp_bias = use_mlp_bias
133+
self.data_format = data_format
134+
135+
def get_config(self):
136+
config = super().get_config()
137+
config.update(
138+
{
139+
"image_shape": self.image_shape,
140+
"patch_size": self.patch_size,
141+
"num_layers": self.num_layers,
142+
"num_heads": self.num_heads,
143+
"hidden_dim": self.hidden_dim,
144+
"mlp_dim": self.mlp_dim,
145+
"dropout_rate": self.dropout_rate,
146+
"attention_dropout": self.attention_dropout,
147+
"layer_norm_epsilon": self.layer_norm_epsilon,
148+
"use_mha_bias": self.use_mha_bias,
149+
"use_mlp_bias": self.use_mlp_bias,
150+
}
151+
)
152+
return config
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import pytest
2+
from keras import ops
3+
4+
from keras_hub.src.models.vit.vit_backbone import ViTBackbone
5+
from keras_hub.src.tests.test_case import TestCase
6+
7+
8+
class ViTBackboneTest(TestCase):
9+
def setUp(self):
10+
self.init_kwargs = {
11+
"image_shape": (28, 28, 3),
12+
"patch_size": 4,
13+
"num_layers": 3,
14+
"hidden_dim": 48,
15+
"num_heads": 6,
16+
"mlp_dim": 48 * 4,
17+
"use_mha_bias": True,
18+
}
19+
self.input_size = 28
20+
self.input_data = ops.ones((2, self.input_size, self.input_size, 3))
21+
22+
def test_backbone_basics(self):
23+
self.run_backbone_test(
24+
cls=ViTBackbone,
25+
init_kwargs={**self.init_kwargs},
26+
input_data=self.input_data,
27+
expected_output_shape=(2, 50, 48),
28+
run_quantization_check=False,
29+
)
30+
31+
@pytest.mark.large
32+
def test_saved_model(self):
33+
self.run_model_saving_test(
34+
cls=ViTBackbone,
35+
init_kwargs=self.init_kwargs,
36+
input_data=self.input_data,
37+
)

0 commit comments

Comments
 (0)