Skip to content

Commit

Permalink
Remove threading from all ∇*conv_filter and re-enable old tests
Browse files Browse the repository at this point in the history
This should hopefully fix persistent threading-related issues on CI,
While also slightly reducing memory usage when taking gradients.
  • Loading branch information
ToucheSir committed Nov 2, 2022
1 parent f5fd67e commit 1ce5720
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 36 deletions.
19 changes: 19 additions & 0 deletions src/dim_helpers/ConvDims.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,25 @@ function im2col_dims(c::ConvDims)
)
end

"""
∇filter_im2col_dims(c::ConvDims)
Like [`im2col_dims`](@ref), but saves some memory because multiple (Julia) threads are
not required for the filter gradient calculation.
Note: in the future, this may return `Dims{2}` instead of `Dims{3}`.
"""
function ∇filter_im2col_dims(c::ConvDims)
return (
# Output size
prod(output_size(c)),
# Size of single dotproduct within convolution
prod(kernel_size(c))*channels_in(c),
# No threading, this is just here for backwards compat
1
)
end

# Protect your skin, kids. Also do common validation of stride, padding, etc...
function check_spdf(x_size::NTuple{N}, w_size::NTuple{N}, stride, padding, dilation) where {N}
# Number of spatial dimensions in `x` and `w`.
Expand Down
9 changes: 5 additions & 4 deletions src/impl/conv_im2col.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,16 @@ function conv_im2col!(
end

"""
∇conv_filter_im2col!(dw, x, dy, cdims, col=similar(dw); alpha=1, beta=0)
∇conv_filter_im2col!(dw, x, dy, cdims, col=similar(dw, ∇filter_im2col_dims(cdims));
alpha=1, beta=0)
Conv backward pass onto the weights using im2col and GEMM; stores the result in `dw`.
See the documentation for `conv_im2col!()` for explanation of optional parameters.
See [`conv_im2col!`](@ref) for explanation of optional parameters.
"""
function ∇conv_filter_im2col!(
dw::AbstractArray{T,5}, x::AbstractArray{T,5},
dy::AbstractArray{T,5}, cdims::DenseConvDims;
col::AbstractArray{T,3} = similar(dw, im2col_dims(cdims)),
col::AbstractArray{T,3} = similar(dw, ∇filter_im2col_dims(cdims)),
alpha::T=T(1), beta::T=T(0)) where {T}
check_dims(size(x), size(dw), size(dy), cdims)

Expand Down Expand Up @@ -115,7 +116,7 @@ end
∇conv_data_im2col!(dx, w, dy, cdims, col=similar(dx); alpha=1, beta=0)
Conv2d backward pass onto the input using im2col and GEMM; stores the result in `dx`.
See the documentation for `conv_im2col!()` for explanation of other parameters.
See [`conv_im2col!`](@ref) for explanation of optional parameters.
"""
function ∇conv_data_im2col!(
dx::AbstractArray{T,5}, dy::AbstractArray{T,5},
Expand Down
30 changes: 15 additions & 15 deletions src/impl/depthwiseconv_im2col.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
depthwiseconv_im2col!(y, x, w, cdims, col=similar(x); alpha=1, beta=0)
Perform a depthwise convolution using im2col and GEMM, store the result in `y`.
See `conv_im2col!()` for an explanation of optional parameters.
See [`conv_im2col!`](@ref) for explanation of optional parameters.
"""
depthwiseconv_im2col!

Expand Down Expand Up @@ -48,27 +47,32 @@ function depthwiseconv_im2col!(
end

"""
∇depthwiseconv_filter_im2col!(dw, w, dy, cdims, col=similar(dw); alpha=1, beta)
∇depthwiseconv_filter_im2col!(dw, w, dy, cdims, col=similar(dw, ∇filter_im2col_dims(cdims));
alpha=1, beta=0)
Depthwise conv2d backward pass onto the weights using im2col and GEMM.
See the documentation for `conv_im2col!()` for explanation of optional parameters.
Depthwise conv backward pass onto the weights using im2col and GEMM.
See [`conv_im2col!`](@ref) for explanation of optional parameters.
"""
∇depthwiseconv_filter_im2col!

function ∇depthwiseconv_filter_im2col!(
dw::AbstractArray{T,5}, x::AbstractArray{T,5},
dy::AbstractArray{T,5}, cdims::DepthwiseConvDims;
col::AbstractArray{T,3} = similar(dw, im2col_dims(cdims)),
col::AbstractArray{T,3} = similar(dw, ∇filter_im2col_dims(cdims)),
alpha::T=T(1), beta::T=T(0)) where T
check_dims(size(x), size(dw), size(dy), cdims)

M = prod(kernel_size(cdims))
N = channel_multiplier(cdims)
K = prod(output_size(cdims))

@threads for batch_idx in 1:size(x)[end]
for batch_idx in 1:size(x, 5)
# Because we accumulate over batches in this loop, we must set `beta` equal
# to `1.0` after the first sample.
beta′ = batch_idx == 1 ? beta : T(1)

# col_slice is a thread-local workspace
col_slice = view(col, :, :, threadid())
col_slice = view(col, :, :, 1)
im2col!(col_slice, view(x, :, :, :, :, batch_idx), cdims)

# We do a separate convolution for each channel in x, as we must
Expand All @@ -78,22 +82,18 @@ function ∇depthwiseconv_filter_im2col!(
col_ptr = pointer(col_slice, (c_in - 1)*M*K + 1)
dy_ptr = pointer(dy, (batch_idx - 1)*N*K*channels_in(cdims) + (c_in - 1)*K*N + 1)
dw_ptr = pointer(dw, (c_in - 1)*M*N + 1)
gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr)
gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr)
end
end

# Because we accumulate over batches in this loop, we must set `beta` equal
# to `1.0` from this point on.
beta = T(1)
end
return dw
end

"""
depthwiseconv2d_Δx_im2col!(dx, w, dy, cdims, col=similar(dx); alpha=1, beta=0)
∇depthwiseconv_data_im2col!(dx, w, dy, cdims, col=similar(dx); alpha=1, beta=0)
Depwthwise conv2d backward pass onto the input using im2col and GEMM.
See the documentation for `conv_im2col!()` for explanation of optional parameters.
See [`conv_im2col!`](@ref) for explanation of optional parameters.
"""
∇depthwiseconv_data_im2col!

Expand Down
24 changes: 7 additions & 17 deletions test/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ end
end

# Currently hangs due to a FiniteDifferences issue
@test_skip gradtest((x, w) -> sum(conv(x, w, cdims)), x′, w′)
gradtest((x, w) -> sum(conv(x, w, cdims)), x′, w′)
end

@testset "conv_wrapper" begin
Expand Down Expand Up @@ -759,25 +759,15 @@ end

y = conv(x, w, cdims)
gradtest((y, w) -> ∇conv_data(y, w, cdims), y, w)
# if spatial_rank == 3
# @test_broken gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w)
# else
gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w)
# end
gradtest((x, y) -> ∇conv_filter(x, y, cdims), x, y)
if spatial_rank < 3
gradtest((x, y) -> sum(∇conv_filter(x, y, cdims)), x, y)
end
gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w)
gradtest((x, y) -> ∇conv_filter(x, y, cdims), x, y)
gradtest((x, y) -> sum(∇conv_filter(x, y, cdims)), x, y)

dcdims = DepthwiseConvDims(x, w)
gradtest((x, w) -> depthwiseconv(x, w, dcdims), x, w)

# FIXME fails
# y = depthwiseconv(x, w, dcdims)
# gradtest((y, w) -> ∇depthwiseconv_data(y, w, dcdims), y, w)
# if spatial_rank == 3
# @test_broken gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w)
# else
@test_skip gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w)
# end
y = depthwiseconv(x, w, dcdims)
gradtest((y, w) -> ∇depthwiseconv_data(y, w, dcdims), y, w)
gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w)
end

0 comments on commit 1ce5720

Please sign in to comment.