Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable BLAS threading within conv_im2col! etc. #395

Closed
wants to merge 1 commit into from

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Feb 22, 2022

Looking at #234, the implementation of im2col! seems to have 3 nested multi-threading operations: a @spawn, a @threads loop over the batch dim, and then BLAS threads. That might not be optimal.

This PR finds some 30% speedups by keeping just the @threads loop. But could use more testing, etc. Especially for someone to try on a newer Intel machine with MKL.

Before:

julia> using Flux, BenchmarkTools, LinearAlgebra

julia> x = randn(Float32, 224, 224, 3, 32);  # more batch dim than https://github.com/FluxML/NNlib.jl/issues/234

julia> @btime copy($x);
  min 2.411 ms, mean 2.712 ms (2 allocations, 18.38 MiB)  # intel E5-2603 first line, 1.7 + OpenBLAS
  min 1.389 ms, mean 1.962 ms (2 allocations, 18.38 MiB)  # M1 mac second line always, 1.8 + Accelerate

julia> c = Conv((7,7), 3 => 64; stride=2, pad=3);

julia> cdims = Flux.DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups);

julia> Threads.nthreads()
4

julia> BLAS.set_num_threads(4)

julia> @btime Flux.conv($x, $c.weight, $cdims);
  min 142.908 ms, mean 208.501 ms (32 allocations, 126.14 MiB)
  min 64.856 ms, mean 107.744 ms (66 allocations, 126.14 MiB)

julia> BLAS.set_num_threads(1)

julia> @btime Flux.conv($x, $c.weight, $cdims);
  min 120.732 ms, mean 350.911 ms (32 allocations, 126.14 MiB)  # faster min, slower mean, just a bad run?
  min 52.465 ms, mean 93.704 ms (67 allocations, 126.14 MiB)    # faster

# With the example from https://github.com/FluxML/NNlib.jl/issues/234 exactly:

julia> BLAS.set_num_threads(4)

julia> @btime conv($dummy);
  min 10.225 ms, mean 14.212 ms (37 allocations, 40.39 MiB)  # xeon
  min 5.835 ms, mean 8.560 ms (82 allocations, 40.39 MiB)    # M1

julia> BLAS.set_num_threads(1)

julia> @btime conv($dummy);
  min 13.030 ms, mean 15.990 ms (37 allocations, 40.39 MiB)
  min 7.011 ms, mean 8.411 ms (82 allocations, 40.39 MiB)

julia> size(dummy)  # batch size 2, hence only 2 threads not 4 used by @threads loop
(224, 224, 3, 2)

After:

julia> @btime Flux.conv($x, $c.weight, $cdims);
  min 112.312 ms, mean 140.835 ms (32 allocations, 126.14 MiB)  # was min 142.908 ms, mean 208.501 ms
  min 52.618 ms, mean 79.594 ms (32 allocations, 126.14 MiB)    # was min 64.856 ms, mean 107.744 ms

julia> @btime conv($dummy);
  min 13.061 ms, mean 16.524 ms (37 allocations, 40.39 MiB)  # slower, when batch size < nthreads. Perhaps test that?
  min 7.397 ms, mean 9.666 ms (37 allocations, 40.39 MiB)

Comment on lines +199 to +204
if nthreads() > 1
th = BLAS.get_num_threads()
BLAS.set_num_threads(1)
# conv_im2col! has a loop with @threads, and benchmarks show that this is usually
# faster without BLAS multithreading, and without @spawn in the zip(x_cs, w_cs) loop.
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know of any way to do this on a task or thread-local level? It's too bad that this requires mutating global state.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't, although I agree it's ugly.

@darsnack
Copy link
Member

Shouldn't this also test a convolution with groups > 1? That's what the spawn was for.

@mcabbott
Copy link
Member Author

Maybe, got an example?

In this case the @spawn is pure loss of course. Ideally I think we want exactly one multi-threading loop, but where to put it (without too much complexity) seems messy. Maybe making the batch loop use @spawn might make them compose better.

@mcabbott mcabbott changed the title Disable BLAS threading within im2col! etc. Disable BLAS threading within conv_im2col! etc. Feb 22, 2022
@CarloLucibello
Copy link
Member

Shall we go on with this? It could test also the grouped conv in FluxML/Flux.jl#1921 (comment)

I can benchmark on a amd threadripper platform

@mcabbott
Copy link
Member Author

mcabbott commented Apr 2, 2022

Benchmarks would be great. I stopped here because I thought we ought to check a whole range -- what if there is one group, or fewer channels than threads, etc... just timing one thing could lead you anywhere.

Ideally we'd have multi-threading only on one outermost loop (and with some heuristic to decide when the problem is big enough). If one loop can't cover all cases then IIRC @spawn within @spawn composes better than @threads. But sorting all of this out started to seem like a big project.

@ToucheSir
Copy link
Member

@CarloLucibello thoughts on expanding https://github.com/FluxML/Flux.jl/blob/master/perf/conv.jl to use as a benchmarking suite?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants