Skip to content

[ViT] Vision Transformer (ViT) backbone, layers, and image classifier #1989

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
Dec 12, 2024
Merged

[ViT] Vision Transformer (ViT) backbone, layers, and image classifier #1989

merged 37 commits into from
Dec 12, 2024

Conversation

sineeli
Copy link
Collaborator

@sineeli sineeli commented Nov 21, 2024

This PR introduces a Vision Transformer (ViT) implementation

  1. Backbone
  2. Preprocessor
  3. Image classifier
  4. Weights transfer script

@sineeli sineeli added the kokoro:force-run Runs Tests on GPU label Nov 23, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Nov 23, 2024
@sineeli sineeli added the kokoro:force-run Runs Tests on GPU label Nov 27, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Nov 27, 2024
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great! Very nice work. Just a couple comments.

@@ -137,7 +139,10 @@ def __init__(
# === Functional Model ===
inputs = self.backbone.input
x = self.backbone(inputs)
x = self.pooler(x)
if pooling == "token": # used for Vision Transformer(ViT)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"token" feels like a bit a weird name here, especially when compared to "avg" or "max". Maybe "first"?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually wouldn't this also break for other classifier types? I think this "token" pooling would fail to actually pool over a 2d output from most backbone, and similarly global avg 2d pooling would fail to pool correctly for a vit backbone right (since it's a 1d sequence after patching)? Instead we should subclass here, and not let pooling be configurable for vit. See https://github.com/keras-team/keras-hub/blob/master/keras_hub/src/models/vgg/vgg_image_classifier.py as an example of this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, I was thinking earlier to subclass and totally write a new one. Thanks for point out I will make the changes required.

Copy link
Collaborator Author

@sineeli sineeli Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mattdangerw

Also from hugging face I observed that there is one more dense layer if the model is not used for ImageClassification which they call pooling layer and it just has a dense layer(which just projects the same number of hidden dimension) and a tanh activation.

Should we include this, if we are consider for ImageClassification this layer wouldn't be present.

ViTModel: https://github.com/huggingface/transformers/blob/91b8ab18b778ae9e2f8191866e018cd1dc7097be/src/transformers/models/vit/modeling_vit.py#L576

Image Classification: https://github.com/huggingface/transformers/blob/91b8ab18b778ae9e2f8191866e018cd1dc7097be/src/transformers/models/vit/modeling_vit.py#L823C37-L823C54

Any thoughts ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sineeli
Copy link
Collaborator Author

sineeli commented Dec 4, 2024

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good, but let's fix the broken pooling configurations.

@sineeli sineeli added kokoro:force-run Runs Tests on GPU and removed kokoro:force-run Runs Tests on GPU labels Dec 9, 2024
@sineeli
Copy link
Collaborator Author

sineeli commented Dec 9, 2024

Model ID Weights Transfer Status
Base  
hf://google/vit-base-patch16-224 successful backbone + head weights transfer
hf://google/vit-base-patch16-224-in21k layer name mismatch from model.safetensors
hf://google/vit-base-patch16-384 successful backbone + head weights transfer
hf://google/vit-base-patch32-224-in21k layer name mismatch from model.safetensors
hf://google/vit-base-patch32-384 successful backbone + head weights transfer
Large  
hf://google/vit-large-patch16-224 No model.saftensors
hf://google/vit-large-patch16-224-in21k layer name mismatch from model.safetensors
hf://google/vit-large-patch16-384 No model.saftensors
hf://google/vit-large-patch32-224-in21k No model.saftensors
hf://google/vit-large-patch32-384 No model.saftensors
Huge  
hf://google/vit-huge-patch14-224-in21k layer name mismatch from model.safetensors

@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Dec 10, 2024
@sineeli sineeli added the kokoro:force-run Runs Tests on GPU label Dec 12, 2024
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Just some comments on the classifier API

`"token_unpooled"`: Ouputs directly tokens from `ViTBackbone`
representation_size: Optional dimensionality of the intermediate
representation layer before the final classification layer.
If `None`, the output of the transformer is directly used."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trailing quote mark?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

self.preprocessor = preprocessor

if representation_size is not None:
self.representation_layer = keras.layers.Dense(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe call this intermediate_dim? Fits with other places we have an arg for the middle size on a two layer MLP.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

elif pooling == "gap":
ndim = len(ops.shape(x))
x = ops.mean(x, axis=list(range(1, ndim - 1))) # (1,) or (1,2)
elif pooling == "token_unpooled":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this change the output shape? Output of an image classifier should be (batch_size, num_classes). How is this expected to be used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah this part is not required, backbone will serve the purpose if users wants to use for some other task rather than image classification.

It is used in jax code as they have single network. Thanks mat!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Dec 12, 2024
@mattdangerw
Copy link
Member

looks good to me! will merge once green

@mattdangerw mattdangerw added the kokoro:force-run Runs Tests on GPU label Dec 12, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Dec 12, 2024
@mattdangerw mattdangerw merged commit 15564ca into keras-team:master Dec 12, 2024
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants