Avoiding graph break by changing the way we infer dtype in vae.decoder#12512
Avoiding graph break by changing the way we infer dtype in vae.decoder#12512DN6 merged 6 commits intohuggingface:mainfrom
Conversation
|
@DN6 WDYT? |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
I made sure all autoencoder tests are passing locally, I would be very thankful if you can take a look @DN6 |
| sample = self.conv_in(sample) | ||
|
|
||
| upscale_dtype = next(iter(self.up_blocks.parameters())).dtype | ||
| upscale_dtype = self.up_blocks[0].resnets[0].norm1.weight.dtype |
There was a problem hiding this comment.
Think current failing tests in the CI are due to the fact that not every decoder block has a norm1 with a weight. Hence the use of the generator here to avoid such cases.
@ppadjinTT I noticed you initially used self.conv_out.weight here? What was the issue you ran into with that?
There was a problem hiding this comment.
okay, I will change that too, tnx! I intially changed the self.conv_out.weight because there are some tests that check what happens when conv_out and upscale_blocks have different dtypes
There was a problem hiding this comment.
Could you point me to those tests? Seems like setting to conv_out is more robust.
There was a problem hiding this comment.
Yup, these are the tests pytest -svvv tests/models/autoencoders/test_models_autoencoder_kl.py
This is one of the tests from this test set that fails tests/models/autoencoders/test_models_autoencoder_kl.py::AutoencoderKLTests::test_layerwise_casting_inference
There was a problem hiding this comment.
I added better logic for inferring dtype, to capture the case where it doesn't work
There was a problem hiding this comment.
Hmm I think we can remove upscale_type entirely here. I think all tests should still pass without it.
There was a problem hiding this comment.
okay let's try that, i'm pushing the change
There was a problem hiding this comment.
Any chance you can take a look? Thanks for your effort @DN6
|
Thanks @ppadjinTT 👍🏽 |
…sors
What does this PR do?
This PR addresses the problem disscused in #12501, where the usage of
upscale_dtype = next(iter(self.up_blocks.parameters())).dtypeto infer the dtype in the forward pass of thevae.decodercauses the graph break when compiling the model with torch.compile.The issue is that the usage of
next(iter(...))forces the lazy tensors in the initial compiled model pass to materialize, resulting in graph break, which decreases performance.This PR proposes a simple fix by infering the
dtypeas:Fixes #12501
Who can review?
@sayakpaul