| 
 | 1 | +using .ReverseDiff: compile, GradientTape  | 
 | 2 | +using .ReverseDiff.DiffResults: GradientResult  | 
 | 3 | + | 
 | 4 | +struct ReverseDiffAD{cache} <: ADBackend end  | 
 | 5 | +const RDCache = Ref(false)  | 
 | 6 | +setrdcache(b::Bool) = setrdcache(Val(b))  | 
 | 7 | +setrdcache(::Val{false}) = RDCache[] = false  | 
 | 8 | +setrdcache(::Val) = throw("Memoization.jl is not loaded. Please load it before setting the cache to true.")  | 
 | 9 | +function emptyrdcache end  | 
 | 10 | + | 
 | 11 | +getrdcache() = RDCache[]  | 
 | 12 | +ADBackend(::Val{:reversediff}) = ReverseDiffAD{getrdcache()}  | 
 | 13 | +function setadbackend(::Val{:reversediff})  | 
 | 14 | +    ADBACKEND[] = :reversediff  | 
 | 15 | +end  | 
 | 16 | + | 
 | 17 | +function gradient_logp(  | 
 | 18 | +    backend::ReverseDiffAD{false},  | 
 | 19 | +    θ::AbstractVector{<:Real},  | 
 | 20 | +    vi::VarInfo,  | 
 | 21 | +    model::Model,  | 
 | 22 | +    sampler::AbstractSampler = SampleFromPrior(),  | 
 | 23 | +)  | 
 | 24 | +    T = typeof(getlogp(vi))  | 
 | 25 | +      | 
 | 26 | +    # Specify objective function.  | 
 | 27 | +    function f(θ)  | 
 | 28 | +        new_vi = VarInfo(vi, sampler, θ)  | 
 | 29 | +        model(new_vi, sampler)  | 
 | 30 | +        return getlogp(new_vi)  | 
 | 31 | +    end  | 
 | 32 | +    tp, result = taperesult(f, θ)  | 
 | 33 | +    ReverseDiff.gradient!(result, tp, θ)  | 
 | 34 | +    l = DiffResults.value(result)  | 
 | 35 | +    ∂l∂θ::typeof(θ) = DiffResults.gradient(result)  | 
 | 36 | + | 
 | 37 | +    return l, ∂l∂θ  | 
 | 38 | +end  | 
 | 39 | + | 
 | 40 | +tape(f, x) = GradientTape(f, x)  | 
 | 41 | +function taperesult(f, x)  | 
 | 42 | +    return tape(f, x), GradientResult(x)  | 
 | 43 | +end  | 
 | 44 | + | 
 | 45 | +@require Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" @eval begin  | 
 | 46 | +    setrdcache(::Val{true}) = RDCache[] = true  | 
 | 47 | +    function emptyrdcache()  | 
 | 48 | +        for k in keys(Memoization.caches)  | 
 | 49 | +            if k[1] === typeof(memoized_taperesult)  | 
 | 50 | +                pop!(Memoization.caches, k)  | 
 | 51 | +            end  | 
 | 52 | +        end  | 
 | 53 | +    end  | 
 | 54 | +    function gradient_logp(  | 
 | 55 | +        backend::ReverseDiffAD{true},  | 
 | 56 | +        θ::AbstractVector{<:Real},  | 
 | 57 | +        vi::VarInfo,  | 
 | 58 | +        model::Model,  | 
 | 59 | +        sampler::AbstractSampler = SampleFromPrior(),  | 
 | 60 | +    )  | 
 | 61 | +        T = typeof(getlogp(vi))  | 
 | 62 | +          | 
 | 63 | +        # Specify objective function.  | 
 | 64 | +        function f(θ)  | 
 | 65 | +            new_vi = VarInfo(vi, sampler, θ)  | 
 | 66 | +            model(new_vi, sampler)  | 
 | 67 | +            return getlogp(new_vi)  | 
 | 68 | +        end  | 
 | 69 | +        ctp, result = memoized_taperesult(f, θ)  | 
 | 70 | +        ReverseDiff.gradient!(result, ctp, θ)  | 
 | 71 | +        l = DiffResults.value(result)  | 
 | 72 | +        ∂l∂θ = DiffResults.gradient(result)  | 
 | 73 | + | 
 | 74 | +        return l, ∂l∂θ  | 
 | 75 | +    end  | 
 | 76 | + | 
 | 77 | +    # This makes sure we generate a single tape per Turing model and sampler  | 
 | 78 | +    struct RDTapeKey{F, Tx}  | 
 | 79 | +        f::F  | 
 | 80 | +        x::Tx  | 
 | 81 | +    end  | 
 | 82 | +    function Memoization._get!(f::Union{Function, Type}, d::IdDict, keys::Tuple{Tuple{RDTapeKey}, Any})  | 
 | 83 | +        key = keys[1][1]  | 
 | 84 | +        return Memoization._get!(f, d, (typeof(key.f), typeof(key.x), size(key.x)))  | 
 | 85 | +    end  | 
 | 86 | +    memoized_taperesult(f, x) = memoized_taperesult(RDTapeKey(f, x))  | 
 | 87 | +    Memoization.@memoize function memoized_taperesult(k::RDTapeKey)  | 
 | 88 | +        return compiledtape(k.f, k.x), GradientResult(k.x)  | 
 | 89 | +    end  | 
 | 90 | +    memoized_tape(f, x) = memoized_tape(RDTapeKey(f, x))  | 
 | 91 | +    Memoization.@memoize memoized_tape(k::RDTapeKey) = compiledtape(k.f, k.x)  | 
 | 92 | +    compiledtape(f, x) = compile(GradientTape(f, x))  | 
 | 93 | +end  | 
0 commit comments