-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add groups to Conv1d #948
Add groups to Conv1d #948
Conversation
Wdyt? This is for CPU only. The GPU code should be very similar to this so I want to get some feedback before I continue. Main changes:
|
@Rifur13 this looks cool! Do you intend to add the GPU kernel here? Also this will just be for 1D grouped convolutions, correct? Also would be great to if you can run some benchmarks:
|
Yep I intend to add the GPU kernel as well. And yes, this PR will focus on 1D convolutions only. Benchmarks coming soon! |
Performance doesn’t look great, it scales worse with more groups.
What we really need is a specialized It would be good to have some working version in the meantime to unblock people (like me). |
I’ll take another look actually. If I ignore the split k specialization this seems very doable. |
Just curious what is the last column measuring? It's a difference from what to what exactly? CPU -> GPU? |
No it's actually mlx vs pytorch. They should scale similarly so I use these numbers to measure performance. Also small update: I'm trying to parallelize the |
Actually, I would not do that. That is going to introduce a lot of overhead and subvert how we do job submission for the GPU. The best approach is to have a single kernel to do all the groups and handle that extra dimension in the thread grid or something like that. But I realize that might be a lot more work. A less good option that you could try is to use a concurrent command encoder. If you rebase on main, you will get some functionality to make that much easier. |
Here is a very simple example of how we do that in concatenate now: https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/primitives.cpp#L556-L564 |
Thanks for guiding me in the right direction! Numbers looks very good now and it’s review for review.
|
Very nice result!! Will review soon. |
mlx/backend/metal/conv.cpp
Outdated
// Transpose unfolded inputs | ||
array in_view( | ||
{in_unfolded.shape(0), conv_params.C, kernel_size}, | ||
in_unfolded.dtype(), | ||
nullptr, | ||
{}); | ||
in_view.copy_shared_buffer( | ||
in_unfolded, | ||
{in_unfolded.strides(0), 1, static_cast<size_t>(conv_params.C)}, | ||
in_unfolded.flags(), | ||
in_unfolded.data_size()); | ||
|
||
// Materialize | ||
auto in_transpose = array(in_view.shape(), in_view.dtype(), nullptr, {}); | ||
copy_gpu(in_view, in_transpose, CopyType::General, s); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can also create a new unfold kernel and do this transpose directly in there. It will avoid an extra copy. wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds like a good idea to me!
CC @jagrit06
@Rifur13 did you do any benchmarking for the CPU version? It's not a super high priority to make it fast, but we also don't want to make it worse than it was |
There’s an extra copy so in theory it should be worse but I didn’t see a noticeable difference in my tests. The code for convolutions when groups = 1 is unchanged now, so the performance is identical as before. I refactored the code to remove this copy and I think it also looks a lot cleaner. It’s easier to understand the code for groups vs without groups. |
Any notes or concerns? |
Not really on my side. I think we can merge this, results are very nice and code looks good!! @jagrit06 or @angeloskath do either of you care to take a quick look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks great
Just a couple of things
Could you add an error being thrown in the vjp of convolutions for now if groups != 1
Also, is there any reason that the gemm we go to for grouped convs needs to be a separate kernel ? It looks the same gemm kernel - so we don’t need to have it as a separate kernel and add to the size of the metallib
@jagrit06 Good catch I’ll add a comment for the jvp. The existing gemm kernel using the 3rd grid dim as the batch size. I think it’s possible if we set:
|
Exactly as you suggest, we can set the batch strides to let the Thanks! |
Done! Thanks for all the suggestions. Ready for a final review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much for the good work!
We should be good to merge once the tests pass
@Rifur13 the conv 1d test failed. Do you mind checking it? |
Tests should pass now. Tricky one.. |
It's failling metal validation. You should be able to reproduce locally with:
|
Fixed! Probably a good idea to add these test options in the docs somewhere |
@Rifur13 sorry for the delay in merging this caused a conflict. If you can fix it we can merge asap. I also don't mind fixing the conflict tomorrow sometime. |
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Rebased. Should be fixed now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this is awesome!
Proposed changes
Adding groups to 1D convolutions. Resolves #237.
Checklist
Put an
x
in the boxes that apply.pre-commit run --all-files
to format my code / installed pre-commit prior to committing changes