-
Notifications
You must be signed in to change notification settings - Fork 279
[ViT] Vision Transformer (ViT) backbone, layers, and image classifier #1989
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great! Very nice work. Just a couple comments.
@@ -137,7 +139,10 @@ def __init__( | |||
# === Functional Model === | |||
inputs = self.backbone.input | |||
x = self.backbone(inputs) | |||
x = self.pooler(x) | |||
if pooling == "token": # used for Vision Transformer(ViT) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"token" feels like a bit a weird name here, especially when compared to "avg"
or "max"
. Maybe "first"
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually wouldn't this also break for other classifier types? I think this "token" pooling would fail to actually pool over a 2d output from most backbone, and similarly global avg 2d pooling would fail to pool correctly for a vit backbone right (since it's a 1d sequence after patching)? Instead we should subclass here, and not let pooling be configurable for vit. See https://github.com/keras-team/keras-hub/blob/master/keras_hub/src/models/vgg/vgg_image_classifier.py as an example of this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yes, I was thinking earlier to subclass and totally write a new one. Thanks for point out I will make the changes required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also from hugging face I observed that there is one more dense layer if the model is not used for ImageClassification which they call pooling
layer and it just has a dense layer(which just projects the same number of hidden dimension) and a tanh activation.
Should we include this, if we are consider for ImageClassification this layer wouldn't be present.
Image Classification: https://github.com/huggingface/transformers/blob/91b8ab18b778ae9e2f8191866e018cd1dc7097be/src/transformers/models/vit/modeling_vit.py#L823C37-L823C54
Any thoughts ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Original Code from jax call it representation size: https://github.com/google-research/vision_transformer/blob/c6de1e5378c9831a8477feb30994971bdc409e46/vit_jax/models_vit.py#L296C13-L296C32
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good, but let's fix the broken pooling configurations.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Just some comments on the classifier API
`"token_unpooled"`: Ouputs directly tokens from `ViTBackbone` | ||
representation_size: Optional dimensionality of the intermediate | ||
representation layer before the final classification layer. | ||
If `None`, the output of the transformer is directly used." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trailing quote mark?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
self.preprocessor = preprocessor | ||
|
||
if representation_size is not None: | ||
self.representation_layer = keras.layers.Dense( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe call this intermediate_dim
? Fits with other places we have an arg for the middle size on a two layer MLP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
elif pooling == "gap": | ||
ndim = len(ops.shape(x)) | ||
x = ops.mean(x, axis=list(range(1, ndim - 1))) # (1,) or (1,2) | ||
elif pooling == "token_unpooled": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this change the output shape? Output of an image classifier should be (batch_size, num_classes)
. How is this expected to be used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah this part is not required, backbone will serve the purpose if users wants to use for some other task rather than image classification.
It is used in jax code as they have single network. Thanks mat!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
looks good to me! will merge once green |
This PR introduces a Vision Transformer (ViT) implementation