Skip to content

conv_direct!(): The performance fix #142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/dim_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,15 @@ spatial dimension at the end of the spatial dimensions. This does so for a Conv
)
end

@inline function insert_singleton_spatial_dimension(x::AbstractArray)
return reshape(x, size(x)[1:end-2]..., 1, size(x)[end-1:end]...)
# We specialize common cases
@inline function insert_singleton_spatial_dimension(x::AbstractArray{T,3}) where {T}
return reshape(x, size(x,1), 1, size(x,2), size(x,3))
end
@inline function insert_singleton_spatial_dimension(x::AbstractArray{T,4}) where {T}
return reshape(x, size(x,1), size(x,2), 1, size(x,3), size(x,4))
end

# Helper to do this multiple times
# Helper to do this as many times as needed
@inline function insert_singleton_spatial_dimension(x, reps::Int)
for r in 1:reps
x = insert_singleton_spatial_dimension(x)
Expand Down
16 changes: 8 additions & 8 deletions src/dim_helpers/DenseConvDims.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,16 @@ end

function check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::DenseConvDims) where {M}
# First, check that channel counts are all correct:
@assert x[end-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[end-1]) vs. $(channels_in(cdims)))")
@assert y[end-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[end-1]) vs. $(channels_out(cdims)))")
@assert w[end-1] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[end-1]) vs. $(channels_in(cdims)))")
@assert w[end] == channels_out(cdims) DimensionMismatch("Kernel output channel count ($(w[end]) vs. $(channels_out(cdims)))")
@assert x[M-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[M-1]) vs. $(channels_in(cdims)))")
@assert y[M-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[M-1]) vs. $(channels_out(cdims)))")
@assert w[M-1] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[M-1]) vs. $(channels_in(cdims)))")
@assert w[M] == channels_out(cdims) DimensionMismatch("Kernel output channel count ($(w[M]) vs. $(channels_out(cdims)))")

# Next, check that the spatial dimensions match up
@assert x[1:end-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:end-2]) vs. $(input_size(cdims)))")
@assert y[1:end-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:end-2]) vs. $(output_size(cdims)))")
@assert w[1:end-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:end-2]) vs. $(kernel_size(cdims)))")
@assert x[1:M-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:M-2]) vs. $(input_size(cdims)))")
@assert y[1:M-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:M-2]) vs. $(output_size(cdims)))")
@assert w[1:M-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:M-2]) vs. $(kernel_size(cdims)))")

# Finally, check that the batch size matches
@assert x[end] == y[end] DimensionMismatch("Batch size ($(x[end]) vs. $(y[end]))")
@assert x[M] == y[M] DimensionMismatch("Batch size ($(x[M]) vs. $(y[M]))")
end
46 changes: 20 additions & 26 deletions src/dim_helpers/DepthwiseConvDims.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,16 @@ Concrete subclass of `ConvDims` for a depthwise convolution. Differs primarily
characterization by C_in, C_mult, rather than C_in, C_out. Useful to be separate from
DenseConvDims primarily for channel calculation differences.
"""
struct DepthwiseConvDims{N,S,P,D,F} <: ConvDims{N,S,P,D,F}
struct DepthwiseConvDims{N,K,C_in,C_mult,S,P,D,F} <: ConvDims{N,S,P,D,F}
I::NTuple{N, Int}
K::NTuple{N, Int}
C_in::Int
C_mult::Int
end

# Getters for the fields
input_size(c::DepthwiseConvDims) = c.I
kernel_size(c::DepthwiseConvDims) = c.K
channels_in(c::DepthwiseConvDims) = c.C_in
channels_out(c::DepthwiseConvDims) = c.C_in * channel_multiplier(c)
channel_multiplier(c::DepthwiseConvDims) = c.C_mult
kernel_size(c::DepthwiseConvDims{N,K,C_in,C_mult,S,P,D,F}) where {N,K,C_in,C_mult,S,P,D,F} = K
channels_in(c::DepthwiseConvDims{N,K,C_in,C_mult,S,P,D,F}) where {N,K,C_in,C_mult,S,P,D,F} = C_in
channels_out(c::DepthwiseConvDims{N,K,C_in,C_mult,S,P,D,F}) where {N,K,C_in,C_mult,S,P,D,F} = C_in * C_mult
channel_multiplier(c::DepthwiseConvDims{N,K,C_in,C_mult,S,P,D,F}) where {N,K,C_in,C_mult,S,P,D,F} = C_mult


# Convenience wrapper to create DepthwiseConvDims objects
Expand All @@ -37,22 +34,19 @@ function DepthwiseConvDims(x_size::NTuple{M}, w_size::NTuple{M};

return DepthwiseConvDims{
M - 2,
# Kernel spatial size
w_size[1:end-2],
# Input channels
x_size[end-1],
# Channel multiplier
w_size[end-1],
stride,
padding,
dilation,
flipkernel
}(
# Image spatial size
x_size[1:end-2],

# Kernel spatial size
w_size[1:end-2],

# Input channels
x_size[end-1],

# Channel multiplier
w_size[end-1],
)
end

Expand All @@ -69,22 +63,22 @@ end
function DepthwiseConvDims(c::DepthwiseConvDims; N=spatial_dims(c), I=input_size(c), K=kernel_size(c),
C_in=channels_in(c), C_m=channel_multiplier(c), S=stride(c),
P=padding(c), D=dilation(c), F=flipkernel(c))
return DepthwiseConvDims{N, S, P, D, F}(I, K, C_in, C_m)
return DepthwiseConvDims{N, K, C_in, C_m, S, P, D, F}(I)
end

# This one is basically the same as for DenseConvDims, we only change a few lines for kernel channel count
function check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::DepthwiseConvDims) where {M}
# First, check that channel counts are all correct:
@assert x[end-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[end-1]) vs. $(channels_in(cdims)))")
@assert y[end-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[end-1]) vs. $(channels_out(cdims)))")
@assert w[end-1] == channel_multiplier(cdims) DimensionMismatch("Kernel multiplier channel count ($(w[end-1]) vs. $(channel_multiplier(cdims))")
@assert w[end] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[end]) vs. $(channels_in(cdims)))")
@assert x[M-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[M-1]) vs. $(channels_in(cdims)))")
@assert y[M-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[M-1]) vs. $(channels_out(cdims)))")
@assert w[M-1] == channel_multiplier(cdims) DimensionMismatch("Kernel multiplier channel count ($(w[M-1]) vs. $(channel_multiplier(cdims))")
@assert w[M] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[M]) vs. $(channels_in(cdims)))")

# Next, check that the spatial dimensions match up
@assert x[1:end-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:end-2]) vs. $(input_size(cdims)))")
@assert y[1:end-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:end-2]) vs. $(output_size(cdims)))")
@assert w[1:end-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:end-2]) vs. $(kernel_size(cdims)))")
@assert x[1:M-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:M-2]) vs. $(input_size(cdims)))")
@assert y[1:M-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:M-2]) vs. $(output_size(cdims)))")
@assert w[1:M-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:M-2]) vs. $(kernel_size(cdims)))")

# Finally, check that the batch size matches
@assert x[end] == y[end] DimensionMismatch("Batch size ($(x[end]) vs. $(y[end]))")
@assert x[M] == y[M] DimensionMismatch("Batch size ($(x[M]) vs. $(y[M]))")
end
120 changes: 79 additions & 41 deletions src/impl/conv_direct.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
## This file contains direct Julia implementations of 2d and 3d convolutions
using Base.Threads

# Helper functions for restricting x/w overreach
function clamp_lo(x, w)
Expand Down Expand Up @@ -57,50 +58,87 @@ function conv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5},
stride_w, stride_h, stride_d = stride(cdims)
out_width, out_height, out_depth = output_size(cdims)

# If we're doing crosscorr instead of conv, then don't bother to flip `w`
if !flipkernel(cdims)
w = w[end:-1:1, end:-1:1, end:-1:1, :, :]
end

# Create a method that, at compile-time, determines how we're going to index into `w`
kproj(k, M, cdims::ConvDims{N,S,P,D,true}) where {N, S, P, D} = k
kproj(k, M, cdims::ConvDims{N,S,P,D,false}) where {N, S, P, D} = M - k + 1

# A helper function to project from output (w, h) to input (input_w, input_h)
@inline project(idx, stride, pad) = (idx - 1)*stride - pad + 1
project(idx, stride, pad) = (idx - 1)*stride - pad + 1

# explicit formulation of convolution. Oh hoisting gods, hear my plea.
@inbounds for batch in 1:size(x)[end],
# Use `calc_padding_regions` to determine where we do or don't need to worry about padding
padded_regions, central_region = calc_padding_regions(cdims)

# Start with the central region
w_region, h_region, d_region = central_region
@inbounds for batch in 1:size(x, 5),
c_out in 1:out_c,
d_idx in d_region,
h_idx in h_region,
w_idx in w_region

# Since we're in the central region, we don't need to worry about clamping
dotprod = yT(0)
for c_in in 1:channels_in(cdims),
kd in 1:kernel_d,
kh in 1:kernel_h,
kw in 1:kernel_w

# Hoist me, you coward.
x_d = project(d_idx, stride_d, pad_d_lo) + (kd - 1)*dil_d
x_h = project(h_idx, stride_h, pad_h_lo) + (kh - 1)*dil_h
x_w = project(w_idx, stride_w, pad_w_lo) + (kw - 1)*dil_w

x_val = x[x_w, x_h, x_d, c_in, batch]
w_val = w[kproj(kw, kernel_w, cdims),
kproj(kh, kernel_h, cdims),
kproj(kd, kernel_d, cdims),
c_in, c_out]
dotprod = muladd(x_val, w_val, dotprod)
end
y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch]
end

# Next, do potentially-padded regions:
@inbounds for (w_region, h_region, d_region) in padded_regions,
batch in 1:size(x, 5),
c_out in 1:out_c,
d_idx in 1:out_depth,
h_idx in 1:out_height,
w_idx in 1:out_width

# Starting points of the window of x we're going to grab
x_w = project(w_idx, stride_w, pad_w_lo)
x_h = project(h_idx, stride_h, pad_h_lo)
x_d = project(d_idx, stride_d, pad_d_lo)

# Grow that starting point into ranges
x_widxs = x_w .+ (0:dil_w:(dil_w*kernel_w-1))
x_hidxs = x_h .+ (0:dil_h:(dil_h*kernel_h-1))
x_didxs = x_d .+ (0:dil_d:(dil_d*kernel_d-1))
w_widxs = 1:kernel_w
w_hidxs = 1:kernel_h
w_didxs = 1:kernel_d

# Clamp the ranges to simulate padding
x_widxs, w_widxs = clamp_lo(x_widxs, w_widxs)
x_widxs, w_widxs = clamp_hi(x_widxs, w_widxs, width)
x_hidxs, w_hidxs = clamp_lo(x_hidxs, w_hidxs)
x_hidxs, w_hidxs = clamp_hi(x_hidxs, w_hidxs, height)
x_didxs, w_didxs = clamp_lo(x_didxs, w_didxs)
x_didxs, w_didxs = clamp_hi(x_didxs, w_didxs, depth)

# Grab our slices
x_slice = view(x, x_widxs, x_hidxs, x_didxs, :, batch)
w_slice = view(w, w_widxs, w_hidxs, w_didxs, :, c_out)

# Do the dotproduct dance, then weight by alpha/beta and git 'er done
dotprod = sum(x_slice .* w_slice)
y[w_idx, h_idx, d_idx, c_out, batch] = alpha*convert(yT, dotprod) +
beta*y[w_idx, h_idx, d_idx, c_out, batch]
d_idx in d_region,
h_idx in h_region,
w_idx in w_region

# Probe for out-of-bounds accesses on `x` and `continue` if we hit one
dotprod = yT(0)
for c_in in 1:channels_in(cdims),
kd in 1:kernel_d

x_d = project(d_idx, stride_d, pad_d_lo) + (kd - 1)*dil_d
if x_d <= 0 || x_d > depth
continue
end

for kh in 1:kernel_h
x_h = project(h_idx, stride_h, pad_h_lo) + (kh - 1)*dil_h
if x_h <= 0 || x_h > height
continue
end

for kw in 1:kernel_w
x_w = project(w_idx, stride_w, pad_w_lo) + (kw - 1)*dil_w
if x_w <= 0 || x_w > width
continue
end

x_val = x[x_w, x_h, x_d, c_in, batch]
w_val = w[kproj(kw, kernel_w, cdims),
kproj(kh, kernel_h, cdims),
kproj(kd, kernel_d, cdims),
c_in, c_out]
dotprod = muladd(x_val, w_val, dotprod)
end
end
end

y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch]
end

return y
Expand Down
6 changes: 3 additions & 3 deletions src/impl/conv_im2col.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ function conv_im2col!(
N = channels_out(cdims)
K = prod(kernel_size(cdims))*channels_in(cdims)

@inbounds for batch_idx in 1:size(x,5)
@threads for batch_idx in 1:size(x,5)
# We invoke `@timeit_debug` on the outside of `im2col!()` because inference
# doesn't like us putting it on the inside.
im2col!(col, view(x, :, :, :, :, batch_idx), cdims)
Expand Down Expand Up @@ -94,7 +94,7 @@ function ∇conv_filter_im2col!(
N = channels_out(cdims)
K = prod(output_size(cdims))

@inbounds for batch_idx in 1:size(x,5)
@threads for batch_idx in 1:size(x,5)
im2col!(col, view(x, :, :, :, :, batch_idx), cdims)
GC.@preserve col, dw, dy, begin
col_ptr = pointer(col)
Expand Down Expand Up @@ -142,7 +142,7 @@ function ∇conv_data_im2col!(
N = prod(kernel_size(cdims))*channels_in(cdims)
K = channels_out(cdims)

@inbounds for batch_idx in 1:size(dx, 5)
@threads for batch_idx in 1:size(dx, 5)
GC.@preserve col, w, dy, begin
dy_ptr = pointer(dy, (batch_idx - 1)*M*K + 1)
w_ptr = pointer(w)
Expand Down
Loading