-
Notifications
You must be signed in to change notification settings - Fork 816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add the general convolution operation to extensions #954
Conversation
tests coming soon. |
Hey @DarrenZhang01, Is this ready to review? |
Hi Akshay @akshaym ! Actually not yet, there is a bug I need to deal with which I will mention in today's coming meeting. |
Hey @DarrenZhang01, let me know if I can take a look at this. If there are bugs lets just document them for now and add the code anyway? WDYT? |
Thanks very much for the support, Akshay! It is ready for being reviewed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You had mentioned a bug - is that still around? Or is that fixed?
So the original bug is that the transpose convolution was not supported due to the lack of output shape as I mentioned in our last Friday's meeting. Now I added an output shape evaluation helper. Besides, I used the test cases from JAX general convolution. Those cases are really broad and TF general conv currently do not cover them all, so I skipped some test cases and added a TODO to expand the test cases later on. The shape evaluation function does not seem so correct, so I remove it and use the output shape of JAX convolution directly as the input shape for TF convolution in the test cases. |
Thanks for the review, Akshay! |
Now it is ready for review again. @akshaym |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes @DarrenZhang01!
Just some small remaining nits!
Thanks a lot, Akshay! I revised the places that you mentioned. |
TF XLA version:
https://www.tensorflow.org/xla/operation_semantics?hl=en#conv_convolution;
JAX version: https://jax.readthedocs.io/en/stable/_autosummary/jax.lax.conv_general_dilated.html