diff --git a/src/dim_helpers/ConvDims.jl b/src/dim_helpers/ConvDims.jl index 1b1c3f271..a0edb9d80 100644 --- a/src/dim_helpers/ConvDims.jl +++ b/src/dim_helpers/ConvDims.jl @@ -77,6 +77,25 @@ function im2col_dims(c::ConvDims) ) end +""" + ∇filter_im2col_dims(c::ConvDims) + +Like [`im2col_dims`](@ref), but saves some memory because multiple (Julia) threads are +not required for the filter gradient calculation. + +Note: in the future, this may return `Dims{2}` instead of `Dims{3}`. +""" +function ∇filter_im2col_dims(c::ConvDims) + return ( + # Output size + prod(output_size(c)), + # Size of single dotproduct within convolution + prod(kernel_size(c))*channels_in(c), + # No threading, this is just here for backwards compat + 1 + ) +end + # Protect your skin, kids. Also do common validation of stride, padding, etc... function check_spdf(x_size::NTuple{N}, w_size::NTuple{N}, stride, padding, dilation) where {N} # Number of spatial dimensions in `x` and `w`. diff --git a/src/impl/conv_im2col.jl b/src/impl/conv_im2col.jl index 38f2ff50b..3d2702b8b 100644 --- a/src/impl/conv_im2col.jl +++ b/src/impl/conv_im2col.jl @@ -60,15 +60,16 @@ function conv_im2col!( end """ - ∇conv_filter_im2col!(dw, x, dy, cdims, col=similar(dw); alpha=1, beta=0) + ∇conv_filter_im2col!(dw, x, dy, cdims, col=similar(dw, ∇filter_im2col_dims(cdims)); + alpha=1, beta=0) Conv backward pass onto the weights using im2col and GEMM; stores the result in `dw`. -See the documentation for `conv_im2col!()` for explanation of optional parameters. +See [`conv_im2col!`](@ref) for explanation of optional parameters. """ function ∇conv_filter_im2col!( dw::AbstractArray{T,5}, x::AbstractArray{T,5}, dy::AbstractArray{T,5}, cdims::DenseConvDims; - col::AbstractArray{T,3} = similar(dw, im2col_dims(cdims)), + col::AbstractArray{T,3} = similar(dw, ∇filter_im2col_dims(cdims)), alpha::T=T(1), beta::T=T(0)) where {T} check_dims(size(x), size(dw), size(dy), cdims) @@ -115,7 +116,7 @@ end ∇conv_data_im2col!(dx, w, dy, cdims, col=similar(dx); alpha=1, beta=0) Conv2d backward pass onto the input using im2col and GEMM; stores the result in `dx`. -See the documentation for `conv_im2col!()` for explanation of other parameters. +See [`conv_im2col!`](@ref) for explanation of optional parameters. """ function ∇conv_data_im2col!( dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, diff --git a/src/impl/depthwiseconv_im2col.jl b/src/impl/depthwiseconv_im2col.jl index b1c8adc2f..9d133174b 100644 --- a/src/impl/depthwiseconv_im2col.jl +++ b/src/impl/depthwiseconv_im2col.jl @@ -5,8 +5,7 @@ depthwiseconv_im2col!(y, x, w, cdims, col=similar(x); alpha=1, beta=0) Perform a depthwise convolution using im2col and GEMM, store the result in `y`. - -See `conv_im2col!()` for an explanation of optional parameters. +See [`conv_im2col!`](@ref) for explanation of optional parameters. """ depthwiseconv_im2col! @@ -48,17 +47,18 @@ function depthwiseconv_im2col!( end """ - ∇depthwiseconv_filter_im2col!(dw, w, dy, cdims, col=similar(dw); alpha=1, beta) + ∇depthwiseconv_filter_im2col!(dw, w, dy, cdims, col=similar(dw, ∇filter_im2col_dims(cdims)); + alpha=1, beta=0) -Depthwise conv2d backward pass onto the weights using im2col and GEMM. -See the documentation for `conv_im2col!()` for explanation of optional parameters. +Depthwise conv backward pass onto the weights using im2col and GEMM. +See [`conv_im2col!`](@ref) for explanation of optional parameters. """ ∇depthwiseconv_filter_im2col! function ∇depthwiseconv_filter_im2col!( dw::AbstractArray{T,5}, x::AbstractArray{T,5}, dy::AbstractArray{T,5}, cdims::DepthwiseConvDims; - col::AbstractArray{T,3} = similar(dw, im2col_dims(cdims)), + col::AbstractArray{T,3} = similar(dw, ∇filter_im2col_dims(cdims)), alpha::T=T(1), beta::T=T(0)) where T check_dims(size(x), size(dw), size(dy), cdims) @@ -66,9 +66,13 @@ function ∇depthwiseconv_filter_im2col!( N = channel_multiplier(cdims) K = prod(output_size(cdims)) - @threads for batch_idx in 1:size(x)[end] + for batch_idx in 1:size(x, 5) + # Because we accumulate over batches in this loop, we must set `beta` equal + # to `1.0` after the first sample. + beta′ = batch_idx == 1 ? beta : T(1) + # col_slice is a thread-local workspace - col_slice = view(col, :, :, threadid()) + col_slice = view(col, :, :, 1) im2col!(col_slice, view(x, :, :, :, :, batch_idx), cdims) # We do a separate convolution for each channel in x, as we must @@ -78,22 +82,18 @@ function ∇depthwiseconv_filter_im2col!( col_ptr = pointer(col_slice, (c_in - 1)*M*K + 1) dy_ptr = pointer(dy, (batch_idx - 1)*N*K*channels_in(cdims) + (c_in - 1)*K*N + 1) dw_ptr = pointer(dw, (c_in - 1)*M*N + 1) - gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr) + gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta′, dw_ptr) end end - - # Because we accumulate over batches in this loop, we must set `beta` equal - # to `1.0` from this point on. - beta = T(1) end return dw end """ - depthwiseconv2d_Δx_im2col!(dx, w, dy, cdims, col=similar(dx); alpha=1, beta=0) + ∇depthwiseconv_data_im2col!(dx, w, dy, cdims, col=similar(dx); alpha=1, beta=0) Depwthwise conv2d backward pass onto the input using im2col and GEMM. -See the documentation for `conv_im2col!()` for explanation of optional parameters. +See [`conv_im2col!`](@ref) for explanation of optional parameters. """ ∇depthwiseconv_data_im2col! diff --git a/test/conv.jl b/test/conv.jl index c3ce97ec9..e93100217 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -711,7 +711,7 @@ end end # Currently hangs due to a FiniteDifferences issue - @test_skip gradtest((x, w) -> sum(conv(x, w, cdims)), x′, w′) + gradtest((x, w) -> sum(conv(x, w, cdims)), x′, w′) end @testset "conv_wrapper" begin @@ -759,25 +759,15 @@ end y = conv(x, w, cdims) gradtest((y, w) -> ∇conv_data(y, w, cdims), y, w) - # if spatial_rank == 3 - # @test_broken gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w) - # else - gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w) - # end - gradtest((x, y) -> ∇conv_filter(x, y, cdims), x, y) - if spatial_rank < 3 - gradtest((x, y) -> sum(∇conv_filter(x, y, cdims)), x, y) - end + gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w) + gradtest((x, y) -> ∇conv_filter(x, y, cdims), x, y) + gradtest((x, y) -> sum(∇conv_filter(x, y, cdims)), x, y) dcdims = DepthwiseConvDims(x, w) gradtest((x, w) -> depthwiseconv(x, w, dcdims), x, w) # FIXME fails - # y = depthwiseconv(x, w, dcdims) - # gradtest((y, w) -> ∇depthwiseconv_data(y, w, dcdims), y, w) - # if spatial_rank == 3 - # @test_broken gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w) - # else - @test_skip gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w) - # end + y = depthwiseconv(x, w, dcdims) + gradtest((y, w) -> ∇depthwiseconv_data(y, w, dcdims), y, w) + gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w) end