Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add groups to Conv1d #948

Merged
merged 10 commits into from
Apr 27, 2024
Merged

Add groups to Conv1d #948

merged 10 commits into from
Apr 27, 2024

Conversation

Rifur13
Copy link
Contributor

@Rifur13 Rifur13 commented Apr 1, 2024

Proposed changes

Adding groups to 1D convolutions. Resolves #237.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@Rifur13
Copy link
Contributor Author

Rifur13 commented Apr 1, 2024

Wdyt? This is for CPU only. The GPU code should be very similar to this so I want to get some feedback before I continue.

Main changes:

  • The input and kernel weights need to be transposed to cleanly split up the input into groups for the matmuls.
  • The result of each grouped convolution won’t be contiguous in the output - so they need to we inserted with a slice.

@awni
Copy link
Member

awni commented Apr 3, 2024

@Rifur13 this looks cool! Do you intend to add the GPU kernel here? Also this will just be for 1D grouped convolutions, correct?

Also would be great to if you can run some benchmarks:

  • Regular conv pre/post (make sure no change)
  • Group conv (ideally much faster when using lots of groups as compared to same shape conv with a single group)

@Rifur13
Copy link
Contributor Author

Rifur13 commented Apr 3, 2024

Yep I intend to add the GPU kernel as well. And yes, this PR will focus on 1D convolutions only.

Benchmarks coming soon!

@Rifur13
Copy link
Contributor Author

Rifur13 commented Apr 9, 2024

Performance doesn’t look great, it scales worse with more groups.



(N, iH, C) (O, wH, C) dtype stride pads groups diff%
(4, 32, 32) (32, 5, 32) float32 1 2 1 +179.77%
(4, 32, 32) (32, 5, 32) float32 1 2 2 +59.62%
(4, 32, 32) (32, 5, 32) float32 1 2 4 +33.96%
(4, 32, 32) (32, 5, 32) float32 1 2 8 +0.71%
(4, 32, 32) (32, 5, 32) float32 1 2 8 +15.49%
(4, 32, 32) (32, 5, 32) float32 1 2 16 -32.81%
(4, 32, 32) (32, 5, 32) float32 1 2 32 -62.36%
(4, 32, 256) (512, 5, 256) float32 1 2 2 +41.59%
(4, 32, 256) (512, 5, 256) float32 1 2 128 -88.60%
(4, 32, 256) (512, 5, 256) float32 1 2 256 -93.96%

What we really need is a specialized steel_matmul that splits up the inputs into groups and dispatches the kernels in parallel.
It might take me a while to understand all the gemm kernel code. I’m not sure how much time I’ll have so if something really needs it they can take up this work.

It would be good to have some working version in the meantime to unblock people (like me).

@Rifur13
Copy link
Contributor Author

Rifur13 commented Apr 9, 2024

I’ll take another look actually. If I ignore the split k specialization this seems very doable.

@awni
Copy link
Member

awni commented Apr 12, 2024

Just curious what is the last column measuring? It's a difference from what to what exactly? CPU -> GPU?

@Rifur13
Copy link
Contributor Author

Rifur13 commented Apr 12, 2024

No it's actually mlx vs pytorch. They should scale similarly so I use these numbers to measure performance.

Also small update: I'm trying to parallelize the groups for loop by sending each kernel to a different command buffer. So I will create groups streams, groups command queues, etc... Working through some errors right now, but lmk if that makes sense

@awni
Copy link
Member

awni commented Apr 12, 2024

Also small update: I'm trying to parallelize the groups for loop by sending each kernel to a different command buffer. So I will create groups streams, groups command queues, etc... Working through some errors right now, but lmk if that makes sense

Actually, I would not do that. That is going to introduce a lot of overhead and subvert how we do job submission for the GPU.

The best approach is to have a single kernel to do all the groups and handle that extra dimension in the thread grid or something like that. But I realize that might be a lot more work.

A less good option that you could try is to use a concurrent command encoder. If you rebase on main, you will get some functionality to make that much easier.

@awni
Copy link
Member

awni commented Apr 12, 2024

Here is a very simple example of how we do that in concatenate now: https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/primitives.cpp#L556-L564

@Rifur13
Copy link
Contributor Author

Rifur13 commented Apr 14, 2024

Thanks for guiding me in the right direction! Numbers looks very good now and it’s review for review.

N iH C O wH C dtype stride pads groups diff%
4 32 32 32 5 32 float32 1 2 1 +189.00%
4 32 32 32 5 32 float32 1 2 2 +176.95%
4 32 32 32 5 32 float32 1 2 4 +185.48%
4 32 32 32 5 32 float32 1 2 8 +183.16%
4 32 32 32 5 32 float32 1 2 8 +181.10%
4 32 32 32 5 32 float32 1 2 16 +145.79%
4 32 32 32 5 32 float32 1 2 32 +102.98%
4 32 256 512 5 256 float32 1 2 2 +110.27%
4 32 256 512 5 256 float32 1 2 128 +50.08%
4 32 256 512 5 256 float32 1 2 256 +28.68%

@Rifur13 Rifur13 marked this pull request as ready for review April 14, 2024 21:03
@awni
Copy link
Member

awni commented Apr 15, 2024

Very nice result!! Will review soon.

mlx/ops.cpp Outdated Show resolved Hide resolved
Comment on lines 144 to 158
// Transpose unfolded inputs
array in_view(
{in_unfolded.shape(0), conv_params.C, kernel_size},
in_unfolded.dtype(),
nullptr,
{});
in_view.copy_shared_buffer(
in_unfolded,
{in_unfolded.strides(0), 1, static_cast<size_t>(conv_params.C)},
in_unfolded.flags(),
in_unfolded.data_size());

// Materialize
auto in_transpose = array(in_view.shape(), in_view.dtype(), nullptr, {});
copy_gpu(in_view, in_transpose, CopyType::General, s);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can also create a new unfold kernel and do this transpose directly in there. It will avoid an extra copy. wdyt?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like a good idea to me!

CC @jagrit06

@awni
Copy link
Member

awni commented Apr 18, 2024

@Rifur13 did you do any benchmarking for the CPU version? It's not a super high priority to make it fast, but we also don't want to make it worse than it was

@Rifur13
Copy link
Contributor Author

Rifur13 commented Apr 18, 2024

There’s an extra copy so in theory it should be worse but I didn’t see a noticeable difference in my tests. The code for convolutions when groups = 1 is unchanged now, so the performance is identical as before.

I refactored the code to remove this copy and I think it also looks a lot cleaner. It’s easier to understand the code for groups vs without groups.

@Rifur13
Copy link
Contributor Author

Rifur13 commented Apr 23, 2024

Any notes or concerns?

@awni
Copy link
Member

awni commented Apr 24, 2024

Not really on my side. I think we can merge this, results are very nice and code looks good!! @jagrit06 or @angeloskath do either of you care to take a quick look?

Copy link
Member

@jagrit06 jagrit06 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks great
Just a couple of things

Could you add an error being thrown in the vjp of convolutions for now if groups != 1

Also, is there any reason that the gemm we go to for grouped convs needs to be a separate kernel ? It looks the same gemm kernel - so we don’t need to have it as a separate kernel and add to the size of the metallib

@Rifur13
Copy link
Contributor Author

Rifur13 commented Apr 24, 2024

@jagrit06 Good catch I’ll add a comment for the jvp.

The existing gemm kernel using the 3rd grid dim as the batch size.
Are you suggesting to repurpose batches as groups? Readability would take a hit imo.

I think it’s possible if we set:

params->batch_ndim = 1
params->batch_stride_a = K
params->batch_stride_b = N * K
params->batch_stride_d = N

@jagrit06
Copy link
Member

@jagrit06 Good catch I’ll add a comment for the jvp.

The existing gemm kernel using the 3rd grid dim as the batch size. Are you suggesting to repurpose batches as groups? Readability would take a hit imo.

I think it’s possible if we set:

params->batch_ndim = 1
params->batch_stride_a = K
params->batch_stride_b = N * K
params->batch_stride_d = N

Exactly as you suggest, we can set the batch strides to let the tid.z handle that
I don't particularly think this is a bad enough readability hit for us to include the overhead of compiling and packing all new sets of gemm kernels which are basically the same as the ones we already have

Thanks!

@Rifur13
Copy link
Contributor Author

Rifur13 commented Apr 24, 2024

Done! Thanks for all the suggestions.

Ready for a final review.

Copy link
Member

@jagrit06 jagrit06 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for the good work!
We should be good to merge once the tests pass

@awni
Copy link
Member

awni commented Apr 25, 2024

@Rifur13 the conv 1d test failed. Do you mind checking it?

@Rifur13
Copy link
Contributor Author

Rifur13 commented Apr 25, 2024

Tests should pass now. Tricky one..

@awni
Copy link
Member

awni commented Apr 25, 2024

It's failling metal validation. You should be able to reproduce locally with:

METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python ..

@Rifur13
Copy link
Contributor Author

Rifur13 commented Apr 25, 2024

Fixed! Probably a good idea to add these test options in the docs somewhere

@awni
Copy link
Member

awni commented Apr 27, 2024

@Rifur13 sorry for the delay in merging this caused a conflict. If you can fix it we can merge asap. I also don't mind fixing the conflict tomorrow sometime.

@Rifur13
Copy link
Contributor Author

Rifur13 commented Apr 27, 2024

Rebased. Should be fixed now

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this is awesome!

@awni awni merged commit c4a471c into ml-explore:main Apr 27, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feature Request: groups parameter in Conv1d
3 participants