Skip to content

Commit

Permalink
Merge pull request #137 from JuliaDiff/mt/healed_reversediff
Browse files Browse the repository at this point in the history
Fix broadcasting and cat - round 2
  • Loading branch information
mohamed82008 authored Jul 18, 2020
2 parents 60a92ac + 343a754 commit 62601fb
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/ReverseDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ include("tape.jl")
include("tracked.jl")
include("macros.jl")
include("derivatives/arrays.jl")
include("derivatives/broadcast.jl")
include("derivatives/propagation.jl")
include("derivatives/broadcast.jl")
include("derivatives/scalars.jl")
include("derivatives/elementwise.jl")
include("derivatives/linalg/arithmetic.jl")
Expand Down
29 changes: 19 additions & 10 deletions src/derivatives/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,11 @@ end
function back(Δ)
start = 0
Δs = map(xs) do xsi
x = map(_ -> :, size(xsi))
i = isempty(x) ? x : Base.tail(x)
d = Δ[start+1:start+size(xsi,1), i...]
if xsi isa Number
d = Δ[start+1]
else
d = Δ[start+1:start+size(xsi,1), :]
end
start += size(xsi, 1)
d
end
Expand All @@ -75,11 +77,13 @@ end

@grad function hcat(xs::Union{Number, AbstractVecOrMat}...)
xs_value = value.(xs)
out_value = reduce(hcat,xs_value)
out_value = reduce(hcat, xs_value)
function back(Δ)
start = 0
Δs = map(xs) do xsi
d = if ndims(xsi) == 1
d = if ndims(xsi) == 0
Δ[start+1]
elseif ndims(xsi) == 1
Δ[:, start+1]
else
i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail
Expand All @@ -102,11 +106,16 @@ end
return cat(Xs_value...; dims = dims), Δ -> begin
start = ntuple(i -> 0, Val(ndims(Δ)))
Δs = map(Xs) do xs
dim_xs = 1:ndims(xs)
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ)))
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ)))
d = reshape(Δ[xs_in_Δ...],size(xs))
start = start .+ till_xs
if xs isa Number
d = Δ[start+1]
start = start .+ 1
else
dim_xs = 1:ndims(xs)
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ)))
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ)))
d = reshape(Δ[xs_in_Δ...],size(xs))
start = start .+ till_xs
end
d
end
return (Δs...,)
Expand Down
20 changes: 10 additions & 10 deletions src/derivatives/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,29 +196,29 @@ end
results, _, bounds = instruction.cache
N = length(input)
if N == 1 || all(isequal(size(input[1])), size.(Base.tail(input)))
_add_to_deriv!(input, output_deriv, results)
_br_add_to_deriv!(input, output_deriv, results)
else
_add_to_deriv!(input, output_deriv, results, bounds)
_br_add_to_deriv!(input, output_deriv, results, bounds)
end
unseed!(output)
return nothing
end

@generated function _add_to_deriv!(xs::T, o, r) where {T <: Tuple}
@generated function _br_add_to_deriv!(xs::T, o, r) where {T <: Tuple}
N = length(T.types)
return Expr(:block, [:(_add_to_deriv!(xs[$i], o, r, Val($i))) for i in 1:N]...)
return Expr(:block, [:(_br_add_to_deriv!(xs[$i], o, r, Val($i))) for i in 1:N]...)
end
_add_to_deriv!(_, _, _, _) = nothing
function _add_to_deriv!(x::Union{TrackedReal, TrackedArray}, out_deriv, results, ::Val{i}) where {i}
_br_add_to_deriv!(_, _, _, _) = nothing
function _br_add_to_deriv!(x::Union{TrackedReal, TrackedArray}, out_deriv, results, ::Val{i}) where {i}
return istracked(x) && diffresult_increment_deriv!(x, out_deriv, results, i)
end

@generated function _add_to_deriv!(xs::T, o, r, bounds) where {T <: Tuple}
@generated function _br_add_to_deriv!(xs::T, o, r, bounds) where {T <: Tuple}
N = length(T.types)
return Expr(:block, [:(_add_to_deriv!(xs[$i], o, r, Val($i), bounds[$i])) for i in 1:N]...)
return Expr(:block, [:(_br_add_to_deriv!(xs[$i], o, r, Val($i), bounds[$i])) for i in 1:N]...)
end
_add_to_deriv!(_, _, _, _, _) = nothing
function _add_to_deriv!(x::Union{TrackedReal,TrackedArray}, out_deriv, results, ::Val{i}, bound) where {i}
_br_add_to_deriv!(_, _, _, _, _) = nothing
function _br_add_to_deriv!(x::Union{TrackedReal,TrackedArray}, out_deriv, results, ::Val{i}, bound) where {i}
return istracked(x) && diffresult_increment_deriv!(x, out_deriv, results, i, bound)
end

Expand Down
24 changes: 24 additions & 0 deletions test/derivatives/ArrayFunctionTests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using ForwardDiff
using ReverseDiff: track, value, gradient, TrackedVector, TrackedMatrix, TrackedArray
using Test

Expand Down Expand Up @@ -32,6 +33,29 @@ function testcat(f, args::Tuple{Any, Any}, type, kwargs=NamedTuple())
x = f(track.(args)...; kwargs...)
@test x isa type
@test value(x) == f(args...; kwargs...)

sizes = size.(args)
F = vecx -> sum(f(unpack(sizes, vecx)...; kwargs...))
X = pack(args)
@test ForwardDiff.gradient(F, X) == gradient(F, X)
end
function pack(xs)
return mapreduce(vcat, xs) do x
x isa Number ? x : vec(x)
end
end
function unpack(sizes, vecx)
start = 0
out = map(sizes) do s
if s === ()
x = vecx[start+1]
start += 1
else
x = reshape(vecx[start+1:start+prod(s)], s)
start += prod(s)
end
end
return out
end

@testset "cat" begin
Expand Down

0 comments on commit 62601fb

Please sign in to comment.