From ec54732b5d94e628587eefe5ffcfbfc34488c22c Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Mon, 24 Oct 2022 22:21:43 -0700 Subject: [PATCH] =?UTF-8?q?Remove=20threading=20from=20all=20`=E2=88=87*co?= =?UTF-8?q?nv=5Ffilter`=20and=20re-enable=20old=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This should hopefully fix persistent threading-related issues on CI, While also slightly reducing memory usage when taking gradients. --- src/dim_helpers/ConvDims.jl | 19 +++++++++++++++++++ src/impl/conv_im2col.jl | 9 +++++---- src/impl/depthwiseconv_im2col.jl | 30 +++++++++++++++--------------- test/conv.jl | 24 +++++++----------------- 4 files changed, 46 insertions(+), 36 deletions(-) diff --git a/src/dim_helpers/ConvDims.jl b/src/dim_helpers/ConvDims.jl index 1b1c3f271..a0edb9d80 100644 --- a/src/dim_helpers/ConvDims.jl +++ b/src/dim_helpers/ConvDims.jl @@ -77,6 +77,25 @@ function im2col_dims(c::ConvDims) ) end +""" + ∇filter_im2col_dims(c::ConvDims) + +Like [`im2col_dims`](@ref), but saves some memory because multiple (Julia) threads are +not required for the filter gradient calculation. + +Note: in the future, this may return `Dims{2}` instead of `Dims{3}`. +""" +function ∇filter_im2col_dims(c::ConvDims) + return ( + # Output size + prod(output_size(c)), + # Size of single dotproduct within convolution + prod(kernel_size(c))*channels_in(c), + # No threading, this is just here for backwards compat + 1 + ) +end + # Protect your skin, kids. Also do common validation of stride, padding, etc... function check_spdf(x_size::NTuple{N}, w_size::NTuple{N}, stride, padding, dilation) where {N} # Number of spatial dimensions in `x` and `w`. diff --git a/src/impl/conv_im2col.jl b/src/impl/conv_im2col.jl index 38f2ff50b..3d2702b8b 100644 --- a/src/impl/conv_im2col.jl +++ b/src/impl/conv_im2col.jl @@ -60,15 +60,16 @@ function conv_im2col!( end """ - ∇conv_filter_im2col!(dw, x, dy, cdims, col=similar(dw); alpha=1, beta=0) + ∇conv_filter_im2col!(dw, x, dy, cdims, col=similar(dw, ∇filter_im2col_dims(cdims)); + alpha=1, beta=0) Conv backward pass onto the weights using im2col and GEMM; stores the result in `dw`. -See the documentation for `conv_im2col!()` for explanation of optional parameters. +See [`conv_im2col!`](@ref) for explanation of optional parameters. """ function ∇conv_filter_im2col!( dw::AbstractArray{T,5}, x::AbstractArray{T,5}, dy::AbstractArray{T,5}, cdims::DenseConvDims; - col::AbstractArray{T,3} = similar(dw, im2col_dims(cdims)), + col::AbstractArray{T,3} = similar(dw, ∇filter_im2col_dims(cdims)), alpha::T=T(1), beta::T=T(0)) where {T} check_dims(size(x), size(dw), size(dy), cdims) @@ -115,7 +116,7 @@ end ∇conv_data_im2col!(dx, w, dy, cdims, col=similar(dx); alpha=1, beta=0) Conv2d backward pass onto the input using im2col and GEMM; stores the result in `dx`. -See the documentation for `conv_im2col!()` for explanation of other parameters. +See [`conv_im2col!`](@ref) for explanation of optional parameters. """ function ∇conv_data_im2col!( dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, diff --git a/src/impl/depthwiseconv_im2col.jl b/src/impl/depthwiseconv_im2col.jl index b1c8adc2f..9d133174b 100644 --- a/src/impl/depthwiseconv_im2col.jl +++ b/src/impl/depthwiseconv_im2col.jl @@ -5,8 +5,7 @@ depthwiseconv_im2col!(y, x, w, cdims, col=similar(x); alpha=1, beta=0) Perform a depthwise convolution using im2col and GEMM, store the result in `y`. - -See `conv_im2col!()` for an explanation of optional parameters. +See [`conv_im2col!`](@ref) for explanation of optional parameters. """ depthwiseconv_im2col! @@ -48,17 +47,18 @@ function depthwiseconv_im2col!( end """ - ∇depthwiseconv_filter_im2col!(dw, w, dy, cdims, col=similar(dw); alpha=1, beta) + ∇depthwiseconv_filter_im2col!(dw, w, dy, cdims, col=similar(dw, ∇filter_im2col_dims(cdims)); + alpha=1, beta=0) -Depthwise conv2d backward pass onto the weights using im2col and GEMM. -See the documentation for `conv_im2col!()` for explanation of optional parameters. +Depthwise conv backward pass onto the weights using im2col and GEMM. +See [`conv_im2col!`](@ref) for explanation of optional parameters. """ ∇depthwiseconv_filter_im2col! function ∇depthwiseconv_filter_im2col!( dw::AbstractArray{T,5}, x::AbstractArray{T,5}, dy::AbstractArray{T,5}, cdims::DepthwiseConvDims; - col::AbstractArray{T,3} = similar(dw, im2col_dims(cdims)), + col::AbstractArray{T,3} = similar(dw, ∇filter_im2col_dims(cdims)), alpha::T=T(1), beta::T=T(0)) where T check_dims(size(x), size(dw), size(dy), cdims) @@ -66,9 +66,13 @@ function ∇depthwiseconv_filter_im2col!( N = channel_multiplier(cdims) K = prod(output_size(cdims)) - @threads for batch_idx in 1:size(x)[end] + for batch_idx in 1:size(x, 5) + # Because we accumulate over batches in this loop, we must set `beta` equal + # to `1.0` after the first sample. + beta′ = batch_idx == 1 ? beta : T(1) + # col_slice is a thread-local workspace - col_slice = view(col, :, :, threadid()) + col_slice = view(col, :, :, 1) im2col!(col_slice, view(x, :, :, :, :, batch_idx), cdims) # We do a separate convolution for each channel in x, as we must @@ -78,22 +82,18 @@ function ∇depthwiseconv_filter_im2col!( col_ptr = pointer(col_slice, (c_in - 1)*M*K + 1) dy_ptr = pointer(dy, (batch_idx - 1)*N*K*channels_in(cdims) + (c_in - 1)*K*N + 1) dw_ptr = pointer(dw, (c_in - 1)*M*N + 1) - gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr) + gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta′, dw_ptr) end end - - # Because we accumulate over batches in this loop, we must set `beta` equal - # to `1.0` from this point on. - beta = T(1) end return dw end """ - depthwiseconv2d_Δx_im2col!(dx, w, dy, cdims, col=similar(dx); alpha=1, beta=0) + ∇depthwiseconv_data_im2col!(dx, w, dy, cdims, col=similar(dx); alpha=1, beta=0) Depwthwise conv2d backward pass onto the input using im2col and GEMM. -See the documentation for `conv_im2col!()` for explanation of optional parameters. +See [`conv_im2col!`](@ref) for explanation of optional parameters. """ ∇depthwiseconv_data_im2col! diff --git a/test/conv.jl b/test/conv.jl index c3ce97ec9..969090d63 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -737,7 +737,7 @@ end end # https://github.com/FluxML/NNlib.jl/pull/171 -@testset "conv_direct! - Check Sizes" begin +@testset "conv_direct! - Check Sizes" begin x_size = (6, 7, 8, 5, 3) y_size = (5, 6, 7, 4, 3) w_size = (2, 2, 2, 5, 4) @@ -759,25 +759,15 @@ end y = conv(x, w, cdims) gradtest((y, w) -> ∇conv_data(y, w, cdims), y, w) - # if spatial_rank == 3 - # @test_broken gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w) - # else - gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w) - # end - gradtest((x, y) -> ∇conv_filter(x, y, cdims), x, y) - if spatial_rank < 3 - gradtest((x, y) -> sum(∇conv_filter(x, y, cdims)), x, y) - end + gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w) + gradtest((x, y) -> ∇conv_filter(x, y, cdims), x, y) + gradtest((x, y) -> sum(∇conv_filter(x, y, cdims)), x, y) dcdims = DepthwiseConvDims(x, w) gradtest((x, w) -> depthwiseconv(x, w, dcdims), x, w) # FIXME fails - # y = depthwiseconv(x, w, dcdims) - # gradtest((y, w) -> ∇depthwiseconv_data(y, w, dcdims), y, w) - # if spatial_rank == 3 - # @test_broken gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w) - # else - @test_skip gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w) - # end + y = depthwiseconv(x, w, dcdims) + gradtest((y, w) -> ∇depthwiseconv_data(y, w, dcdims), y, w) + gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w) end