Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add adaptive pool #1239

Merged
merged 16 commits into from
Jun 30, 2020
2 changes: 2 additions & 0 deletions docs/src/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ These layers are used to build convolutional neural networks (CNNs).

```@docs
Conv
AdaptiveMaxPool
MaxPool
GlobalMaxPool
AdaptiveMeanPool
MeanPool
GlobalMeanPool
DepthwiseConv
Expand Down
7 changes: 4 additions & 3 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ using Zygote: Params, @adjoint, gradient, pullback, @nograd
export gradient

export Chain, Dense, Maxout, RNN, LSTM, GRU, SamePad, Conv, CrossCor, ConvTranspose,
GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, flatten,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode!, trainmode!
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool,
MeanPool, flatten, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm,
InstanceNorm, GroupNorm, SkipConnection, params, fmap, cpu, gpu, f32, f64,
testmode!, trainmode!

include("optimise/Optimise.jl")
using .Optimise
Expand Down
54 changes: 54 additions & 0 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,60 @@ end
outdims(l::CrossCor, isize) =
output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))

"""
AdaptiveMaxPool(out)
dnabanita7 marked this conversation as resolved.
Show resolved Hide resolved

Adaptive max pooling layer. `out` is the size of the dimension of the output.
dnabanita7 marked this conversation as resolved.
Show resolved Hide resolved

The output dimension is fixed irrespective of the size of the input.
"""
struct AdaptiveMaxPool{S, O}
dnabanita7 marked this conversation as resolved.
Show resolved Hide resolved
out::NTuple{O, Int}
AdaptiveMaxPool(out::NTuple{O, Int}) where O = new{O + 2, O}(out)
end

function (a::AdaptiveMaxPool{S})(x::AbstractArray{T, S}) where {S, T}
sz = size(x)
insize = sz[1:end-2]
outsize = a.out
stride = insize ./ outsize
k = insize .- (outsize .- 1) .* stride
dnabanita7 marked this conversation as resolved.
Show resolved Hide resolved
dnabanita7 marked this conversation as resolved.
Show resolved Hide resolved
pad = 0
pdims = PoolDims(x, k; padding=pad, stride=stride)
return maxpool(x, pdims)
end

function Base.show(io::IO, a::AdaptiveMaxPool)
print(io, "AdaptiveMaxPool(", a.out, ")")
end

"""
AdaptiveMeanPool(out)

Adaptive mean pooling layer. `out` is the size of the dimension of the output.

The output dimension is fixed irrespective of the size of the input.
"""
struct AdaptiveMeanPool{S, O}
out::NTuple{O, Int}
AdaptiveMeanPool(out::NTuple{O, Int}) where O = new{O + 2, O}(out)
end

function (a::AdaptiveMeanPool{S})(x::AbstractArray{T, S}) where {S, T}
sz = size(x)
insize = sz[1:end-2]
outsize = a.out
stride = insize ./ outsize
k = insize .- (outsize .- 1) .* stride
pad = 0
pdims = PoolDims(x, k; padding=pad, stride=stride)
return meanpool(x, pdims)
dnabanita7 marked this conversation as resolved.
Show resolved Hide resolved
end

function Base.show(io::IO, a::AdaptiveMeanPool)
print(io, "AdaptiveMeanPool(", a.out, ")")
end

"""
GlobalMaxPool()

Expand Down
4 changes: 4 additions & 0 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ using Flux: gradient

@testset "Pooling" begin
x = randn(Float32, 10, 10, 3, 2)
amp = AdaptiveMaxPool((5,5))
@test amp(x) == maxpool(x, PoolDims(x, 2))
amp = AdaptiveMeanPool((5,5))
@test amp(x) == meanpool(x, PoolDims(x, 2))
dnabanita7 marked this conversation as resolved.
Show resolved Hide resolved
gmp = GlobalMaxPool()
@test size(gmp(x)) == (1, 1, 3, 2)
gmp = GlobalMeanPool()
Expand Down