Skip to content

Commit

Permalink
Merge pull request #320 from maxfreu/upsample-linear
Browse files Browse the repository at this point in the history
introduce linear upsampling
  • Loading branch information
CarloLucibello authored Jun 1, 2021
2 parents 1dbdcef + b5abee5 commit 4411c86
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 2 deletions.
115 changes: 115 additions & 0 deletions src/upsample.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
export upsample_nearest, ∇upsample_nearest,
upsample_linear, ∇upsample_linear,
upsample_bilinear, ∇upsample_bilinear,
upsample_trilinear, ∇upsample_trilinear,
pixel_shuffle
Expand Down Expand Up @@ -96,6 +97,120 @@ end
return input_index0, input_index1, lambda0, lambda1
end

###########
# linear
###########
"""
upsample_linear(x::AbstractArray{T,3}, scale::Real)
upsample_linear(x::AbstractArray{T,3}; size::Integer)
Upsamples the first dimension of the array `x` by the upsample provided `scale`,
using linear interpolation. As an alternative to using `scale`, the resulting array `size`
can be directly specified with a keyword argument.
The size of the output is equal to
`(scale*S1, S2, S3)`, where `S1, S2, S3 = size(x)`.
"""
function upsample_linear(x::AbstractArray{<:Any,3}, scale::Real)
outsize = floor(Int, scale * Base.size(x)[1])
return upsample_linear(x; size=outsize)
end

function upsample_linear(x::AbstractArray{T,3}; size::Integer) where T
w,c,n = Base.size(x)
if w == size
return x
end
y = similar(x, T, size, c, n)
return upsample_linear_wcn!(y, x)
end

function upsample_linear(x::AbstractArray{T,3}; size::Integer) where T<:Integer
y = float.(x)
res = upsample_linear(y; size=size)
return round.(T, res)
end

function upsample_linear_wcn!(output::AbstractArray{T,3}, input::AbstractArray{T,3}) where T
size(input)[2:3] == size(output)[2:3] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
in_w, channels, batches = size(input)
# treat batch and channel dimension as one for better parallelization granularity
channels *= batches
out_w, _, _ = size(output)
output_slice_size = out_w

# T() and // so that we can handle rationals (super slow)
width_scale = T((in_w - 1) // (out_w - 1))

@inline idx(c, w) = c * in_w + w + 1

@inbounds Threads.@threads for c in 0:channels-1
for ow in 0:out_w-1
iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w)
output_offset = c * output_slice_size + ow + 1
output[output_offset] = (w0lambda * input[idx(c, iw0)] + # w0 * i00
w1lambda * input[idx(c, iw1)]) # w1 * i01
end
end
return output
end

"""
∇upsample_linear(Δ::AbstractArray{T,3}; size::Integer) where T
# Arguments
- `Δ`: Incoming gradient array, backpropagated from downstream layers
- `size`: Size of the image upsampled in the first place
# Outputs
- `dx`: Downsampled version of `Δ`
"""
function ∇upsample_linear::AbstractArray{T,3}; size::Integer) where T
w, c, n = Base.size(Δ)
out_w = size
if w == out_w
return Δ
end
dx = zero(similar(Δ, T, out_w, c, n))
return ∇upsample_linear_wcn!(dx, Δ)
end

function ∇upsample_linear_wcn!(dx::AbstractArray{T,3}, Δ::AbstractArray{T,3}) where T
size(dx)[2:3] == size(Δ)[2:3] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
in_w, channels, batches = size(dx)

# treat batch and channel dimension as one for better parallelization granularity
channels *= batches
out_w, _, _ = size(Δ)
output_slice_size = out_w

width_scale = T((in_w - 1) // (out_w - 1))

@inline idx(c, w) = c * in_w + w + 1

@inbounds Threads.@threads for c in 0:channels-1
for ow in 0:out_w-1
iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w)
output_offset = c * output_slice_size + ow + 1
Δ_value = Δ[output_offset]
dx[idx(c, iw0)] += w0lambda * Δ_value # i00
dx[idx(c, iw1)] += w1lambda * Δ_value # i01
end
end
return dx
end

function rrule(::typeof(upsample_linear), x; size)
Ω = upsample_linear(x; size=size)
function upsample_linear_pullback(Δ)
(NO_FIELDS, ∇upsample_linear(Δ; size=Base.size(x,1)))
end
return Ω, upsample_linear_pullback
end

###########
# bilinear
###########
"""
upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real})
upsample_bilinear(x::AbstractArray{T,4}; size::NTuple{2,Integer})
Expand Down
17 changes: 15 additions & 2 deletions test/upsample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,20 @@
@test_throws ArgumentError upsample_nearest(x, size=(3,4))
end

@testset "upsample_bilinear 2d" begin
@testset "Linear upsampling (1D)" begin
x = Float64[1,2,3,4]
x = hcat(x,x,x)[:,:,:]

y = collect(1:1//3:4)
y = hcat(y,y,y)[:,:,:]
yF64 = Float64.(y)

@test y upsample_linear(x, 2.5)
@test y upsample_linear(x; size=10)
gradtest(x->upsample_linear(x, 2.5), x)
end

@testset "Bilinear upsampling (2D)" begin
x = Float32[1 2; 3 4][:,:,:,:]
x = cat(x,x; dims=3)
x = cat(x,x; dims=4)
Expand Down Expand Up @@ -65,7 +78,7 @@ end
@test y == y_true_int
end

@testset "Trilinear upsampling" begin
@testset "Trilinear upsampling (3D)" begin
# Layout: WHDCN, where D is depth
# we generate data which is constant along W & H and differs in D
# then we upsample along all dimensions
Expand Down

0 comments on commit 4411c86

Please sign in to comment.