Skip to content

improve bilinear upsampling #266

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Feb 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
321 changes: 126 additions & 195 deletions src/upsample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,235 +64,166 @@ function ChainRulesCore.rrule(::typeof(upsample_nearest), x::AbstractArray, s::T
return Ω, upsample_nearest_pullback
end

"""
upsample_bilinear(x::AbstractArray{<:Number,4}, k::NTuple{2,Int})

Upsamples the first 2 dimensions of the array `x` by the upsample factors stored in `k`,
using bilinear interpolation.

The size of the output is equal to
`(k[1]*S1, k[2]*S2, S3, S4)`, where `S1, S2, S3, S4 = size(x)`.

The interpolation grid is identical to the one used by `imresize` from `Images.jl`.

Currently only 2d upsampling is supported.
"""
function upsample_bilinear(x::AbstractArray{T,4}, k::NTuple{2,Int}) where T
# This function is gpu friendly

imgsize = size(x)
newsize = get_newsize(imgsize, k)

# Get linear interpolation lower- and upper index, and weights
ilow1, ihigh1, wdiff1 = get_inds_and_ws(x, imgsize[1], newsize[1], 1)
ilow2, ihigh2, wdiff2 = get_inds_and_ws(x, imgsize[2], newsize[2], 2)

# Adjust the upper interpolation indices of the second dimension
ihigh2_r = adjoint_of_idx(ilow2)[ihigh2]

@inbounds y = @view(x[ilow1,ilow2,:,:]) .* (1 .- wdiff1) .+ @view(x[ihigh1,ilow2,:,:]) .* wdiff1
@inbounds y .= y .* (1 .- wdiff2) .+ y[:,ihigh2_r,:,:] .* wdiff2
# @inbounds y = y .* (1 .- wdiff2) .+ @view(y[:,ihigh2_r,:,:]) .* wdiff2 # equivalent to line above
return y
end

function get_inds_and_ws(x::T, n::Int, m::Int, dim::Int) where T <: AbstractArray
# Creates interpolation grid for resampling.
# Creates the same grid as used in Image.jl `imresize`.
step = n // m
offset = (n + 1)//2 - step//2 - step * (m//2 - 1)
xq = clamp.(range(offset, step=step, length=m), 1, n)

# Creates interpolation lower and upper indices, and broadcastable weights
ilow = floor.(Int, xq)
ihigh = ceil.(Int, xq)
sizew = ntuple(i-> i == dim ? length(xq) : 1, ndims(x))
wdiff = convert(T, reshape(xq .- ilow, sizew)) # wdiff possibly lives on gpu
return ilow, ihigh, wdiff
# utility function
@inline function compute_source_index_and_lambda(
ratio, # 0 < ratio < 1
output_index,
input_size,
output_size
)
real_input_index = ratio*output_index
input_index0 = floor(Int, real_input_index) # typecast to int was here in C++
offset = (input_index0 < input_size - 1) ? 1 : 0
input_index1 = input_index0 + offset
lambda1 = real_input_index - input_index0
lambda0 = 1 - lambda1
return input_index0, input_index1, lambda0, lambda1
end

"""
adjoint_of_idx(idx::Vector{<:Integer})

# Arguments
- `idx`: a vector of indices from which you want the adjoint.
upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real})
upsample_bilinear(x::AbstractArray{T,4}; size::NTuple{2,Integer})

# Outputs
-`idx_adjoint`: index that inverses the operation `x[idx]`.
Upsamples the first 2 dimensions of the array `x` by the upsample factors stored in `scale`,
using bilinear interpolation. As an alternative to using `scale`, the resulting image `size`
can be directly specified with a keyword argument.

# Explanation
Determines the adjoint of the vector of indices `idx`, based on the following assumptions:
* `idx[1] == 1`
* `all(d in [0,1] for d in diff(idx))`
The adjoint of `idx` can be seen as an inverse operation such that:
The size of the output is equal to
`(scale[1]*S1, scale[2]*S2, S3, S4)`, where `S1, S2, S3, S4 = size(x)`.

Examples:
```julia
x = [1, 2, 3, 4, 5]
idx = [1, 2, 2, 3, 4, 4, 5]
idx_adjoint = adjoint_of_idx(idx)
@assert x[idx][idx_adjoint] == x
upsample_bilinear(x, (2, pi)) # real scaling factors are allowed
upsample_bilinear(x; size=(64,64)) # specify ouput size
```
The above holds as long as `idx` contains every index in `x`.
"""
function adjoint_of_idx(idx::Vector{Int})
d = trues(length(idx))
d[2:end] .= diff(idx)
idx_adjoint = findall(d)
return idx_adjoint
end

function get_newsize(sz, k)
return ntuple(i -> i <= length(k) ? sz[i]*k[i] : sz[i], length(sz))
function upsample_bilinear(x::AbstractArray{<:Any,4}, scale::NTuple{2,Real})
outsize = ntuple(i -> floor(Int, scale[i] * Base.size(x, i)), 2)
return upsample_bilinear(x; size=outsize)
end

upsample_bilinear(x, scale::Real) = upsample_bilinear(x, (scale,scale))

"""
∇upsample_bilinear(Δ::AbstractArray{<:Number,4}, k::NTuple{2,Int})

# Arguments
- `Δ`: array that has been upsampled using the upsample factors in `k`

# Outputs
- `dx`: downsampled version of `Δ`

# Explanation

Custom adjoint for [`upsample_bilinear`](@ref).
The adjoint of upsampling is a downsampling operation, which
in this implementation is performed using `NNlib.conv` in combination with a downsampling kernel based on the
upsampling factors. Because of the zero-padding during convolution, the values at the boundary are polluted by edge-effects,
which have been corrected for manually.
"""
function ∇upsample_bilinear(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int})
# This function is gpu friendly

# Be more efficient on some corner cases
if size(Δ, 1) == k[1]
Δ = sum(Δ, dims=1)
k = (1, k[2])
end
if size(Δ, 2) == k[2]
Δ = sum(Δ, dims=2)
k = (k[1], 1)
function upsample_bilinear(x::AbstractArray{T,4}; size::NTuple{2,Integer}) where T
w,h,c,n = Base.size(x)
if (w,h) == size
return x
end
if (size(Δ, 1) == 1) && (size(Δ, 2) == 1)
dx = Δ
return dx
end

n_chan, n_batch = size(Δ, 3), size(Δ, 4)

kern1 = get_downsamplekernel(Δ, k[1])
kern2 = get_downsamplekernel(Δ, k[2])
kern = kern1 * kern2'

pad = (floor(Int, k[1]//2), floor(Int, k[2]//2))
stride = k

weight = similar(Δ, eltype(Δ), (size(kern)..., n_chan, n_chan))
weight .= 0
for i in 1:n_chan
weight[:,:,i,i] .= kern
end
# weight = cat(fill(kern, n_chan)..., dims=(3,4)) # slow
dx = conv(Δ, weight, pad=pad, stride=stride)
y = similar(x, T, size..., c, n)
return upsample_bilinear_whcn!(y, x)
end

# Still have to fix edge effects due to zero-padding of convolution,
# TODO: Could be circumvented by having padding that just extrapolates the value at the first/last index
# nextras = tuple((Int.(floor(factor//2)) for factor in k)...)
nextras = (floor(Int, k[1]//2), floor(Int, k[2]//2))
function upsample_bilinear(x::AbstractArray{T,4}; size::NTuple{2,Integer}) where T<:Integer
y = float.(x)
res = upsample_bilinear(y; size=size)
return round.(T, res)
end

# First dimension edge-effect correction
if nextras[1] > 0
kern1 = kern[1:nextras[1],:]
pad1 = (0, pad[2])
stride1 = (1, stride[2])
weight1 = similar(Δ, eltype(Δ), (size(kern1)..., n_chan, n_chan))
weight1 .= 0
for i in 1:n_chan
weight1[:,:,i,i] .= kern1
# this is the core function which works on arrays of arbitrary size
# the implementation is a translation of https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/UpSampleMoreKernel.cpp
# which implements open-cv style linear interpolation / upsampling
# for simplicity, corners are aligned and all logic for other behaviour has been stripped
# - whcn because there is also a cwhn implementation
# - the function is parallelized using @threads
# - RGB types could be supported via reinterpreting
# - integer types need to be converted to Float and back
# - rationals work, but are slow
function upsample_bilinear_whcn!(output::AbstractArray{T,4}, input::AbstractArray{T,4}) where T
size(input)[3:4] == size(output)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
in_w, in_h, channels, batches = size(input)
# treat batch and channel dimension as one for better parallelization granularity
channels *= batches
out_w, out_h, _, _ = size(output)
output_slice_size = out_h * out_w

# T() and // so that we can handle rationals (super slow)
width_scale = T((in_w - 1) // (out_w - 1))
height_scale = T((in_h - 1) // (out_h - 1))

@inline idx(c, h, w) = c * in_h * in_w + h * in_w + w + 1

@inbounds Threads.@threads for c in 0:channels-1
for oh in 0:out_h-1
ih0, ih1, h0lambda, h1lambda = compute_source_index_and_lambda(height_scale, oh, in_h, out_h)
for ow in 0:out_w-1
iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w)
output_offset = c * output_slice_size + oh * out_w + ow + 1
output[output_offset] =
(h0lambda * w0lambda * input[idx(c, ih0, iw0)] + # h0 * w0 * i00
h0lambda * w1lambda * input[idx(c, ih0, iw1)] + # h0 * w1 * i01
h1lambda * w0lambda * input[idx(c, ih1, iw0)] + # h1 * w0 * i10
h1lambda * w1lambda * input[idx(c, ih1, iw1)]) # h1 * w1 * i11
end
end
# weight1 = cat(fill(kern1, n_chan)..., dims=(3,4)) # slow
dx[[1],:,:,:] .+= conv(Δ[1:nextras[1],:,:,:], weight1, pad=pad1, stride=stride1)
weight1 .= weight1[end:-1:1,:,:,:]
dx[[end],:,:,:] .+= conv(Δ[end-nextras[1]+1:end,:,:,:], weight1, pad=pad1, stride=stride1)

## Conv with views is not dispatched to CUDA.conv
# dx[[1],:,:,:] .+= conv(@view(Δ[1:nextras[1],:,:,:]), weight1, pad=pad1, stride=stride1)
# weight1 .= @view(weight1[end:-1:1,:,:,:])
# dx[[end],:,:,:] .+= conv(@view(Δ[end-nextras[1]+1:end,:,:,:]), weight1, pad=pad1, stride=stride1)
end
return output
end

# Second dimension edge-effect correction
if nextras[2] > 0
kern2 = kern[:,1:nextras[2]]
pad2 = (pad[1], 0)
stride2 = (stride[1], 1)
weight2 = similar(Δ, eltype(Δ), (size(kern2)..., n_chan, n_chan))
weight2 .= 0
for i in 1:n_chan
weight2[:,:,i,i] .= kern2
end
# weight2 = cat(fill(kern2, n_chan)..., dims=(3,4)) # slow

yy = conv(Δ[:,1:nextras[2],:,:], weight2, pad=pad2, stride=stride2)
dx[:,[1],:,:] .+= conv(Δ[:,1:nextras[2],:,:], weight2, pad=pad2, stride=stride2)
weight2 .= weight2[:,end:-1:1,:,:]
dx[:,[end],:,:] .+= conv(Δ[:,end-nextras[2]+1:end,:,:], weight2, pad=pad2, stride=stride2)
"""
∇upsample_bilinear(Δ::AbstractArray{T,4}; size::NTuple{2,Integer}) where T

## Conv with views is not dispatched to CUDA.conv
# yy = conv(@view(Δ[:,1:nextras[2],:,:]), weight2, pad=pad2, stride=stride2)
# dx[:,[1],:,:] .+= conv(@view(Δ[:,1:nextras[2],:,:]), weight2, pad=pad2, stride=stride2)
# weight2 .= @view(weight2[:,end:-1:1,:,:])
# dx[:,[end],:,:] .+= conv(@view(Δ[:,end-nextras[2]+1:end,:,:]), weight2, pad=pad2, stride=stride2)
end
# Arguments
- `Δ`: Incoming gradient array, backpropagated from downstream layers
- `size`: Lateral (W,H) size of the image upsampled in the first place

## Finally fix four corners if needed
n1, n2 = nextras
if (n1 > 0) & (n2 > 0)
dx[1,1,:,:] .+= sum(kern[1:n1,1:n2] .* @view(Δ[1:n1,1:n2,:,:]), dims=(1,2))[1,1,:,:]
dx[1,end,:,:] .+= sum(kern[1:n1,end-n2+1:end] .* @view(Δ[1:n1,end-n2+1:end,:,:]), dims=(1,2))[1,1,:,:]
dx[end,end,:,:] .+= sum(kern[end-n1+1:end,end-n2+1:end] .* @view(Δ[end-n1+1:end,end-n2+1:end,:,:]), dims=(1,2))[1,1,:,:]
dx[end,1,:,:] .+= sum(kern[end-n1+1:end,1:n2] .* @view(Δ[end-n1+1:end,1:n2,:,:]), dims=(1,2))[1,1,:,:]
# Outputs
- `dx`: Downsampled version of `Δ`
"""
function ∇upsample_bilinear(Δ::AbstractArray{T,4}; size::NTuple{2,Integer}) where T
w, h, c, n = Base.size(Δ)
out_w, out_h = size
if (w,h) == (out_w, out_h)
return Δ
end

return dx
dx = zero(similar(Δ, T, out_w, out_h, c, n))
return ∇upsample_bilinear_whcn!(dx, Δ)
end

# `n` upsample factor for which a downsample kernel will be determined.
# Δ is given in case of necessity of gpu conversion
function get_downsamplekernel(Δ, n::Int)
step = 1//n
if n % 2 == 0
start = step//2
upward = collect(start:step:1//1)
kernel = [upward; reverse(upward)]
else
start = step
upward = collect(start:step:1//1)
kernel = [upward; reverse(upward[1:end-1])]
function ∇upsample_bilinear_whcn!(dx::AbstractArray{T,4}, Δ::AbstractArray{T,4}) where T
size(dx)[3:4] == size(Δ)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
in_w, in_h, channels, batches = size(dx)

# treat batch and channel dimension as one for better parallelization granularity
channels *= batches
out_w, out_h, _, _ = size(Δ)
output_slice_size = out_h * out_w

width_scale = T((in_w - 1) // (out_w - 1))
height_scale = T((in_h - 1) // (out_h - 1))

@inline idx(c, h, w) = c * in_h * in_w + h * in_w + w + 1

@inbounds Threads.@threads for c in 0:channels-1
for oh in 0:out_h-1
ih0, ih1, h0lambda, h1lambda = compute_source_index_and_lambda(height_scale, oh, in_h, out_h)
for ow in 0:out_w-1
iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w)
output_offset = c * output_slice_size + oh * out_w + ow + 1
Δ_value = Δ[output_offset]
dx[idx(c, ih0, iw0)] += h0lambda * w0lambda * Δ_value # i00
dx[idx(c, ih0, iw1)] += h0lambda * w1lambda * Δ_value # i01
dx[idx(c, ih1, iw0)] += h1lambda * w0lambda * Δ_value # i10
dx[idx(c, ih1, iw1)] += h1lambda * w1lambda * Δ_value # i11
end
end
end
# TODO there must be a more convenient way to send to gpu
kernel = convert(typeof(Δ), reshape(kernel, length(kernel), 1, 1, 1))
kernel = dropdims(kernel, dims=(2,3,4))
return kernel
return dx
end

function ChainRulesCore.rrule(::typeof(upsample_bilinear), x, k)
Ω = upsample_bilinear(x, k)
function ChainRulesCore.rrule(::typeof(upsample_bilinear), x; size)
Ω = upsample_bilinear(x; size=size)
function upsample_bilinear_pullback(Δ)
(NO_FIELDS, ∇upsample_bilinear(Δ, k), DoesNotExist())
(NO_FIELDS, ∇upsample_bilinear(Δ; size=(Base.size(x,1),Base.size(x,2))))
end
return Ω, upsample_bilinear_pullback
end


"""
pixel_shuffle(x, r)

Pixel shuffling operation. `r` is the upscale factor for shuffling.
The operation converts an input of size [W,H,r²C,N] to size [rW,rH,C,N]
Used extensively in super-resolution networks to upsample
Used extensively in super-resolution networks to upsample
towards high resolution features.

Reference : https://arxiv.org/pdf/1609.05158.pdf
Expand All @@ -301,7 +232,7 @@ function pixel_shuffle(x::AbstractArray, r::Integer)
@assert ndims(x) > 2
d = ndims(x) - 2
sizein = size(x)[1:d]
cin, n = size(x, d+1), size(x, d+2)
cin, n = size(x, d+1), size(x, d+2)
@assert cin % r^d == 0
cout = cin ÷ r^d
# x = reshape(x, sizein..., fill(r, d)..., cout, n) # bug https://github.com/FluxML/Zygote.jl/issues/866
Expand Down
Loading