-
Notifications
You must be signed in to change notification settings - Fork 289
[ViT] Vision Transformer (ViT) backbone, layers, and image classifier #1989
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
37 commits
Select commit
Hold shift + click to select a range
741b889
vit base
sineeli 13dae08
Add vit backbone, classifier and preprocessor layers
sineeli b64b137
update args
sineeli 429d635
add default args
sineeli 6d69abc
correct build method
sineeli 2e87884
fix build issues
sineeli bd3cce0
fix bugs
sineeli 4232a06
Update backbone args and configs
sineeli 32b08c5
correct position ids dtype
sineeli cc938c6
build token layer
sineeli 78812de
token layer build
sineeli 8a20465
assign correct dtype to TokenLayer
sineeli de754cc
fix build shape of token layer
sineeli 84ba896
correct mlp dens var names
sineeli 7a70e16
use default norm mean and std as per hugging face config
sineeli 81e3021
correct position_ids
sineeli d3061d6
remove separate token layer
sineeli 618e163
correct position ids
sineeli 2338637
Checkpoint conversion script and minor changes
sineeli 95e5868
correct flag type
sineeli 9d2e5bd
correct key name
sineeli ac7d1d3
use flat list later we can extract in between layers if needed
sineeli 8065c01
Add test cases and correct dtype polciy for model
sineeli a8be824
add proper docstrings
sineeli 3f027a0
correct test cases
sineeli 05acb70
use numpy for test data
sineeli 521df6f
nit
sineeli ae2b800
nit
sineeli 26c2224
Merge branch 'master' into sineeli/ViT
sineeli 92149d5
add presets
sineeli 5374c70
load vit preset from hugging face directly
sineeli ebee9ef
nit
sineeli 93064bd
handle num classes case for ViT
sineeli e206e7b
replace toke with first
sineeli 7a39d5b
convert all vit checkpoints using tools
sineeli 0827954
Add custom ImageClassifier for ViT
sineeli ae9319a
remove token pooling and rename representation_size to intermediate_dim
sineeli File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from keras_hub.src.models.vit.vit_backbone import ViTBackbone | ||
from keras_hub.src.models.vit.vit_presets import backbone_presets | ||
from keras_hub.src.utils.preset_utils import register_presets | ||
|
||
register_presets(backbone_presets, ViTBackbone) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
import keras | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.backbone import Backbone | ||
from keras_hub.src.models.vit.vit_layers import ViTEncoder | ||
from keras_hub.src.models.vit.vit_layers import ViTPatchingAndEmbedding | ||
from keras_hub.src.utils.keras_utils import standardize_data_format | ||
|
||
|
||
@keras_hub_export("keras_hub.models.ViTBackbone") | ||
class ViTBackbone(Backbone): | ||
"""Vision Transformer (ViT) backbone. | ||
|
||
This backbone implements the Vision Transformer architecture as described in | ||
[An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929). | ||
It transforms the input image into a sequence of patches, embeds them, and | ||
then processes them through a series of Transformer encoder layers. | ||
|
||
Args: | ||
image_shape: A tuple or list of 3 integers representing the shape of the | ||
input image `(height, width, channels)`, `height` and `width` must | ||
be equal. | ||
patch_size: int. The size of each image patch, the input image will be | ||
divided into patches of shape `(patch_size, patch_size)`. | ||
num_layers: int. The number of transformer encoder layers. | ||
num_heads: int. specifying the number of attention heads in each | ||
Transformer encoder layer. | ||
hidden_dim: int. The dimensionality of the hidden representations. | ||
mlp_dim: int. The dimensionality of the intermediate MLP layer in | ||
each Transformer encoder layer. | ||
dropout_rate: float. The dropout rate for the Transformer encoder | ||
layers. | ||
attention_dropout: float. The dropout rate for the attention mechanism | ||
in each Transformer encoder layer. | ||
layer_norm_epsilon: float. Value used for numerical stability in | ||
layer normalization. | ||
use_mha_bias: bool. Whether to use bias in the multi-head | ||
attention layers. | ||
use_mlp_bias: bool. Whether to use bias in the MLP layers. | ||
data_format: str. `"channels_last"` or `"channels_first"`, specifying | ||
the data format for the input image. If `None`, defaults to | ||
`"channels_last"`. | ||
dtype: The dtype of the layer weights. Defaults to None. | ||
**kwargs: Additional keyword arguments to be passed to the parent | ||
`Backbone` class. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
image_shape, | ||
patch_size, | ||
num_layers, | ||
num_heads, | ||
hidden_dim, | ||
mlp_dim, | ||
dropout_rate=0.0, | ||
attention_dropout=0.0, | ||
layer_norm_epsilon=1e-6, | ||
use_mha_bias=True, | ||
use_mlp_bias=True, | ||
data_format=None, | ||
dtype=None, | ||
**kwargs, | ||
): | ||
# === Laters === | ||
data_format = standardize_data_format(data_format) | ||
sineeli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
h_axis, w_axis, channels_axis = ( | ||
(-3, -2, -1) if data_format == "channels_last" else (-2, -1, -3) | ||
) | ||
# Check that the input image is well specified. | ||
if image_shape[h_axis] is None or image_shape[w_axis] is None: | ||
raise ValueError( | ||
f"Image shape must have defined height and width. Found `None` " | ||
f"at index {h_axis} (height) or {w_axis} (width). " | ||
f"Image shape: {image_shape}" | ||
) | ||
if image_shape[h_axis] != image_shape[w_axis]: | ||
raise ValueError( | ||
f"Image height and width must be equal. Found height: " | ||
f"{image_shape[h_axis]}, width: {image_shape[w_axis]} at " | ||
f"indices {h_axis} and {w_axis} respectively. Image shape: " | ||
f"{image_shape}" | ||
) | ||
|
||
num_channels = image_shape[channels_axis] | ||
|
||
# === Functional Model === | ||
inputs = keras.layers.Input(shape=image_shape) | ||
|
||
x = ViTPatchingAndEmbedding( | ||
image_size=image_shape[h_axis], | ||
patch_size=patch_size, | ||
hidden_dim=hidden_dim, | ||
num_channels=num_channels, | ||
data_format=data_format, | ||
dtype=dtype, | ||
name="vit_patching_and_embedding", | ||
)(inputs) | ||
|
||
output = ViTEncoder( | ||
num_layers=num_layers, | ||
num_heads=num_heads, | ||
hidden_dim=hidden_dim, | ||
mlp_dim=mlp_dim, | ||
dropout_rate=dropout_rate, | ||
attention_dropout=attention_dropout, | ||
layer_norm_epsilon=layer_norm_epsilon, | ||
use_mha_bias=use_mha_bias, | ||
use_mlp_bias=use_mlp_bias, | ||
dtype=dtype, | ||
name="vit_encoder", | ||
)(x) | ||
|
||
super().__init__( | ||
inputs=inputs, | ||
outputs=output, | ||
dtype=dtype, | ||
**kwargs, | ||
) | ||
|
||
# === Config === | ||
self.image_shape = image_shape | ||
self.patch_size = patch_size | ||
self.num_layers = num_layers | ||
self.num_heads = num_heads | ||
self.hidden_dim = hidden_dim | ||
self.mlp_dim = mlp_dim | ||
self.dropout_rate = dropout_rate | ||
self.attention_dropout = attention_dropout | ||
self.layer_norm_epsilon = layer_norm_epsilon | ||
self.use_mha_bias = use_mha_bias | ||
self.use_mlp_bias = use_mlp_bias | ||
self.data_format = data_format | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"image_shape": self.image_shape, | ||
"patch_size": self.patch_size, | ||
"num_layers": self.num_layers, | ||
"num_heads": self.num_heads, | ||
"hidden_dim": self.hidden_dim, | ||
"mlp_dim": self.mlp_dim, | ||
"dropout_rate": self.dropout_rate, | ||
"attention_dropout": self.attention_dropout, | ||
"layer_norm_epsilon": self.layer_norm_epsilon, | ||
"use_mha_bias": self.use_mha_bias, | ||
"use_mlp_bias": self.use_mlp_bias, | ||
} | ||
) | ||
return config |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import pytest | ||
from keras import ops | ||
|
||
from keras_hub.src.models.vit.vit_backbone import ViTBackbone | ||
from keras_hub.src.tests.test_case import TestCase | ||
|
||
|
||
class ViTBackboneTest(TestCase): | ||
def setUp(self): | ||
self.init_kwargs = { | ||
"image_shape": (28, 28, 3), | ||
"patch_size": 4, | ||
"num_layers": 3, | ||
"hidden_dim": 48, | ||
"num_heads": 6, | ||
"mlp_dim": 48 * 4, | ||
"use_mha_bias": True, | ||
} | ||
self.input_size = 28 | ||
self.input_data = ops.ones((2, self.input_size, self.input_size, 3)) | ||
|
||
def test_backbone_basics(self): | ||
self.run_backbone_test( | ||
cls=ViTBackbone, | ||
init_kwargs={**self.init_kwargs}, | ||
input_data=self.input_data, | ||
expected_output_shape=(2, 50, 48), | ||
run_quantization_check=False, | ||
) | ||
|
||
@pytest.mark.large | ||
def test_saved_model(self): | ||
self.run_model_saving_test( | ||
cls=ViTBackbone, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.input_data, | ||
) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.