Skip to content

[ViT] Vision Transformer (ViT) backbone, layers, and image classifier #1989

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
741b889
vit base
sineeli Nov 14, 2024
13dae08
Add vit backbone, classifier and preprocessor layers
sineeli Nov 15, 2024
b64b137
update args
sineeli Nov 15, 2024
429d635
add default args
sineeli Nov 15, 2024
6d69abc
correct build method
sineeli Nov 15, 2024
2e87884
fix build issues
sineeli Nov 15, 2024
bd3cce0
fix bugs
sineeli Nov 16, 2024
4232a06
Update backbone args and configs
sineeli Nov 18, 2024
32b08c5
correct position ids dtype
sineeli Nov 18, 2024
cc938c6
build token layer
sineeli Nov 18, 2024
78812de
token layer build
sineeli Nov 18, 2024
8a20465
assign correct dtype to TokenLayer
sineeli Nov 18, 2024
de754cc
fix build shape of token layer
sineeli Nov 18, 2024
84ba896
correct mlp dens var names
sineeli Nov 18, 2024
7a70e16
use default norm mean and std as per hugging face config
sineeli Nov 18, 2024
81e3021
correct position_ids
sineeli Nov 19, 2024
d3061d6
remove separate token layer
sineeli Nov 19, 2024
618e163
correct position ids
sineeli Nov 19, 2024
2338637
Checkpoint conversion script and minor changes
sineeli Nov 21, 2024
95e5868
correct flag type
sineeli Nov 21, 2024
9d2e5bd
correct key name
sineeli Nov 21, 2024
ac7d1d3
use flat list later we can extract in between layers if needed
sineeli Nov 21, 2024
8065c01
Add test cases and correct dtype polciy for model
sineeli Nov 21, 2024
a8be824
add proper docstrings
sineeli Nov 21, 2024
3f027a0
correct test cases
sineeli Nov 22, 2024
05acb70
use numpy for test data
sineeli Nov 25, 2024
521df6f
nit
sineeli Nov 25, 2024
ae2b800
nit
sineeli Nov 27, 2024
26c2224
Merge branch 'master' into sineeli/ViT
sineeli Dec 2, 2024
92149d5
add presets
sineeli Dec 2, 2024
5374c70
load vit preset from hugging face directly
sineeli Dec 5, 2024
ebee9ef
nit
sineeli Dec 5, 2024
93064bd
handle num classes case for ViT
sineeli Dec 5, 2024
e206e7b
replace toke with first
sineeli Dec 9, 2024
7a39d5b
convert all vit checkpoints using tools
sineeli Dec 10, 2024
0827954
Add custom ImageClassifier for ViT
sineeli Dec 10, 2024
ae9319a
remove token pooling and rename representation_size to intermediate_dim
sineeli Dec 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras_hub/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
SegFormerImageConverter,
)
from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter
from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter
from keras_hub.src.models.whisper.whisper_audio_converter import (
WhisperAudioConverter,
)
5 changes: 5 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,11 @@
from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import (
VGGImageClassifierPreprocessor,
)
from keras_hub.src.models.vit.vit_backbone import ViTBackbone
from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier
from keras_hub.src.models.vit.vit_image_classifier_preprocessor import (
ViTImageClassifierPreprocessor,
)
from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone
from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone
from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer
Expand Down
5 changes: 5 additions & 0 deletions keras_hub/src/models/vit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from keras_hub.src.models.vit.vit_backbone import ViTBackbone
from keras_hub.src.models.vit.vit_presets import backbone_presets
from keras_hub.src.utils.preset_utils import register_presets

register_presets(backbone_presets, ViTBackbone)
152 changes: 152 additions & 0 deletions keras_hub/src/models/vit/vit_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone
from keras_hub.src.models.vit.vit_layers import ViTEncoder
from keras_hub.src.models.vit.vit_layers import ViTPatchingAndEmbedding
from keras_hub.src.utils.keras_utils import standardize_data_format


@keras_hub_export("keras_hub.models.ViTBackbone")
class ViTBackbone(Backbone):
"""Vision Transformer (ViT) backbone.

This backbone implements the Vision Transformer architecture as described in
[An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929).
It transforms the input image into a sequence of patches, embeds them, and
then processes them through a series of Transformer encoder layers.

Args:
image_shape: A tuple or list of 3 integers representing the shape of the
input image `(height, width, channels)`, `height` and `width` must
be equal.
patch_size: int. The size of each image patch, the input image will be
divided into patches of shape `(patch_size, patch_size)`.
num_layers: int. The number of transformer encoder layers.
num_heads: int. specifying the number of attention heads in each
Transformer encoder layer.
hidden_dim: int. The dimensionality of the hidden representations.
mlp_dim: int. The dimensionality of the intermediate MLP layer in
each Transformer encoder layer.
dropout_rate: float. The dropout rate for the Transformer encoder
layers.
attention_dropout: float. The dropout rate for the attention mechanism
in each Transformer encoder layer.
layer_norm_epsilon: float. Value used for numerical stability in
layer normalization.
use_mha_bias: bool. Whether to use bias in the multi-head
attention layers.
use_mlp_bias: bool. Whether to use bias in the MLP layers.
data_format: str. `"channels_last"` or `"channels_first"`, specifying
the data format for the input image. If `None`, defaults to
`"channels_last"`.
dtype: The dtype of the layer weights. Defaults to None.
**kwargs: Additional keyword arguments to be passed to the parent
`Backbone` class.
"""

def __init__(
self,
image_shape,
patch_size,
num_layers,
num_heads,
hidden_dim,
mlp_dim,
dropout_rate=0.0,
attention_dropout=0.0,
layer_norm_epsilon=1e-6,
use_mha_bias=True,
use_mlp_bias=True,
data_format=None,
dtype=None,
**kwargs,
):
# === Laters ===
data_format = standardize_data_format(data_format)
h_axis, w_axis, channels_axis = (
(-3, -2, -1) if data_format == "channels_last" else (-2, -1, -3)
)
# Check that the input image is well specified.
if image_shape[h_axis] is None or image_shape[w_axis] is None:
raise ValueError(
f"Image shape must have defined height and width. Found `None` "
f"at index {h_axis} (height) or {w_axis} (width). "
f"Image shape: {image_shape}"
)
if image_shape[h_axis] != image_shape[w_axis]:
raise ValueError(
f"Image height and width must be equal. Found height: "
f"{image_shape[h_axis]}, width: {image_shape[w_axis]} at "
f"indices {h_axis} and {w_axis} respectively. Image shape: "
f"{image_shape}"
)

num_channels = image_shape[channels_axis]

# === Functional Model ===
inputs = keras.layers.Input(shape=image_shape)

x = ViTPatchingAndEmbedding(
image_size=image_shape[h_axis],
patch_size=patch_size,
hidden_dim=hidden_dim,
num_channels=num_channels,
data_format=data_format,
dtype=dtype,
name="vit_patching_and_embedding",
)(inputs)

output = ViTEncoder(
num_layers=num_layers,
num_heads=num_heads,
hidden_dim=hidden_dim,
mlp_dim=mlp_dim,
dropout_rate=dropout_rate,
attention_dropout=attention_dropout,
layer_norm_epsilon=layer_norm_epsilon,
use_mha_bias=use_mha_bias,
use_mlp_bias=use_mlp_bias,
dtype=dtype,
name="vit_encoder",
)(x)

super().__init__(
inputs=inputs,
outputs=output,
dtype=dtype,
**kwargs,
)

# === Config ===
self.image_shape = image_shape
self.patch_size = patch_size
self.num_layers = num_layers
self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.mlp_dim = mlp_dim
self.dropout_rate = dropout_rate
self.attention_dropout = attention_dropout
self.layer_norm_epsilon = layer_norm_epsilon
self.use_mha_bias = use_mha_bias
self.use_mlp_bias = use_mlp_bias
self.data_format = data_format

def get_config(self):
config = super().get_config()
config.update(
{
"image_shape": self.image_shape,
"patch_size": self.patch_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"mlp_dim": self.mlp_dim,
"dropout_rate": self.dropout_rate,
"attention_dropout": self.attention_dropout,
"layer_norm_epsilon": self.layer_norm_epsilon,
"use_mha_bias": self.use_mha_bias,
"use_mlp_bias": self.use_mlp_bias,
}
)
return config
37 changes: 37 additions & 0 deletions keras_hub/src/models/vit/vit_backbone_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest
from keras import ops

from keras_hub.src.models.vit.vit_backbone import ViTBackbone
from keras_hub.src.tests.test_case import TestCase


class ViTBackboneTest(TestCase):
def setUp(self):
self.init_kwargs = {
"image_shape": (28, 28, 3),
"patch_size": 4,
"num_layers": 3,
"hidden_dim": 48,
"num_heads": 6,
"mlp_dim": 48 * 4,
"use_mha_bias": True,
}
self.input_size = 28
self.input_data = ops.ones((2, self.input_size, self.input_size, 3))

def test_backbone_basics(self):
self.run_backbone_test(
cls=ViTBackbone,
init_kwargs={**self.init_kwargs},
input_data=self.input_data,
expected_output_shape=(2, 50, 48),
run_quantization_check=False,
)

@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
cls=ViTBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)
Loading
Loading