Skip to content

Add fold.jl #303

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ include("batched/batchedmul.jl")
include("gemm.jl")
include("conv.jl")
include("conv_bias_act.jl")
include("fold.jl")
include("pooling.jl")
include("padding.jl")
include("upsample.jl")
Expand Down
70 changes: 70 additions & 0 deletions src/fold.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
export unfold, fold
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not export them, the names would dirty the name space very quickly.


"""
unfold(X, W; stride=1, padding=0, dilation=1)
Extracts sliding local blocks from a batched input tensor. X is the input 5d vector of size
`(spatial_dims... , channels, batch_size)`. W is the size of kernel, in format
`(spatial_dims... , channels)`. Output has the size of `(L, channels*kernel_w*kernel_h*kernel_d, batch_size)`,
where L is the total number of blocks.

"""
function unfold(X::AbstractArray{T,M} where T, w_dim::NTuple{K}; stride=1, padding=0, dilation=1) where M where K
x_dim = size(X)
if ndims(X) > 5 || ndims(X) < 3
throw(DimensionMismatch("X and W must be 3d/4d/5d for 1d/2d/3d image. got $(ndims(X))d input"))
end

if ndims(X)-2 != length(w_dim)-1
throw(DimensionMismatch("spatial dimentions of image and kernel must be equal, got $(ndims(X)-2),$(length(w_dim)-1)"))
end

# reassign x_dim after converting it to a 3d image type input
x_dim = ( x_dim[1:end-2]... , fill(1,5-ndims(X))... , x_dim[end-1:end]... )
# w_dim must be in following format: (spatial_dims..., channels_in, channels_out)
w_dim = ( w_dim[1:end-1]... , fill(1,4-length(w_dim))... , w_dim[end], w_dim[end] )
X = reshape(X, x_dim)

# Make DenseConvDims object
cdims = DenseConvDims(x_dim, w_dim; stride=stride, padding=padding, dilation=dilation)

# Calculate the total number of sliding blocks
col_dim = (im2col_dims(cdims))[1:2] # im2col_dims() returns (col_dim_x, col_dim_y, thread_num)
col = fill(0., col_dim[1],col_dim[2],x_dim[end]) # x_dim[end] is number of batches

# Iterate through all batchs
@views for i = 1:x_dim[end]
im2col!(col[:,:,i], X[:,:,:,:,i], cdims)
end
return col
end

"""
fold(col, out_dim, W; stride=1, padding=0, dilation=1)
Does the opposite of `unfold()`, Combines an array of sliding local blocks into a large containing
tensor. `col` is a 3d array of shape `(L, channels*kernel_w*kernel_h*kernel_d, batch_size)`, where,
L is the total number of blocks. out_dim is the spatial dimention of the required image. W is the
spatial dimentions of the kernel.

"""
function fold(col::AbstractArray{T,3} where T, out_dim::NTuple{M}, w_dim::NTuple{M}; stride=1, padding=0, dilation=1) where M
# Validate input
if length(out_dim) > 3
throw(DimensionMismatch("output dimentions cannot be greater than 3, got $(ndims(out_dim))"))
end

# Create DenseConvDims object
col_dim = size(col)
channels = col_dim[2]÷prod(w_dim)
x_dim = (out_dim... , fill(3-length(out_dim))... , channels,col_dim[3])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The calls to fill seem suspect. Did you need ntuple there?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needed a ntuple with dimensions of out_dim appended with dummy dimension to make it 3d. I think it should be fill(1,3-length(out_dim)). It's okay to use fill here right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use insert_singleton_dimension instead. Or a reshape. fill is incorrect here.

w_dim = (w_dim... , fill(3-length(w_dim))... , channels,channels)
cdims = DenseConvDims(x_dim,w_dim; stride=stride, padding=padding, dilation=dilation)

img = fill(0., x_dim)

# Iterate through all batchs
@views for i = 1:x_dim[end]
col2im!(img[:,:,:,:,i], col[:,:,i], cdims)
end

return reshape(img, (out_dim... , channels,col_dim[3]))
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely don't want to see this. Possibly the editor is using file endings different from Linux. I'd check that.