Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix broadcasting and cat - round 2 #137

Merged
merged 4 commits into from
Jul 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ReverseDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ include("tape.jl")
include("tracked.jl")
include("macros.jl")
include("derivatives/arrays.jl")
include("derivatives/broadcast.jl")
include("derivatives/propagation.jl")
include("derivatives/broadcast.jl")
include("derivatives/scalars.jl")
include("derivatives/elementwise.jl")
include("derivatives/linalg/arithmetic.jl")
Expand Down
29 changes: 19 additions & 10 deletions src/derivatives/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,11 @@ end
function back(Δ)
start = 0
Δs = map(xs) do xsi
x = map(_ -> :, size(xsi))
i = isempty(x) ? x : Base.tail(x)
d = Δ[start+1:start+size(xsi,1), i...]
if xsi isa Number
d = Δ[start+1]
else
d = Δ[start+1:start+size(xsi,1), :]
end
start += size(xsi, 1)
d
end
Expand All @@ -75,11 +77,13 @@ end

@grad function hcat(xs::Union{Number, AbstractVecOrMat}...)
xs_value = value.(xs)
out_value = reduce(hcat,xs_value)
out_value = reduce(hcat, xs_value)
function back(Δ)
start = 0
Δs = map(xs) do xsi
d = if ndims(xsi) == 1
d = if ndims(xsi) == 0
Δ[start+1]
elseif ndims(xsi) == 1
Δ[:, start+1]
else
i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail
Expand All @@ -102,11 +106,16 @@ end
return cat(Xs_value...; dims = dims), Δ -> begin
start = ntuple(i -> 0, Val(ndims(Δ)))
Δs = map(Xs) do xs
dim_xs = 1:ndims(xs)
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ)))
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ)))
d = reshape(Δ[xs_in_Δ...],size(xs))
start = start .+ till_xs
if xs isa Number
d = Δ[start+1]
start = start .+ 1
else
dim_xs = 1:ndims(xs)
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ)))
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ)))
d = reshape(Δ[xs_in_Δ...],size(xs))
start = start .+ till_xs
end
d
end
return (Δs...,)
Expand Down
20 changes: 10 additions & 10 deletions src/derivatives/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,29 +196,29 @@ end
results, _, bounds = instruction.cache
N = length(input)
if N == 1 || all(isequal(size(input[1])), size.(Base.tail(input)))
_add_to_deriv!(input, output_deriv, results)
_br_add_to_deriv!(input, output_deriv, results)
else
_add_to_deriv!(input, output_deriv, results, bounds)
_br_add_to_deriv!(input, output_deriv, results, bounds)
end
unseed!(output)
return nothing
end

@generated function _add_to_deriv!(xs::T, o, r) where {T <: Tuple}
@generated function _br_add_to_deriv!(xs::T, o, r) where {T <: Tuple}
N = length(T.types)
return Expr(:block, [:(_add_to_deriv!(xs[$i], o, r, Val($i))) for i in 1:N]...)
return Expr(:block, [:(_br_add_to_deriv!(xs[$i], o, r, Val($i))) for i in 1:N]...)
end
_add_to_deriv!(_, _, _, _) = nothing
function _add_to_deriv!(x::Union{TrackedReal, TrackedArray}, out_deriv, results, ::Val{i}) where {i}
_br_add_to_deriv!(_, _, _, _) = nothing
function _br_add_to_deriv!(x::Union{TrackedReal, TrackedArray}, out_deriv, results, ::Val{i}) where {i}
return istracked(x) && diffresult_increment_deriv!(x, out_deriv, results, i)
end

@generated function _add_to_deriv!(xs::T, o, r, bounds) where {T <: Tuple}
@generated function _br_add_to_deriv!(xs::T, o, r, bounds) where {T <: Tuple}
N = length(T.types)
return Expr(:block, [:(_add_to_deriv!(xs[$i], o, r, Val($i), bounds[$i])) for i in 1:N]...)
return Expr(:block, [:(_br_add_to_deriv!(xs[$i], o, r, Val($i), bounds[$i])) for i in 1:N]...)
end
_add_to_deriv!(_, _, _, _, _) = nothing
function _add_to_deriv!(x::Union{TrackedReal,TrackedArray}, out_deriv, results, ::Val{i}, bound) where {i}
_br_add_to_deriv!(_, _, _, _, _) = nothing
function _br_add_to_deriv!(x::Union{TrackedReal,TrackedArray}, out_deriv, results, ::Val{i}, bound) where {i}
return istracked(x) && diffresult_increment_deriv!(x, out_deriv, results, i, bound)
end

Expand Down
24 changes: 24 additions & 0 deletions test/derivatives/ArrayFunctionTests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using ForwardDiff
using ReverseDiff: track, value, gradient, TrackedVector, TrackedMatrix, TrackedArray
using Test

Expand Down Expand Up @@ -32,6 +33,29 @@ function testcat(f, args::Tuple{Any, Any}, type, kwargs=NamedTuple())
x = f(track.(args)...; kwargs...)
@test x isa type
@test value(x) == f(args...; kwargs...)

sizes = size.(args)
F = vecx -> sum(f(unpack(sizes, vecx)...; kwargs...))
X = pack(args)
@test ForwardDiff.gradient(F, X) == gradient(F, X)
end
function pack(xs)
return mapreduce(vcat, xs) do x
x isa Number ? x : vec(x)
end
end
function unpack(sizes, vecx)
start = 0
out = map(sizes) do s
if s === ()
x = vecx[start+1]
start += 1
else
x = reshape(vecx[start+1:start+prod(s)], s)
start += prod(s)
end
end
return out
end

@testset "cat" begin
Expand Down