Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Upsample and PixelShuffle layers #1468

Merged
merged 19 commits into from
Feb 11, 2021
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "0.3.4+0"

[[DataAPI]]
git-tree-sha1 = "6d64b28d291cb94a0d84e6e41081fb081e7f717f"
git-tree-sha1 = "8ab70b4de35bb3b8cc19654f6b893cf5164f8ee8"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.5.0"
version = "1.5.1"

[[DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
Expand Down Expand Up @@ -236,9 +236,9 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804"

[[NNlib]]
deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"]
git-tree-sha1 = "573cc0d31f9697b9d2b060130a7a3c05a4f36b78"
git-tree-sha1 = "df42d0816edfc24f5b82a728f46381613c4dff79"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.7.12"
version = "0.7.14"

[[NaNMath]]
git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb"
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
* Moved GPU CI to use buildkite instead of GitLab
* New [`Parallel` layer](https://github.com/FluxML/Flux.jl/pull/1462) adds inception module-like building blocks.
* Feature additions and bug fixes for BatchNorm, LayerNorm, InstanceNorm, and GroupNorm [normalization layers](https://github.com/FluxML/Flux.jl/pull/1397)
* Added [Upsample and PixelShuffle layers](https://github.com/FluxML/Flux.jl/pull/1468)

## v0.11.2

Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Colors = "0.12"
Functors = "0.1, 0.2"
Juno = "0.8"
MacroTools = "0.5"
NNlib = "0.7.10"
NNlib = "0.7.14"
Reexport = "0.2, 1.0"
StatsBase = "0.33"
ZipFile = "0.9"
Expand Down
7 changes: 7 additions & 0 deletions docs/src/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ Flux.convfilter
Flux.depthwiseconvfilter
```

## Upsampling Layers

```@docs
Upsample
PixelShuffle
```

## Recurrent Layers

Much like the core layers above, but can be used to process sequence data (as well as other kinds of structured data).
Expand Down
8 changes: 8 additions & 0 deletions docs/src/models/nnlib.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ NNlib.conv
NNlib.depthwiseconv
```

## Upsampling

```@docs
NNlib.upsample_nearest
NNlib.upsample_bilinear
NNlib.pixel_shuffle
```

## Batched Operations

```@docs
Expand Down
2 changes: 2 additions & 0 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export Chain, Dense, Maxout, SkipConnection, Parallel, flatten,
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
Upsample, PixelShuffle,
params, fmap, cpu, gpu, f32, f64,
testmode!, trainmode!

Expand All @@ -42,6 +43,7 @@ include("layers/basic.jl")
include("layers/conv.jl")
include("layers/recurrent.jl")
include("layers/normalise.jl")
include("layers/upsample.jl")

include("outputsize.jl")

Expand Down
81 changes: 81 additions & 0 deletions src/layers/upsample.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""
Upsample(mode = :nearest; scale = nothing, size = nothing)

An upsampling layer.

`scale` is a number or a tuple of numbers
representing the output rescaling factor along each spatial dimension.
For integer `scale`, all but the last 2 dimensions (channel and batch)
will be rescaled by the same factor.

It is also possible to directly specify the output spatial `size`,
as an alternative to using `scale`.
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

Currently supported upsampling `mode`s
and corresponding NNlib's methods are:
- `:nearest` -> [`NNlib.upsample_nearest`](@ref)
- `:bilinear` -> [`NNlib.upsample_bilinear`](@ref)

# Examples

```juliarepl
julia> m = Upsample(scale = (2, 3))
Upsample(:nearest, scale=(2, 3))

julia> m(ones(2, 2, 1, 1)) |> size
(4, 6, 1, 1)

julia> m = Upsample(:bilinear, size = (4, 5))
Upsample(:bilinear, size=(4, 5))

julia> m(ones(2, 2, 1, 1)) |> size
(4, 5, 1, 1)
"""
struct Upsample{Mode,S,T}
scale::S
size::T
end

function Upsample(mode = :nearest; scale = nothing, size = nothing)
mode in [:nearest, :bilinear] ||
throw(ArgumentError("mode=:$mode is not supported."))
if ~((scale === nothing) ⊻ (size === nothing))
throw(ArgumentError("Either scale or size should be specified."))
end
return Upsample{mode,typeof(scale),typeof(size)}(scale, size)
DhairyaLGandhi marked this conversation as resolved.
Show resolved Hide resolved
end

(m::Upsample{:nearest})(x::AbstractArray) =
NNlib.upsample_nearest(x, m.scale)
function (m::Upsample{:nearest, Int})(x::AbstractArray{T, N}) where {T, N}
NNlib.upsample_nearest(x, ntuple(i -> m.scale, N-2))
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
end
(m::Upsample{:nearest, Nothing})(x::AbstractArray) =
NNlib.upsample_nearest(x; size=m.size)

(m::Upsample{:bilinear})(x::AbstractArray) =
NNlib.upsample_bilinear(x, m.scale)
(m::Upsample{:bilinear, Nothing})(x::AbstractArray) =
NNlib.upsample_bilinear(x; size=m.size)
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

function Base.show(io::IO, u::Upsample{mode}) where {mode}
print(io, "Upsample(")
print(io, ":", mode)
u.scale !== nothing && print(io, ", scale = $(u.scale)")
u.size !== nothing && print(io, ", size = $(u.size)")
println(io, ")")
DhairyaLGandhi marked this conversation as resolved.
Show resolved Hide resolved
end

"""
PixelShuffle(r::Int)

Pixel shuffling layer with upscale factor `r`.

See [`NNlib.pixel_shuffle`](@ref).
"""
struct PixelShuffle
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
r::Int
end

(m::PixelShuffle)(x) = NNlib.pixel_shuffle(x, m.r)

2 changes: 1 addition & 1 deletion src/outputsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Base.isless(::Nil, ::Number) = true
Base.isless(::Number, ::Nil) = true

Base.isnan(::Nil) = false

Base.isfinite(::Nil) = true
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
Base.typemin(::Type{Nil}) = nil
Base.typemax(::Type{Nil}) = nil

Expand Down
11 changes: 10 additions & 1 deletion test/cuda/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ gpu_gradtest("GroupNorm 3d", groupnorm, rand(Float32, 8, 8, 8, 12, 4), 12, 3, se
gpu_gradtest("GroupNorm 2d", groupnorm, rand(Float32, 8, 8, 12, 4), 12, 3, setmode=true)
gpu_gradtest("GroupNorm 1d", groupnorm, rand(Float32, 8, 3, 12, 4), 12, 3, setmode=true)

upsample = [x -> Upsample(scale=x)]
gpu_gradtest("Upsample 2d", upsample, rand(Float32, 3, 4, 2, 3), (2,2))
gpu_gradtest("Upsample 1d", upsample, rand(Float32, 3, 4, 2, 3), (2,))
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

pixelshuffle = [PixelShuffle]
gpu_gradtest("PixelShuffle 2d", pixelshuffle, rand(Float32, 3, 4, 18, 3), 3)
gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3)


CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
@testset "function layers" begin
x = rand(Float32, 3,3)
gpu_autodiff_test(x -> sum(Flux.normalise(x; dims=1)), x)
Expand Down Expand Up @@ -168,4 +177,4 @@ end
@test sum(l(ip)) ≈ 0.f0
gs = gradient(() -> sum(l(ip)), Flux.params(l))
@test l.b ∉ gs.params
end
end
63 changes: 63 additions & 0 deletions test/layers/upsample.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
@testset "upsample bilinear" begin
m = Upsample(:bilinear, scale=(2, 3))
x = rand(Float32, 3, 4, 2, 3)
y = m(x)
@test y isa Array{Float32, 4}
@test size(y) == (6, 12, 2, 3)

m = Upsample(:bilinear, scale=3)
x = rand(Float32, 3, 4, 2, 3)
y = m(x)
@test y isa Array{Float32, 4}
@test size(y) == (9, 12, 2, 3)

m = Upsample(:bilinear, size=(4, 6))
x = rand(Float32, 3, 4, 2, 3)
y = m(x)
@test y isa Array{Float32, 4}
@test size(y) == (4, 6, 2, 3)
end

@testset "upsample nearest" begin
x = rand(Float32, 3, 2, 3)
m = Upsample(:nearest, scale=(2,))
y = m(x)
@test y isa Array{Float32, 3}
@test size(y) == (6, 2, 3)

x = rand(Float32, 3, 4, 2, 3)

m = Upsample(:nearest, scale=(2, 3))
y = m(x)
@test y isa Array{Float32, 4}
@test size(y) == (6, 12, 2, 3)

m = Upsample(:nearest, scale=(2,))
y = m(x)
@test y isa Array{Float32, 4}
@test size(y) == (6, 4, 2, 3)

m = Upsample(:nearest, scale=2)
y = m(x)
@test y isa Array{Float32, 4}
@test size(y) == (6, 8, 2, 3)

m = Upsample(:nearest, size=(6,8))
y = m(x)
@test y isa Array{Float32, 4}
@test size(y) == (6, 8, 2, 3)
end

@testset "PixelShuffle" begin
m = PixelShuffle(2)
x = rand(Float32, 3, 18, 3)
y = m(x)
@test y isa Array{Float32, 3}
@test size(y) == (6, 9, 3)

m = PixelShuffle(3)
x = rand(Float32, 3, 4, 18, 3)
y = m(x)
@test y isa Array{Float32, 4}
@test size(y) == (9, 12, 2, 3)
end
8 changes: 3 additions & 5 deletions test/outputsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,7 @@ end
@test outputsize(m, (32, 32, 3, 16)) == (32, 32, 3, 16)
@test outputsize(m, (32, 32, 3); padbatch=true) == (32, 32, 3, 1)

if VERSION >= v"1.1"
m = GroupNorm(16, 4)
@test outputsize(m, (32, 32, 16, 16)) == (32, 32, 16, 16)
@test outputsize(m, (32, 32, 16); padbatch=true) == (32, 32, 16, 1)
end
m = GroupNorm(16, 4)
@test outputsize(m, (32, 32, 16, 16)) == (32, 32, 16, 16)
@test outputsize(m, (32, 32, 16); padbatch=true) == (32, 32, 16, 1)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ end
include("layers/stateless.jl")
include("layers/recurrent.jl")
include("layers/conv.jl")
include("layers/upsample.jl")
end

@testset "outputsize" begin
Expand Down