Skip to content

Conversation

@gante
Copy link
Contributor

@gante gante commented Apr 18, 2025

What does this PR do?

kernels and torch.compile are not yet compatible with each other. Although we can skip custom kernels when the package is not installed, adding an error message is also not feasible -- we can't throw exceptions at compile time.

This PR hijacks the kernels decorator to add a compile-friendly path: until kernels supports torch.compile, let's use the original forward.

@github-actions github-actions bot marked this pull request as draft April 18, 2025 09:18
@github-actions
Copy link
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@gante gante marked this pull request as ready for review April 18, 2025 09:20
@gante
Copy link
Contributor Author

gante commented Apr 18, 2025

@qubvel I'm not sure how to implement the manual escape path (as discussed on slack and here) 🤔

From the decorator's perspective, we only have access to cls. Some classes where the kernel is applied don't have any way we can set extra flags at model definition time (e.g.), so we can't do something like model.disable_custom_kernels = False with the decorator approach. Not unless we add a lot of extra code.

Using DISABLE_KERNEL_MAPPING=1 should work, though.

Any suggestions or ideas?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante gante requested review from Cyrilvallez and qubvel April 18, 2025 09:55
Comment on lines 80 to 84
def forward_with_compile_path(*forward_args, **forward_kwargs):
if is_torchdynamo_compiling():
return original_forward(*forward_args, **forward_kwargs)
else:
return kernel_forward(*forward_args, **forward_kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe not super clean, but we can have an extra kwarg, smth like that

Suggested change
def forward_with_compile_path(*forward_args, **forward_kwargs):
if is_torchdynamo_compiling():
return original_forward(*forward_args, **forward_kwargs)
else:
return kernel_forward(*forward_args, **forward_kwargs)
def forward_with_compile_path(*forward_args, **forward_kwargs):
use_kernel = forward_kwargs.pop("use_kernel", True)
if is_torchdynamo_compiling() or not use_kernel:
return original_forward(*forward_args, **forward_kwargs)
else:
return kernel_forward(*forward_args, **forward_kwargs)

Copy link
Contributor

@qubvel qubvel Apr 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then I can manage it on the module call

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good idea, that could work!

Copy link
Contributor Author

@gante gante Apr 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qubvel it hits a similar barrier: many layers don't receive **kwargs in forward, so they will not be reached by arbitrary kwargs as in e.g. model(..., use_kernel=False) :(

example:

class Phi3RMSNorm(nn.Module):

Copy link
Member

@Cyrilvallez Cyrilvallez Apr 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Custom backends https://pytorch.org/docs/stable/torch.compiler_custom_backends.html could be way to send exceptions at compile time, but probably a huge struggle (as we allow users to actually choose backend) - just posting here in case but I don't really believe in it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qubvel I see your point: we leave the option to pass the flag, some layers won't do it by default, but at least we have some degree of control 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qubvel added ✅ I went with the config route, since most layers have it and it doesn't require forwarding any argument, but let me know if you'd prefer having control through an argument.

(sorry for being obtuse in the comments above: I was too focused on having a solution that worked in all cases, and completely missed the point that partial support would still be useful 😅 )

Copy link
Contributor

@qubvel qubvel Apr 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, thanks! I just realized neither of the options solves my issue entirely, but at least it will respect the configuration parameter -> the current implementation is going to work. However, for the refactored version with CoreML export support, I want to know in advance which path is going to be executed before passing inputs into the module. I have no idea how to get this, previously, it was resolved the following way:

# kernel path
if kernel_loaded and not is_compiling and not custom_kernels_disabled:
    hidden_state = kernel_forward(hidden_state)  # guaranteed to be kernel forward

# eager path, avoid 6D tensor
else:
    hidden_state = hidden_state.reshape(...)
    hidden_state = deform_attn_function(hidden_state)

I will solve it when return to RT_DETR refactoring, and maybe by that time kernels library will be updated, not an urgent issue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good! We can revisit this later anyways, if we find a better way or new issues 🤗

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha yes, my idea was to raise directly in the custom backend if kernels are installed and the env variable is active, and force the use of that backend everywhere (thus not changing the kernel decorator). Too bad if it does not work!

@gante
Copy link
Contributor Author

gante commented Apr 18, 2025

@qubvel @Cyrilvallez let me know if you'd like further changes

(and sorry for being pushy -- CI is broken on several places until this is merged or we update a bunch of CI images)

Copy link
Contributor

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks 👍

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks! I like this simple solution for now, IMO we should really aim at not touching the modeling files, otherwise it's gonna scale very badly in the future! 🤗 If truly needed for some models (rt-detr apparently) we can have a local exception, but it should really stay an exception!

@gante gante merged commit 1930e75 into huggingface:main Apr 21, 2025
20 checks passed
@gante gante deleted the kernels_compile branch April 21, 2025 12:23
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants