-
Notifications
You must be signed in to change notification settings - Fork 129
Restore fast convolution Ops, rewrites, and docs #548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
60f3f58
to
1f72ef2
Compare
+1 Follow numpy/scipy API would be the best. Part of the reason we usually do that is it saves us from having to think about it, and users from having to learn something new. So Scipy signal seems to have convolve/convolve2d and correlate (I saw we have corr something?) and they are already batched (at least convolve is)? JAX also followed their API: https://jax.readthedocs.io/en/latest/notebooks/convolutions.html I would suggest trying to implement an Op like the abstract ones, perhaps way simpler that corresponds to this API, and then specialize via rewrites to the fancy ones. Here if you think one of them is way less relevant feel free to kick it out. I would focus first on mapping to JAX (and numba if they have anything) and if some of these cases have a nice correspondence to the old C Ops all the better, if not one more reason to kick them out. Also I see they had old logic about shape inference for the C code, we should update it to use new static shape stuff |
7d736cb
to
ed5fec2
Compare
…on docs Restore fast convolution functions and rewrites, as well as convolution docs
ed5fec2
to
9aa7fba
Compare
a8b8248
to
8443857
Compare
Description
When the
nnet
sub-module was depreciated, the old theano convolution functions went into the trashcan, along with the associated docs. This was partially reverted, but the docs and the "efficient" versions of the Ops are still missing.There is interest in convolutions on the discourse from time-to-time, so it's worth talking about these.
This PR is a draft because I just restored everything. To be frank this part of the library seems somewhat bloated, and would probably be better served by an overhaul than by just restoring this code. For example, there is commentary about CUDA kernels in the docstrings -- I assume this is from the days when thenao was trying to target GPUs with pycuda. We use JAX these days, and I wager these Ops don't compile to JAX at all (though I haven't tried), or numba for that matter...
It's marked as draft for 2 reasons:
Comments requested.
Related Issue
Checklist
Type of change