Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reflect mode parameter from NNlib #53

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions src/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,15 @@ padtuple(x::Tuple,p::Integer) = map(_->p, head(head(x)))
padtuple(x::Tuple,p::Tuple) = p
padtuple(x::AbstractArray,p) = padtuple(size(x),p)

function conv(x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray
function conv(x::A, w::A; pad = 0, stride = 1, dilation = 1, flipkernel=true) where A<:AbstractArray
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
if flipkernel
conv!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)),
x, w, pad = pad_, stride = stride_, dilation = dilation)
x, w, pad = pad_, stride = stride_, dilation = dilation, flipkernel=flipkernel)
else
crosscor!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't make sense to do an if/else here; you pass flipkernel through so removing it would behave the same way.

It would however be good to have an out-of-place crosscor function that calls crosscor!.

x, w, pad = pad_, stride = stride_, dilation = dilation, flipkernel=flipkernel)
end
end

∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray =
Expand All @@ -39,9 +44,10 @@ end
# N-D dispatch

function conv!(y::AbstractArray{T,3}, x::AbstractArray{T,3}, w::AbstractArray{T,3};
pad = 0, stride = 1, dilation = 1) where T
pad = 0, stride = 1, dilation = 1, flipkernel=true) where T
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (y, x, w))
conv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1))
conv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1),
flipkernel= flipkernel)
return y
end

Expand All @@ -62,8 +68,12 @@ function ∇conv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3},
end

conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
pad = 0, stride = 1, dilation = 1) where T =
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation)
pad = 0, stride = 1, dilation = 1, flipkernel=true) where T =
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation, flipkernel=flipkernel)

crosscor!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cross correlation should have flipkernel=false, or at least be different from conv!, no? I also think it'd be better to keep this generic to AbstractArrays and just forward whatever the arguments are to conv!.

pad = 0, stride = 1, dilation = 1, flipkernel=true) where T =
conv!(y, x, w, pad = pad, stride = stride, dilation = dilation, flipkernel=flipkernel)

∇conv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
pad = 0, stride = 1, dilation = 1) where T =
Expand All @@ -74,8 +84,8 @@ conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
conv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)

conv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
pad = 0, stride = 1, dilation = 1) where T =
conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation)
pad = 0, stride = 1, dilation = 1, flipkernel=true) where T =
conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation, flipkernel=flipkernel)

∇conv_filter!(dw::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
pad = 0, stride = 1, dilation = 1) where T =
Expand All @@ -91,14 +101,14 @@ function dcdims(x::NTuple{4,Int}, w::NTuple{4,Int}, pad, stride)
((x[1] + 2 * pad[1] - w[1])÷stride[1] + 1,(x[2] + 2 * pad[2] - w[2])÷stride[2] + 1,w[3]*w[4],x[4])
end

function depthwiseconv(x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray
function depthwiseconv(x::A, w::A; pad = 0, stride = 1, flipkernel = false) where A<:AbstractArray
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
depthwiseconv!(similar(x, dcdims(size(x), size(w), pad_, stride_)), x, w, pad = pad_, stride = stride_)
depthwiseconv!(similar(x, dcdims(size(x), size(w), pad_, stride_)), x, w, pad = pad_, stride = stride_, flipkernel = flipkernel)
end

depthwiseconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
pad = 0, stride = 1) where T =
depthwiseconv2d!(y, x, w, padding = pad, stride = stride)
pad = 0, stride = 1, flipkernel = true) where T =
depthwiseconv2d!(y, x, w, padding = pad, stride = stride, flipkernel = flipkernel)

∇depthwiseconv_data(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray =
∇depthwiseconv_data!(zeros(x), dy, x, w; pad = pad, stride = stride)
Expand Down
78 changes: 38 additions & 40 deletions src/impl/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ end

function im2col_2d!(img::AbstractArray{T,3}, col::AbstractArray{T,2}, width::Int, height::Int, channels::Int,
kernel_w::Int, kernel_h::Int, pad_w::Int, pad_h::Int, stride_w::Int, stride_h::Int,
dil_w::Int, dil_h::Int, mode::Int) where T
dil_w::Int, dil_h::Int, flipkernel::Bool) where T

height_col = div(height + 2pad_h - (kernel_h - 1) * dil_h - 1, stride_h) + 1
width_col = div(width + 2pad_w - (kernel_w - 1) * dil_w - 1, stride_w) + 1
Expand All @@ -23,7 +23,7 @@ function im2col_2d!(img::AbstractArray{T,3}, col::AbstractArray{T,2}, width::Int
w_offset = (c - 1) % kernel_w
h_offset = div(c - 1, kernel_w) % kernel_h
c_im = div(c - 1, kernel_h * kernel_w)
if mode == 0
if flipkernel
w_offset = kernel_w - 1 - w_offset
h_offset = kernel_h - 1 - h_offset
end
Expand All @@ -44,7 +44,7 @@ end

function col2im_2d!(col::AbstractArray{T,2}, img::AbstractArray{T,3}, width::Int, height::Int,
channels::Int, kernel_w::Int, kernel_h::Int, pad_w::Int, pad_h::Int, stride_w::Int,
stride_h::Int, dil_w::Int, dil_h::Int, mode::Int) where T
stride_h::Int, dil_w::Int, dil_h::Int, flipkernel::Bool) where T

height_col = div(height + 2pad_h - (kernel_h - 1) * dil_h - 1, stride_h) + 1
width_col = div(width + 2pad_w - (kernel_w - 1) * dil_w - 1, stride_w) + 1
Expand All @@ -56,7 +56,7 @@ function col2im_2d!(col::AbstractArray{T,2}, img::AbstractArray{T,3}, width::Int
w_offset = (c - 1) % kernel_w
h_offset = div(c - 1, kernel_w) % kernel_h
c_im = div(c - 1, kernel_h * kernel_w)
if mode == 0
if flipkernel
w_offset = kernel_w - 1 - w_offset
h_offset = kernel_h - 1 - h_offset
end
Expand All @@ -73,7 +73,7 @@ end

function im2col_3d!(img::AbstractArray{T,4}, col::AbstractArray{T,2}, width::Int, height::Int, depth::Int,
channels::Int, kernel_w::Int, kernel_h::Int, kernel_d::Int, pad_w::Int, pad_h::Int, pad_d::Int,
stride_w::Int, stride_h::Int, stride_d::Int, dil_w::Int, dil_h::Int, dil_d::Int, mode::Int) where T
stride_w::Int, stride_h::Int, stride_d::Int, dil_w::Int, dil_h::Int, dil_d::Int, flipkernel::Bool) where T

height_col = div(height + 2pad_h - (kernel_h - 1) * dil_h - 1, stride_h) + 1
width_col = div(width + 2pad_w - (kernel_w - 1) * dil_w - 1, stride_w) + 1
Expand All @@ -87,7 +87,7 @@ function im2col_3d!(img::AbstractArray{T,4}, col::AbstractArray{T,2}, width::Int
h_offset = div(c - 1, kernel_w) % kernel_h
d_offset = div(c - 1, kernel_w * kernel_h) % kernel_d
c_im = div(c - 1, kernel_w * kernel_h * kernel_d)
if mode == 0
if flipkernel
w_offset = kernel_w - 1 - w_offset
h_offset = kernel_h - 1 - h_offset
d_offset = kernel_d - 1 - d_offset
Expand All @@ -110,7 +110,7 @@ end
function col2im_3d!(col::AbstractArray{T,2}, img::AbstractArray{T,4}, width::Int, height::Int,
depth::Int, channels::Int, kernel_w::Int, kernel_h::Int, kernel_d::Int,
pad_w::Int, pad_h::Int, pad_d::Int, stride_w::Int, stride_h::Int, stride_d::Int,
dil_w::Int, dil_h::Int, dil_d::Int, mode::Int) where T
dil_w::Int, dil_h::Int, dil_d::Int, flipkernel::Bool) where T

height_col = div(height + 2pad_h - (kernel_h - 1) * dil_h - 1, stride_h) + 1
width_col = div(width + 2pad_w - (kernel_w - 1) * dil_w - 1, stride_w) + 1
Expand All @@ -125,7 +125,7 @@ function col2im_3d!(col::AbstractArray{T,2}, img::AbstractArray{T,4}, width::Int
d_offset = div(c - 1, kernel_w * kernel_h) % kernel_d
c_im = div(c - 1, kernel_h * kernel_w * kernel_d)

if mode == 0
if flipkernel
w_offset = kernel_w - 1 - w_offset
h_offset = kernel_h - 1 - h_offset
d_offset = kernel_d - 1 - d_offset
Expand Down Expand Up @@ -182,7 +182,7 @@ function im2col_dims(w::NTuple{4, Int}, y)
end

function depthwiseconv2d!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
padding = 0, stride = 1, mode = 1, alpha = T(1)) where T
padding = 0, stride = 1, flipkernel = true, alpha = T(1)) where T
Wx,Hx,Cx,Nx = size(x)
Ww,Hw,Cm,Cw = size(w) # Cm = Channel Multiplier
@assert Cx == Cw DimensionMismatch()
Expand All @@ -195,7 +195,7 @@ function depthwiseconv2d!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::Abstr
M,N,K,Y = Wy*Hy,Cm,Ww*Hw,Wy*Hy*Cm
yidx = 1
@inbounds for i in 1:Nx
im2col2d!(dims_w, x, x2, i, p1, p2, s1, s2, mode)
im2col2d!(dims_w, x, x2, i, p1, p2, s1, s2, flipkernel)
@inbounds for j in 1:Cx
gemm!('N','N',M,N,K,alpha,pointer(x2,(j-1)*M*K+1),pointer(w,(j-1)*K*N+1),T(0),pointer(y,yidx))
yidx += Y
Expand All @@ -205,7 +205,7 @@ function depthwiseconv2d!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::Abstr
end

function depthwiseconv2d_grad_w!(dw::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}, dy::AbstractArray{T,4};
padding=0, stride=1, mode=0, alpha=1) where T
padding=0, stride=1, flipkernel = true, alpha=1) where T
Wx,Hx,Cx,Nx = size(x)
Ww,Hw,Cm,Cw = size(w) # Cm = Channel Multiplier
@assert Cx == Cw DimensionMismatch()
Expand All @@ -220,7 +220,7 @@ function depthwiseconv2d_grad_w!(dw::AbstractArray{T,4}, x::AbstractArray{T,4},
alpha,beta = T(alpha),T(1)
dyidx = 1
@inbounds for i in 1:Nx
im2col2d!(dims_w, x, x2, i, p1, p2, s1, s2, mode)
im2col2d!(dims_w, x, x2, i, p1, p2, s1, s2, flipkernel)
dwidx = 1
@inbounds for j in 1:Cx
gemm!('T','T',M,N,K,alpha,pointer(x2,(j-1)*M*K+1),pointer(dy,dyidx+(j-1)*K*N),beta,pointer(dw,dwidx))
Expand All @@ -232,7 +232,7 @@ function depthwiseconv2d_grad_w!(dw::AbstractArray{T,4}, x::AbstractArray{T,4},
end

function depthwiseconv2d_grad_x!(dx::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}, dy::AbstractArray{T,4};
padding=0, stride=1, mode=0, alpha=1) where T
padding=0, stride=1, flipkernel = true, alpha=1) where T
Wx,Hx,Cx,Nx = size(x)
Ww,Hw,Cm,Cw = size(w) # Cm = Channel Multiplier
@assert Cx == Cw DimensionMismatch()
Expand All @@ -250,15 +250,14 @@ function depthwiseconv2d_grad_x!(dx::AbstractArray{T,4}, x::AbstractArray{T,4},
@inbounds for j in 1:Cx
gemm!('N','T',M,N,K,alpha,pointer(dy,dyidx+(j-1)*K*M),pointer(w,(j-1)*K*N+1),beta,pointer(x2,(j-1)*M*N+1))
end
col2im2d!(dims_w,dx,x2,i,p1,p2,s1,s2,mode)
col2im2d!(dims_w,dx,x2,i,p1,p2,s1,s2,flipkernel)
dyidx += Y
end
return dx
end

function conv2d!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
padding=0, stride=1, dilation=1, mode=0, alpha=T(1)) where T
if mode != 0 && mode != 1; throw(ArgumentError("conv2d only supports mode=0 or 1.")); end
padding=0, stride=1, dilation=1, flipkernel=true, alpha=T(1)) where T
Wx,Hx,Cx,Nx = size(x)
Ww,Hw,C1,C2 = size(w)
if Cx!=C1; throw(DimensionMismatch()); end
Expand All @@ -271,15 +270,15 @@ function conv2d!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{
M,N,K,Y = Wy*Hy,Cy,Ww*Hw*Cx,Wy*Hy*Cy
yidx = 1
@inbounds for n in 1:Nx
im2col2d!(w, x, x2, n, p1, p2, s1, s2, d1, d2, mode)
im2col2d!(w, x, x2, n, p1, p2, s1, s2, d1, d2, flipkernel)
gemm!('N','N',M,N,K,alpha,pointer(x2),pointer(w),T(0),pointer(y,yidx))
yidx += Y
end
return y
end

function conv2d_grad_w!(dw::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}, dy::AbstractArray{T,4};
padding=0, stride=1, dilation=1, mode=0, alpha=1) where T
padding=0, stride=1, dilation=1, flipkernel=true, alpha=1) where T
# dw = x'*dy
Wx,Hx,Cx,Nx = size(x)
Ww,Hw,C1,C2 = size(w)
Expand All @@ -296,15 +295,15 @@ function conv2d_grad_w!(dw::AbstractArray{T,4}, x::AbstractArray{T,4}, w::Abstra
(d1,d2) = psize(dilation,x)
dyi = 1
@inbounds for n in 1:Nx
im2col2d!(w, x, x2, n, p1, p2, s1, s2, d1, d2, mode)
im2col2d!(w, x, x2, n, p1, p2, s1, s2, d1, d2, flipkernel)
gemm!('T','N',M,N,K,alpha,pointer(x2),pointer(dy,dyi),beta,pointer(dw))
dyi += Y
end
return dw
end

function conv2d_grad_x!(dx::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}, dy::AbstractArray{T,4};
padding=0, stride=1, dilation=1, mode=0, alpha=1) where T
padding=0, stride=1, dilation=1, flipkernel=true, alpha=1) where T
# dx = dy*w'
Wx,Hx,Cx,Nx = size(x)
Ww,Hw,C1,C2 = size(w)
Expand All @@ -322,53 +321,52 @@ function conv2d_grad_x!(dx::AbstractArray{T,4}, x::AbstractArray{T,4}, w::Abstra
dyi = 1
@inbounds for n in 1:Nx
gemm!('N','T',M,N,K,alpha,pointer(dy,dyi),pointer(w),beta,pointer(x2))
col2im2d!(w,dx,x2,n,p1,p2,s1,s2,d1,d2,mode)
col2im2d!(w,dx,x2,n,p1,p2,s1,s2,d1,d2,flipkernel)
dyi += Y
end
return dx
end

function im2col2d!(w::NTuple{4,Int}, x::AbstractArray{T,4}, x2::AbstractArray{T,2},
n::Int, p1::Int, p2::Int, s1::Int, s2::Int, mode::Int) where T
n::Int, p1::Int, p2::Int, s1::Int, s2::Int, flipkernel::Bool) where T
Wx,Hx,Cx,Nx = size(x)
Ww,Hw,C1,C2 = w
xn = x[:, :, :, n]
im2col_2d!(xn,x2,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,1,1,mode)
im2col_2d!(xn,x2,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,1,1,flipkernel)
return x2
end

function im2col2d!(w::AbstractArray{T,4}, x::AbstractArray{T,4}, x2::AbstractArray{T,2},
n::Int, p1::Int, p2::Int, s1::Int, s2::Int, d1::Int, d2::Int, mode::Int) where T
n::Int, p1::Int, p2::Int, s1::Int, s2::Int, d1::Int, d2::Int, flipkernel::Bool) where T
Wx,Hx,Cx,Nx = size(x)
Ww,Hw,C1,C2 = size(w)
xn = x[:, :, :, n]
im2col_2d!(xn,x2,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,d1,d2,mode)
im2col_2d!(xn,x2,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,d1,d2,flipkernel)
return x2
end

function col2im2d!(w::NTuple{4,Int}, x::AbstractArray{T,4}, x2::AbstractArray{T,2},
n::Int, p1::Int, p2::Int, s1::Int, s2::Int, mode::Int) where T
n::Int, p1::Int, p2::Int, s1::Int, s2::Int, flipkernel::Bool) where T
Wx,Hx,Cx,Nx = size(x)
Ww,Hw,C1,C2 = w
xn = x[:, :, :, n]
col2im_2d!(x2,xn,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,1,1,mode)
col2im_2d!(x2,xn,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,1,1,flipkernel)
x[:, :, :, n] = xn
return x
end

function col2im2d!(w::AbstractArray{T,4}, x::AbstractArray{T,4}, x2::AbstractArray{T,2},
n::Int, p1::Int, p2::Int, s1::Int, s2::Int, d1::Int, d2::Int, mode::Int) where T
n::Int, p1::Int, p2::Int, s1::Int, s2::Int, d1::Int, d2::Int, flipkernel::Bool) where T
Wx,Hx,Cx,Nx = size(x)
Ww,Hw,C1,C2 = size(w)
xn = x[:, :, :, n]
col2im_2d!(x2,xn,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,d1,d2,mode)
col2im_2d!(x2,xn,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,d1,d2,flipkernel)
x[:, :, :, n] = xn
return x
end

function conv3d!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
padding=0, stride=1, dilation = 1, mode=0, alpha=T(1)) where T
if mode != 0 && mode != 1; throw(ArgumentError("conv3d only supports mode=0 or 1.")); end
padding=0, stride=1, dilation = 1, flipkernel=true, alpha=T(1)) where T
Wx,Hx,Dx,Cx,Nx = size(x)
Ww,Hw,Dw,C1,C2 = size(w)
if Cx!=C1; throw(DimensionMismatch()); end
Expand All @@ -383,15 +381,15 @@ function conv3d!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{
yidx = 1
W = reshape(w, (size(w, 1),:,C1,C2))
@inbounds for n in 1:Nx
im2col3d!(w, x, x2, n, p1, p2, p3, s1, s2, s3, d1, d2, d3, mode)
im2col3d!(w, x, x2, n, p1, p2, p3, s1, s2, s3, d1, d2, d3, flipkernel)
gemm!('N','N',M,N,K,alpha,pointer(x2),pointer(W),T(0),pointer(y,yidx))
yidx += Y
end
return y
end

function conv3d_grad_w!(dw::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5}, dy::AbstractArray{T,5};
padding=0, stride=1, dilation = 1, mode=0, alpha=1) where T
padding=0, stride=1, dilation = 1, flipkernel=true, alpha=1) where T
# dw = x'*dy
Wx,Hx,Dx,Cx,Nx = size(x)
Ww,Hw,Dw,C1,C2 = size(w)
Expand All @@ -408,15 +406,15 @@ function conv3d_grad_w!(dw::AbstractArray{T,5}, x::AbstractArray{T,5}, w::Abstra
(d1,d2,d3) = psize(dilation,x)
dyi = 1
@inbounds for n in 1:Nx
im2col3d!(w, x, x2, n, p1, p2, p3, s1, s2, s3, d1, d2, d3, mode)
im2col3d!(w, x, x2, n, p1, p2, p3, s1, s2, s3, d1, d2, d3, flipkernel)
gemm!('T','N',M,N,K,alpha,pointer(x2),pointer(dy,dyi),beta,pointer(dw))
dyi += Y
end
return dw
end

function conv3d_grad_x!(dx::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5}, dy::AbstractArray{T,5};
padding=0, stride=1, dilation = 1, mode=0, alpha=1) where T
padding=0, stride=1, dilation = 1, flipkernel=true, alpha=1) where T
# dx = dy*w'
Wx,Hx,Dx,Cx,Nx = size(x)
Ww,Hw,Dw,C1,C2 = size(w)
Expand All @@ -434,29 +432,29 @@ function conv3d_grad_x!(dx::AbstractArray{T,5}, x::AbstractArray{T,5}, w::Abstra
dyi = 1
@inbounds for n in 1:Nx
gemm!('N','T',M,N,K,alpha,pointer(dy,dyi),pointer(w),beta,pointer(x2))
col2im3d!(w,dx,x2,n,p1,p2,p3,s1,s2,s3,d1,d2,d3,mode)
col2im3d!(w,dx,x2,n,p1,p2,p3,s1,s2,s3,d1,d2,d3,flipkernel)
dyi += Y
end
return dx
end

function im2col3d!(w::AbstractArray{T,5}, x::AbstractArray{T,5}, x2::AbstractArray{T,2},
n::Int, p1::Int, p2::Int, p3::Int, s1::Int, s2::Int,
s3::Int, d1::Int, d2::Int, d3::Int, mode::Int) where T
s3::Int, d1::Int, d2::Int, d3::Int, flipkernel::Bool) where T
Wx,Hx,Dx,Cx,Nx = size(x)
Ww,Hw,Dw,C1,C2 = size(w)
xn = x[:, :, :, :, n]
im2col_3d!(xn,x2,Wx,Hx,Dx,Cx,Ww,Hw,Dw,p1,p2,p3,s1,s2,s3,d1,d2,d3,mode)
im2col_3d!(xn,x2,Wx,Hx,Dx,Cx,Ww,Hw,Dw,p1,p2,p3,s1,s2,s3,d1,d2,d3,flipkernel)
return x2
end

function col2im3d!(w::AbstractArray{T,5}, x::AbstractArray{T,5}, x2::AbstractArray{T,2},
n::Int, p1::Int, p2::Int, p3::Int, s1::Int, s2::Int,
s3::Int, d1::Int, d2::Int, d3::Int, mode::Int) where T
s3::Int, d1::Int, d2::Int, d3::Int, flipkernel::Bool) where T
Wx,Hx,Dx,Cx,Nx = size(x)
Ww,Hw,Dw,C1,C2 = size(w)
xn = x[:, :, :, :, n]
col2im_3d!(x2,xn,Wx,Hx,Dx,Cx,Ww,Hw,Dw,p1,p2,p3,s1,s2,s3,d1,d2,d3,mode)
col2im_3d!(x2,xn,Wx,Hx,Dx,Cx,Ww,Hw,Dw,p1,p2,p3,s1,s2,s3,d1,d2,d3,flipkernel)
x[:, :, :, :, n] = xn
return x
end