-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow overriding of transpose function #74
Conversation
Yes I think this definitely makes sense. Some libraries (e.g. hptt), implement transpose directly for numpy arrays, which you would want to take precedence. |
I'll merge this shall I, unless you have comments @dgasmith? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
As per @jcmgray's comment. It would superb to get TBLIS and/or HPTT on conda-forge to try this out. Does HPTT compile to a single vectorized instruction or can it compile to multiple simultaneously like ICC multi arch or a more home grown solution as found in TBLIS?
@jcmgray Feel free to merge PR's if you wish with the usual caveat of something major should be discussed first. |
I'm actually not sure at all, I did add a numpy like transpose wrapper to
Great, noted. |
@dgasmith would it be possible to cut a release with this PR and the new |
@dgasmith Our Pyro release is planned for the week of Dec 3. Let me know if there is any additional documentation you'd like before the next opt_einsum release. Thanks! |
Description
This PR makes backends fully overridable by checking each backed for a transpose function before using the default
.transpose()
method.Before this PR,
_transpose()
would first look for a default implementation (x.transpose(axes)
), and only if this failed look for an overridden backendtranspose()
function. This approach has two flaws: (1) backends cannot override the transpose function if a different.transpose()
method already exists (e.g. in Pyro we want to track some metadata intranspose()
); (2) the try-except may not play well with jit compilers like the PyTorch jit.The only change in behavior is to the torch backend: before this PR
torch.transpose()
was used for transposes of 2 axes andtorch.permute()
was used for 3 or more axes; after this PRtorch.permute()
is always used.Status