From 76a3504ab3f29a3b0a36eb2936c06db243d9e4d2 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sat, 18 Jul 2020 22:12:18 +1000 Subject: [PATCH 1/4] include broadcast.jl after propagation.jl --- src/ReverseDiff.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ReverseDiff.jl b/src/ReverseDiff.jl index 779038b..1b339c4 100644 --- a/src/ReverseDiff.jl +++ b/src/ReverseDiff.jl @@ -29,8 +29,8 @@ include("tape.jl") include("tracked.jl") include("macros.jl") include("derivatives/arrays.jl") -include("derivatives/broadcast.jl") include("derivatives/propagation.jl") +include("derivatives/broadcast.jl") include("derivatives/scalars.jl") include("derivatives/elementwise.jl") include("derivatives/linalg/arithmetic.jl") From c7f16164d2520d32eebcb16635cf3a3ff840f9b0 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sat, 18 Jul 2020 22:35:18 +1000 Subject: [PATCH 2/4] fix name clash --- src/derivatives/broadcast.jl | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/derivatives/broadcast.jl b/src/derivatives/broadcast.jl index 93f0023..dbacdf7 100644 --- a/src/derivatives/broadcast.jl +++ b/src/derivatives/broadcast.jl @@ -196,29 +196,29 @@ end results, _, bounds = instruction.cache N = length(input) if N == 1 || all(isequal(size(input[1])), size.(Base.tail(input))) - _add_to_deriv!(input, output_deriv, results) + _br_add_to_deriv!(input, output_deriv, results) else - _add_to_deriv!(input, output_deriv, results, bounds) + _br_add_to_deriv!(input, output_deriv, results, bounds) end unseed!(output) return nothing end -@generated function _add_to_deriv!(xs::T, o, r) where {T <: Tuple} +@generated function _br_add_to_deriv!(xs::T, o, r) where {T <: Tuple} N = length(T.types) - return Expr(:block, [:(_add_to_deriv!(xs[$i], o, r, Val($i))) for i in 1:N]...) + return Expr(:block, [:(_br_add_to_deriv!(xs[$i], o, r, Val($i))) for i in 1:N]...) end -_add_to_deriv!(_, _, _, _) = nothing -function _add_to_deriv!(x::Union{TrackedReal, TrackedArray}, out_deriv, results, ::Val{i}) where {i} +_br_add_to_deriv!(_, _, _, _) = nothing +function _br_add_to_deriv!(x::Union{TrackedReal, TrackedArray}, out_deriv, results, ::Val{i}) where {i} return istracked(x) && diffresult_increment_deriv!(x, out_deriv, results, i) end -@generated function _add_to_deriv!(xs::T, o, r, bounds) where {T <: Tuple} +@generated function _br_add_to_deriv!(xs::T, o, r, bounds) where {T <: Tuple} N = length(T.types) - return Expr(:block, [:(_add_to_deriv!(xs[$i], o, r, Val($i), bounds[$i])) for i in 1:N]...) + return Expr(:block, [:(_br_add_to_deriv!(xs[$i], o, r, Val($i), bounds[$i])) for i in 1:N]...) end -_add_to_deriv!(_, _, _, _, _) = nothing -function _add_to_deriv!(x::Union{TrackedReal,TrackedArray}, out_deriv, results, ::Val{i}, bound) where {i} +_br_add_to_deriv!(_, _, _, _, _) = nothing +function _br_add_to_deriv!(x::Union{TrackedReal,TrackedArray}, out_deriv, results, ::Val{i}, bound) where {i} return istracked(x) && diffresult_increment_deriv!(x, out_deriv, results, i, bound) end @@ -283,7 +283,7 @@ end N = length(input) Δargs = _derivs(f, output_deriv, value.(input)...) dxs = map(unbroadcast, input, Δargs) - map(_add_to_deriv!, input, dxs) + map(_br_add_to_deriv!, input, dxs) unseed!(output) return nothing end From e813abcc6aed96248deb7f7bf941c65cc070fc6d Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sat, 18 Jul 2020 23:26:29 +1000 Subject: [PATCH 3/4] fix cat and test the gradient --- src/derivatives/arrays.jl | 29 +++++++++++++++++--------- test/derivatives/ArrayFunctionTests.jl | 24 +++++++++++++++++++++ 2 files changed, 43 insertions(+), 10 deletions(-) diff --git a/src/derivatives/arrays.jl b/src/derivatives/arrays.jl index 44ded91..5ff0213 100644 --- a/src/derivatives/arrays.jl +++ b/src/derivatives/arrays.jl @@ -62,9 +62,11 @@ end function back(Δ) start = 0 Δs = map(xs) do xsi - x = map(_ -> :, size(xsi)) - i = isempty(x) ? x : Base.tail(x) - d = Δ[start+1:start+size(xsi,1), i...] + if xsi isa Number + d = Δ[start+1] + else + d = Δ[start+1:start+size(xsi,1), :] + end start += size(xsi, 1) d end @@ -75,11 +77,13 @@ end @grad function hcat(xs::Union{Number, AbstractVecOrMat}...) xs_value = value.(xs) - out_value = reduce(hcat,xs_value) + out_value = reduce(hcat, xs_value) function back(Δ) start = 0 Δs = map(xs) do xsi - d = if ndims(xsi) == 1 + d = if ndims(xsi) == 0 + Δ[start+1] + elseif ndims(xsi) == 1 Δ[:, start+1] else i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail @@ -102,11 +106,16 @@ end return cat(Xs_value...; dims = dims), Δ -> begin start = ntuple(i -> 0, Val(ndims(Δ))) Δs = map(Xs) do xs - dim_xs = 1:ndims(xs) - till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ))) - xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ))) - d = reshape(Δ[xs_in_Δ...],size(xs)) - start = start .+ till_xs + if xs isa Number + d = Δ[start+1] + start = start .+ 1 + else + dim_xs = 1:ndims(xs) + till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ))) + xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ))) + d = reshape(Δ[xs_in_Δ...],size(xs)) + start = start .+ till_xs + end d end return (Δs...,) diff --git a/test/derivatives/ArrayFunctionTests.jl b/test/derivatives/ArrayFunctionTests.jl index a0e8dfe..236cf1e 100644 --- a/test/derivatives/ArrayFunctionTests.jl +++ b/test/derivatives/ArrayFunctionTests.jl @@ -1,3 +1,4 @@ +using ForwardDiff using ReverseDiff: track, value, gradient, TrackedVector, TrackedMatrix, TrackedArray using Test @@ -32,6 +33,29 @@ function testcat(f, args::Tuple{Any, Any}, type, kwargs=NamedTuple()) x = f(track.(args)...; kwargs...) @test x isa type @test value(x) == f(args...; kwargs...) + + sizes = size.(args) + F = vecx -> sum(f(unpack(sizes, vecx)...; kwargs...)) + X = pack(args) + @test ForwardDiff.gradient(F, X) == gradient(F, X) +end +function pack(xs) + return mapreduce(vcat, xs) do x + x isa Number ? x : vec(x) + end +end +function unpack(sizes, vecx) + start = 0 + out = map(sizes) do s + if s === () + x = vecx[start+1] + start += 1 + else + x = reshape(vecx[start+1:start+prod(s)], s) + start += prod(s) + end + end + return out end @testset "cat" begin From 343a75401c1634b0c94cf4e660ff510b4e9a72e3 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sat, 18 Jul 2020 23:26:46 +1000 Subject: [PATCH 4/4] typo --- src/derivatives/broadcast.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/derivatives/broadcast.jl b/src/derivatives/broadcast.jl index dbacdf7..3cd2b56 100644 --- a/src/derivatives/broadcast.jl +++ b/src/derivatives/broadcast.jl @@ -283,7 +283,7 @@ end N = length(input) Δargs = _derivs(f, output_deriv, value.(input)...) dxs = map(unbroadcast, input, Δargs) - map(_br_add_to_deriv!, input, dxs) + map(_add_to_deriv!, input, dxs) unseed!(output) return nothing end