Skip to content

Fix StackOverflow errors for long recurrences #80

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/Tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,12 @@ a::Call == b::Call = a.func == b.func && a.args == b.args
@inline (c::Call)() = c.func(data.(c.args)...)

mutable struct Tracked{T}
ref::UInt32
f::Call
isleaf::Bool
grad::T
Tracked{T}(f::Call) where T = new(0, f, false)
Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad)
Tracked{T}(f::Call{Nothing}, grad::T) where T = new(0, f, true, grad)
Tracked{T}(f::Call) where T = new(f, false)
Tracked{T}(f::Call, grad::T) where T = new(f, false, grad)
Tracked{T}(f::Call{Nothing}, grad::T) where T = new(f, true, grad)
end

istracked(x::Tracked) = true
Expand Down
101 changes: 58 additions & 43 deletions src/back.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,51 +14,65 @@ init_grad(x) = zero(x)
zero_grad!(x) = zero(x)
zero_grad!(x::AbstractArray) = (x .= 0)

scan(c::Call) = foreach(scan, c.args)

function scan(x::Tracked)
x.isleaf && return
ref = x.ref += 1
if ref == 1
scan(x.f)
isdefined(x, :grad) && (x.grad = zero_grad!(x.grad))
function _walk(queue, seen, c::Call)
foreach(c.args) do x
x === nothing && return
id = objectid(x)
if id ∉ seen
push!(seen, id)
pushfirst!(queue, x)
end
return
end
return
end

function scan(x)
istracked(x) && scan(tracker(x))
return
function walk(f, x::Tracked, seen = Set{UInt64}(); once = true)
queue = Tracked[x]
while !isempty(queue)
x = pop!(queue)
f(x, seen)
_walk(queue, seen, x.f)
once && !x.isleaf && (x.f = Call(missing, ()))
end
end

function back_(c::Call, Δ, once)
function back_(c::Call, Δ, seen)
Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))")
foreach((x, d) -> back(x, d, once), c.args, data.(Δs))
foreach(c.args) do x
isdefined(x, :grad) || return
objectid(x) ∉ seen && zero_grad!(x.grad)
end
foreach((x, d) -> back_(x, d), c.args, data.(Δs))
end

back_(::Call{Nothing}, Δ, once) = nothing
back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
back_(::Call{Nothing}, Δ, seen) = nothing
back_(::Call{Missing}, Δ, seen) = error("`back!` was already used")

accum!(x, Δ) = x .+ Δ
accum!(x::AbstractArray, Δ) = (x .+= Δ)

function back_(x::Tracked, Δ)
if isdefined(x, :grad)
x.grad = accum!(x.grad, Δ)
else
x.grad = Δ
end
return
end

back_(::Nothing, Δ) = return

function back(x::Tracked, Δ, once)
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
ref = x.ref -= 1
grad = if isdefined(x, :grad)
x.grad = accum!(x.grad, Δ)
elseif ref > 0
x.grad = Δ
else
Δ
end
if ref == 0
back_(x.f, grad, once)
once && !x.isleaf && (x.f = Call(missing, ()))
end
return
seen = Set{UInt64}(objectid(x))
if isdefined(x, :grad)
x.grad = zero_grad!(x.grad)
end
back_(x, Δ)
walk(x, seen, once = once) do x, seen
back_(x.f, x.grad, seen)
end
end

back(::Nothing, Δ, once) = return
Expand All @@ -73,7 +87,6 @@ back(::Nothing, Δ, once) = return

function back!(x, Δ; once = true)
istracked(x) || return
scan(x)
back(tracker(x), Δ, once)
return
end
Expand Down Expand Up @@ -110,23 +123,26 @@ function back_(g::Grads, c::Call, Δ)
Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))")
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
foreach((x, Δ) -> back_(g, x, Δ), c.args, Δs)
end

back_(g::Grads, ::Call{Nothing}, Δ) = nothing

function back(g::Grads, x::Tracked, Δ)
function back_(g::Grads, x::Tracked, Δ)
x.isleaf && (accum!(g, x, Δ); return)
ref = x.ref -= 1
if ref > 0 || haskey(g, x)
accum!(g, x, Δ)
ref == 0 && back_(g, x.f, g[x])
else
ref == 0 && back_(g, x.f, Δ)
end
accum!(g, x, Δ)
return
end

back_(g::Grads, ::Nothing, Δ) = return

function back(g::Grads, x::Tracked, Δ)
back_(g, x, Δ)
walk(x, once = false) do x, seen
back_(g, x.f, g[x])
end
end

back(::Grads, ::Nothing, _) = return

collectmemaybe(xs) = xs
Expand All @@ -136,7 +152,6 @@ function forward(f, ps::Params)
y, function (Δ)
g = Grads(ps)
if istracked(y)
scan(y)
back(g, tracker(y), Δ)
end
return g
Expand Down Expand Up @@ -168,7 +183,7 @@ gradient(f, xs...; nest = false) =

"""
J = jacobian(m,x)

Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corresponds to the gradient `J[i,:] = ∇ₓ(m(x)[i])`
"""
function jacobian(f, x::AbstractVector)
Expand Down
25 changes: 14 additions & 11 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ import LinearAlgebra: inv, det, logdet, logabsdet, \, /
using Statistics
using LinearAlgebra: Diagonal, Transpose, Adjoint, diagm, diag

struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
tracker::Tracked{A}
struct TrackedArray{T,N,A<:AbstractArray{T,N},B} <: AbstractArray{T,N}
tracker::Tracked{B}
data::A
grad::A
TrackedArray{T,N,A}(t::Tracked{A}, data::A) where {T,N,A} = new(t, data)
TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad)
grad::B
TrackedArray{T,N,A,B}(t::Tracked{B}, data::A) where {T,N,A,B} = new(t, data)
TrackedArray{T,N,A,B}(t::Tracked{B}, data::A, grad::B) where {T,N,A,B} = new(t, data, grad)
end

data(x::TrackedArray) = x.data
Expand All @@ -23,11 +23,14 @@ TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}

track(c::Call, x::AbstractArray) = TrackedArray(c, x)

TrackedArray(c::Call, x::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c), x)
TrackedArray(c::Call, x::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A,A}(Tracked{A}(c), x)

TrackedArray(c::Call, x::A) where A <: Union{SubArray, Transpose, Adjoint, PermutedDimsArray} =
TrackedArray{eltype(A),ndims(A),A,Any}(Tracked{Any}(c), x)

TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ)
TrackedArray{eltype(A),ndims(A),A,A}(Tracked{A}(c, Δ), x, Δ)

TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zero(x))

Expand All @@ -38,12 +41,12 @@ Base.convert(::Type{T}, x::S) where {T<:TrackedArray,S<:T} = x
Base.convert(T::Type{<:TrackedArray}, x::TrackedArray) =
error("Not implemented: convert $(typeof(x)) to $T")

Base.convert(::Type{<:TrackedArray{T,N,A}}, x::AbstractArray) where {T,N,A} =
Base.convert(::Type{<:TrackedArray{T,N,A,B}}, x::AbstractArray) where {T,N,A,B} =
TrackedArray(convert(A, x))

Base.show(io::IO, t::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
Base.show(io::IO, t::Type{TrackedArray{T,N,A,B}}) where {T,N,A<:AbstractArray{T,N},B} =
@isdefined(A) ?
print(io, "TrackedArray{…,$A}") :
print(io, "TrackedArray{…,$A,...}") :
invoke(show, Tuple{IO,DataType}, io, t)

function Base.summary(io::IO, x::TrackedArray)
Expand Down
26 changes: 19 additions & 7 deletions src/lib/real.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,28 @@ function collect(xs)
track(Call(collect, (tracker.(xs),)), data.(xs))
end

function scan(c::Call{typeof(collect)})
foreach(scan, c.args[1])
end

function back_(c::Call{typeof(collect)}, Δ, once)
foreach((x, d) -> back(x, d, once), c.args[1], data(Δ))
function back_(c::Call{typeof(collect)}, Δ, seen)
foreach(c.args[1]) do x
isdefined(x, :grad) || return
objectid(x) ∉ seen && zero_grad!(x.grad)
end
foreach((x, d) -> back_(x, d), c.args[1], data(Δ))
end

function back_(g::Grads, c::Call{typeof(collect)}, Δ)
foreach((x, Δ) -> back(g, x, Δ), c.args[1], Δ)
foreach((x, Δ) -> back_(g, x, Δ), c.args[1], Δ)
end

function _walk(queue, seen, c::Call{typeof(collect)})
foreach(c.args[1]) do x
x === nothing && return
id = objectid(x)
if id ∉ seen
push!(seen, id)
pushfirst!(queue, x)
end
return
end
end

collectmemaybe(xs::AbstractArray{>:TrackedReal}) = collect(xs)
Expand Down
10 changes: 10 additions & 0 deletions test/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,16 @@ end
@test back([1, 1]) == (32,)
end

@testset "Long Recurrences" begin
@test Tracker.gradient(rand(10000)) do x
s = 0.0
for i in 1:length(x)
s += x[i]
end
return s
end[1] == ones(10000)
end

@testset "PDMats" begin
B = rand(5, 5)
S = PDMat(I + B * B')
Expand Down