-
Notifications
You must be signed in to change notification settings - Fork 31.3k
[kernels] use original forward at compile time #37604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
|
@qubvel I'm not sure how to implement the manual escape path (as discussed on slack and here) 🤔 From the decorator's perspective, we only have access to Using Any suggestions or ideas? |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| def forward_with_compile_path(*forward_args, **forward_kwargs): | ||
| if is_torchdynamo_compiling(): | ||
| return original_forward(*forward_args, **forward_kwargs) | ||
| else: | ||
| return kernel_forward(*forward_args, **forward_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe not super clean, but we can have an extra kwarg, smth like that
| def forward_with_compile_path(*forward_args, **forward_kwargs): | |
| if is_torchdynamo_compiling(): | |
| return original_forward(*forward_args, **forward_kwargs) | |
| else: | |
| return kernel_forward(*forward_args, **forward_kwargs) | |
| def forward_with_compile_path(*forward_args, **forward_kwargs): | |
| use_kernel = forward_kwargs.pop("use_kernel", True) | |
| if is_torchdynamo_compiling() or not use_kernel: | |
| return original_forward(*forward_args, **forward_kwargs) | |
| else: | |
| return kernel_forward(*forward_args, **forward_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then I can manage it on the module call
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, good idea, that could work!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@qubvel it hits a similar barrier: many layers don't receive **kwargs in forward, so they will not be reached by arbitrary kwargs as in e.g. model(..., use_kernel=False) :(
example:
| class Phi3RMSNorm(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Custom backends https://pytorch.org/docs/stable/torch.compiler_custom_backends.html could be way to send exceptions at compile time, but probably a huge struggle (as we allow users to actually choose backend) - just posting here in case but I don't really believe in it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@qubvel I see your point: we leave the option to pass the flag, some layers won't do it by default, but at least we have some degree of control 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@qubvel added ✅ I went with the config route, since most layers have it and it doesn't require forwarding any argument, but let me know if you'd prefer having control through an argument.
(sorry for being obtuse in the comments above: I was too focused on having a solution that worked in all cases, and completely missed the point that partial support would still be useful 😅 )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, thanks! I just realized neither of the options solves my issue entirely, but at least it will respect the configuration parameter -> the current implementation is going to work. However, for the refactored version with CoreML export support, I want to know in advance which path is going to be executed before passing inputs into the module. I have no idea how to get this, previously, it was resolved the following way:
# kernel path
if kernel_loaded and not is_compiling and not custom_kernels_disabled:
hidden_state = kernel_forward(hidden_state) # guaranteed to be kernel forward
# eager path, avoid 6D tensor
else:
hidden_state = hidden_state.reshape(...)
hidden_state = deform_attn_function(hidden_state)I will solve it when return to RT_DETR refactoring, and maybe by that time kernels library will be updated, not an urgent issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good! We can revisit this later anyways, if we find a better way or new issues 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ha yes, my idea was to raise directly in the custom backend if kernels are installed and the env variable is active, and force the use of that backend everywhere (thus not changing the kernel decorator). Too bad if it does not work!
|
@qubvel @Cyrilvallez let me know if you'd like further changes (and sorry for being pushy -- CI is broken on several places until this is merged or we update a bunch of CI images) |
qubvel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, thanks 👍
Cyrilvallez
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks! I like this simple solution for now, IMO we should really aim at not touching the modeling files, otherwise it's gonna scale very badly in the future! 🤗 If truly needed for some models (rt-detr apparently) we can have a local exception, but it should really stay an exception!
What does this PR do?
kernelsandtorch.compileare not yet compatible with each other. Although we can skip custom kernels when the package is not installed, adding an error message is also not feasible -- we can't throw exceptions at compile time.This PR hijacks the
kernelsdecorator to add a compile-friendly path: untilkernelssupportstorch.compile, let's use the originalforward.