Skip to content

Add DeiT Model #2203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open

Add DeiT Model #2203

wants to merge 18 commits into from

Conversation

Sohaib-Ahmed21
Copy link

@Sohaib-Ahmed21 Sohaib-Ahmed21 commented Apr 3, 2025

Add VIT based DeiT (data-efficient image transformers) model to keras-hub along with its backbone, layers, tests and checkpoint conversion.
Paper: https://arxiv.org/pdf/2012.12877
Model card: https://huggingface.co/facebook/deit-base-distilled-patch16-384

Copy link

google-cla bot commented Apr 3, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@Sohaib-Ahmed21
Copy link
Author

Sohaib-Ahmed21 commented Apr 18, 2025

@divyashreepathihalli @JyotinderSingh , kindly approve the workflows.

@sachinprasadhs
Copy link
Collaborator

@Sohaib-Ahmed21 , Thanks for the PR, could you please add the checkpoint_conversion , and link a colab notebook here verifying the numerics, parameters matching and demonstrating end to end usage example of the model.

@Sohaib-Ahmed21
Copy link
Author

@Sohaib-Ahmed21 , Thanks for the PR, could you please add the checkpoint_conversion , and link a colab notebook here verifying the numerics, parameters matching and demonstrating end to end usage example of the model.

Yeah sure, I'll do that and update the PR, thanks!

@Sohaib-Ahmed21
Copy link
Author

Added the checkpoint_conversion, will share the parameter/numerics verification and end to end demo notebooks soon.

@Sohaib-Ahmed21
Copy link
Author

I’m sharing the following resources related to this PR:

  • DeiT Checkpoint Conversion and Numerics Verification Demo (across multiple backends): Notebook Link
  • DeiT End-to-End Demo (zero-shot and finetuning): Notebook Link
  • Here are the converted DeiT presets from Hugging Face checkpoints for reference.

@Sohaib-Ahmed21
Copy link
Author

This PR is ready for review. Kindly approve the workflows and review the PR, thanks!

Copy link
Collaborator

@sachinprasadhs sachinprasadhs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the demo notebook, added few more comments.

Also, in the colab can you assert the parameters matching between Keras and HF model.
Example, tiny variant has "params": 5524800, can you show the output where it matches with the HF params?

@Sohaib-Ahmed21
Copy link
Author

Thanks for the detailed review. I'll address the reviews soon.

can you show the output where it matches with the HF params?

Yes, I'll show that.

@Sohaib-Ahmed21
Copy link
Author

I've addressed all reviews, kindly review the updated PR. The notebook has also been updated to include parameter verification.

@sachinprasadhs
Copy link
Collaborator

Awesome, this looks great! Thank you.

@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label May 2, 2025
@Sohaib-Ahmed21
Copy link
Author

Is the failing test related to the PR? Kindly confirm and re-run the tests if required.

@divyashreepathihalli divyashreepathihalli added the kokoro:force-run Runs Tests on GPU label May 5, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label May 5, 2025
@mattdangerw
Copy link
Member

@Sohaib-Ahmed21 no the jax-gpu segfault popped up recently, it's probably related to our test environment (we haven't tracked it down yet). You can ignore.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few misc comments, but this is looking good!


# Metadata for loading pretrained model weights.
backbone_presets = {
"deit-base-distilled-patch16-384_imagenet": {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we use only underscore in our preset names, no mixing dashes and underscore like this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be true for the preset names here and on Kaggle, we want consistency.

@@ -73,7 +76,10 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
cls, load_weights, load_task_weights, **kwargs
)
# Support loading the classification head for classifier models.
if architecture == "ViTForImageClassification":
if (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could just check if ForImageClassification is in the name here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants