-
Notifications
You must be signed in to change notification settings - Fork 26.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Flax Dinov2 #31960
Add Flax Dinov2 #31960
Conversation
Hi @MHRDYN7, thanks for working on this conversion! The interpolate logic is tricky. If loading the weights directly at the moment means this model passes, then that's a good guide the conversion is OK. We might have to do something where we remove this before merge and skip the equivalence tests. In the meantime, the first thing to do is get the other tests passing. Some of the failing tests are unrelated and have fixes upstream. Could you rebase on main to include these? For the quality checks, running |
Hi @amyeroberts, I did try make fixup and I'm not really sure why the two tests are still failing. Moreover, what should be done for skipping the equivalence test? Should I just remove the directly loaded tensors after interpolate and change the "expected_slice" tensor in the integration tests accordingly to make them pass? |
@MHRDYN7 For the quality checks, you'll need to run
For this, in general, we would add a For your proposal re weights, this might be a good option as we'd still be checking the rest of the model. When you say "remove", I'm guessing you mean from the respective state dicts? |
@amyeroberts thank you. I have tried to solve all the issues. To summarize, I have finally decided to keep the |
@MHRDYN7 Great!
Is there a plan for this to happen in the future?
This shouldn't be necessary, all the model frameworks: TF, PyTorch, Flax should be able to load the safetensors file. |
The issue from the jax repo mentioned on my first comment, suggests that they did come up with the fix but no steps were taken and also there are no PRs related to the issue. I might just open a PR there if I can; shouldn't be hard to solve.
It's good to hear that. Indeed, the models don't necessarily need the .msgpack weights. I observed the flax weights on the hub for many models and thought it was a convention to add those. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this!
Mostly just some nits. Overall LGTM @sanchit-gandhi could you give. quick once-over to confirm flax is OK?
new_height_ratio = jnp.float32(height / math.sqrt(num_positions)) # ? 16/37 | ||
new_width_ratio = jnp.float32(width / math.sqrt(num_positions)) # ? 16/37 | ||
|
||
# patch_pos_embed = jax.image.resize(patch_pos_embed, shape=(hidden_states.shape[0], dim, height, width), method='bicubic', antialias=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why commented out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that I mistakenly wrote in my last comment that I used jax.image.resize whilst I actually used jax.image.scale_and_translate. Both these functions ultimately call the same helper function internally and therefore both of these could be used for interpolating the tensor. The reason why scale_and_translate() is the better fit is that it allows us to set the scale argument (which is key according to the original Dinov2 repo) while resize() determines the scale on its own. I'll remove the commented out line of code
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
@amyeroberts thanks a lot for the review. All the tests are passing again. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR generally looks in good shape! Well done on handling all of the weight initialisations carefully @MHRDYN7 and porting the new functions over to Flax.
The main request from my review is using # Copied from
statements as much as possible. There are many modules / methods that are copied 1-for-1 from existing models in the library. Here, prepending them with a # Copied from
helps:
- Keep code sync'd across models
- The reviewer pinpoint which parts of the code to focus on!
Regarding your PR description: I didn't fully understand what the issue was with the position embedding weights - you've defined them as a standard self.param
, and the keys look to match those from PyTorch? Let me know if I'm missing something here!
) | ||
|
||
|
||
class FlaxDinov2PreTrainedModel(FlaxPreTrainedModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would copy this from Beit
class FlaxDinov2PreTrainedModel(FlaxPreTrainedModel): | |
# Copied from transformers.models.beit.modeling_flax_beit.FlaxBeitPreTrainedModel with Beit-> Dinov2, beit -> dinov2 |
# init input tensors | ||
pixel_values = jnp.zeros(input_shape, dtype=self.dtype) | ||
|
||
params_rng, dropout_rng = jax.random.split(rng) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're missing the rng for the droppath - copying from Beit is going to fix this
@@ -0,0 +1,259 @@ | |||
# coding=utf-8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It'd be super helpful to add # Copied from
statements in the tests as well!
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer") | ||
>>> model = FlaxDinov2ForImageClassification.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer") | ||
|
||
>>> inputs = image_processor(images=image, return_tensors="np") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fine to do "np", since we convert "np" arrays to "jnp" arrays before calling the Flax module! (In fact, doing "np" is preferable here, since "jnp" arrays are automatically created on the accelerator device, whereas "np" is always on cpu -> creating your input on cpu and only moving it to accelerator when required is better for async dispatch)
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
@sanchit-gandhi thanks a lot for the careful review. A summary of the updates
|
The position embedding weights can be loaded perfectly. However, these weights are later modified according to the number of patches using F.interpolate (with bicubic mode) in torch. We can replicate this behavior with jax.image.scale_and_translate (or with image.resize), but it seems that this function is slight different from torch interpolate only in case of the bicubic mode, resulting slightly different output hidden_states. Please Note: |
Thanks for the detailed explanation and iterating with us @MHRDYN7! @sanchit-gandhi is off at the moment, but I can see you've addressed his comments, so I think we're OK to merge without his second review. Final step is running the slow tests for the model before merge. Could you push an empty commit with the message |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@amyeroberts, I've pushed the required commit. Now I guess, it requires your approval for running the slow tests |
@amyeroberts slow tests passed! Ready to be merged |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great piece of work - thanks for adding!
This PR adds the Flax implementation of Dinov2, which seems to have been due since #25579.
All the components of the pytorch Dinov2 model can be converted to flax except "interpolate_pos_encoding" which uses torch.nn.functional.interpolate. The closest jax function to replicate this is jax.image.scale_and_translate, however there seems to be a slight difference between these functions in the "Bicubic" mode (https://github.com/google/jax/issues/15768).
In Dinov2, the pretrained weights of the position encoding are for the image size of 512, but we load images of size 224 into the model. The interpolate function acts to convert the shapes of the position encoding according to the size of the input images. The ViT model does have this interpolate function, but it's not there in the FlaxViT implementation as the config and input image sizes are the same.
For now, I have directly loaded the pos_encoding weights from the pt model to flax, right after interpolation (which is saved in a safetensors file). This passes all the tests (including the two new integration tests added on top of the FlaxViT tests). Surely, this brute force approach to loading the original interpolated pos_encodings will not work, but otherwise, the slight deviations from jax scale_and _translate will fail the tests. @amyeroberts @sanchit-gandhi
Other Remaining Tasks:
Add the flax weights in .msgpack files to hubTest the SwiGLUFFN dense layer for vit giant