diff --git a/src/upsample.jl b/src/upsample.jl index 20b343819..f2623c79a 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -1,4 +1,5 @@ export upsample_nearest, ∇upsample_nearest, + upsample_linear, ∇upsample_linear, upsample_bilinear, ∇upsample_bilinear, upsample_trilinear, ∇upsample_trilinear, pixel_shuffle @@ -96,6 +97,120 @@ end return input_index0, input_index1, lambda0, lambda1 end +########### +# linear +########### +""" + upsample_linear(x::AbstractArray{T,3}, scale::Real) + upsample_linear(x::AbstractArray{T,3}; size::Integer) + +Upsamples the first dimension of the array `x` by the upsample provided `scale`, +using linear interpolation. As an alternative to using `scale`, the resulting array `size` +can be directly specified with a keyword argument. + +The size of the output is equal to +`(scale*S1, S2, S3)`, where `S1, S2, S3 = size(x)`. +""" +function upsample_linear(x::AbstractArray{<:Any,3}, scale::Real) + outsize = floor(Int, scale * Base.size(x)[1]) + return upsample_linear(x; size=outsize) +end + +function upsample_linear(x::AbstractArray{T,3}; size::Integer) where T + w,c,n = Base.size(x) + if w == size + return x + end + y = similar(x, T, size, c, n) + return upsample_linear_wcn!(y, x) +end + +function upsample_linear(x::AbstractArray{T,3}; size::Integer) where T<:Integer + y = float.(x) + res = upsample_linear(y; size=size) + return round.(T, res) +end + +function upsample_linear_wcn!(output::AbstractArray{T,3}, input::AbstractArray{T,3}) where T + size(input)[2:3] == size(output)[2:3] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))") + in_w, channels, batches = size(input) + # treat batch and channel dimension as one for better parallelization granularity + channels *= batches + out_w, _, _ = size(output) + output_slice_size = out_w + + # T() and // so that we can handle rationals (super slow) + width_scale = T((in_w - 1) // (out_w - 1)) + + @inline idx(c, w) = c * in_w + w + 1 + + @inbounds Threads.@threads for c in 0:channels-1 + for ow in 0:out_w-1 + iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w) + output_offset = c * output_slice_size + ow + 1 + output[output_offset] = (w0lambda * input[idx(c, iw0)] + # w0 * i00 + w1lambda * input[idx(c, iw1)]) # w1 * i01 + end + end + return output +end + +""" + ∇upsample_linear(Δ::AbstractArray{T,3}; size::Integer) where T + +# Arguments +- `Δ`: Incoming gradient array, backpropagated from downstream layers +- `size`: Size of the image upsampled in the first place + +# Outputs +- `dx`: Downsampled version of `Δ` +""" +function ∇upsample_linear(Δ::AbstractArray{T,3}; size::Integer) where T + w, c, n = Base.size(Δ) + out_w = size + if w == out_w + return Δ + end + dx = zero(similar(Δ, T, out_w, c, n)) + return ∇upsample_linear_wcn!(dx, Δ) +end + +function ∇upsample_linear_wcn!(dx::AbstractArray{T,3}, Δ::AbstractArray{T,3}) where T + size(dx)[2:3] == size(Δ)[2:3] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))") + in_w, channels, batches = size(dx) + + # treat batch and channel dimension as one for better parallelization granularity + channels *= batches + out_w, _, _ = size(Δ) + output_slice_size = out_w + + width_scale = T((in_w - 1) // (out_w - 1)) + + @inline idx(c, w) = c * in_w + w + 1 + + @inbounds Threads.@threads for c in 0:channels-1 + for ow in 0:out_w-1 + iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w) + output_offset = c * output_slice_size + ow + 1 + Δ_value = Δ[output_offset] + dx[idx(c, iw0)] += w0lambda * Δ_value # i00 + dx[idx(c, iw1)] += w1lambda * Δ_value # i01 + end + end + return dx +end + +function rrule(::typeof(upsample_linear), x; size) + Ω = upsample_linear(x; size=size) + function upsample_linear_pullback(Δ) + (NO_FIELDS, ∇upsample_linear(Δ; size=Base.size(x,1))) + end + return Ω, upsample_linear_pullback +end + +########### +# bilinear +########### """ upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}) upsample_bilinear(x::AbstractArray{T,4}; size::NTuple{2,Integer}) diff --git a/test/upsample.jl b/test/upsample.jl index 7440bdc39..24f71d0b1 100644 --- a/test/upsample.jl +++ b/test/upsample.jl @@ -17,7 +17,20 @@ @test_throws ArgumentError upsample_nearest(x, size=(3,4)) end -@testset "upsample_bilinear 2d" begin +@testset "Linear upsampling (1D)" begin + x = Float64[1,2,3,4] + x = hcat(x,x,x)[:,:,:] + + y = collect(1:1//3:4) + y = hcat(y,y,y)[:,:,:] + yF64 = Float64.(y) + + @test y ≈ upsample_linear(x, 2.5) + @test y ≈ upsample_linear(x; size=10) + gradtest(x->upsample_linear(x, 2.5), x) +end + +@testset "Bilinear upsampling (2D)" begin x = Float32[1 2; 3 4][:,:,:,:] x = cat(x,x; dims=3) x = cat(x,x; dims=4) @@ -65,7 +78,7 @@ end @test y == y_true_int end -@testset "Trilinear upsampling" begin +@testset "Trilinear upsampling (3D)" begin # Layout: WHDCN, where D is depth # we generate data which is constant along W & H and differs in D # then we upsample along all dimensions