Skip to content

ReverseDiff + callable LogDensityAt struct doesn't like accumulators #946

Closed
@penelopeysm

Description

@penelopeysm

Accumulators: #885

LogDensityAt struct: #922

ReverseDiff doesn't like it (check out the breaking branch where these two PRs were merged):

using DynamicPPL, Distributions
using DynamicPPL.TestUtils.AD: run_ad
using ADTypes: AutoReverseDiff
import ReverseDiff

@model f() = x ~ Normal()

run_ad(f(), AutoReverseDiff())

errors with

MethodError: no method matching seeded_reverse_pass!(::DiffResults.MutableDiffResult{1, Float64, Tuple{Vector{Float64}}}, ::@NamedTuple{logprior::ReverseDiff.TrackedReal{Float64, Float64, Nothing}, loglikelihood::ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, ::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, ::ReverseDiff.GradientTape{DynamicPPL.LogDensityAt{Model{typeof(DynamicPPL.TestUtils.demo_assume_matrix_observe_matrix_index), (:x, Symbol("##arg#500")), (), (), Tuple{Transpose{Float64, Matrix{Float64}}, DynamicPPL.TypeWrap{Array{Float64}}}, Tuple{}, DefaultContext}, SimpleVarInfo{@NamedTuple{s::Matrix{Float64}, m::Vector{Float64}}, DynamicPPL.AccumulatorTuple{3, @NamedTuple{LogPrior::LogPriorAccumulator{Float64}, LogLikelihood::LogLikelihoodAccumulator{Float64}, NumProduce::NumProduceAccumulator{Int64}}}, DynamicPPL.DynamicTransformation}, DefaultContext}, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, @NamedTuple{logprior::ReverseDiff.TrackedReal{Float64, Float64, Nothing}, loglikelihood::ReverseDiff.TrackedReal{Float64, Float64, Nothing}}})
  The function `seeded_reverse_pass!` exists, but no method is defined for this combination of argument types.

Replacing LogDensityAt with the old closure, x -> logdensity_at(x, ...) fixes the error.

Another solution is to make ReverseDiff not use the closure, i.e. declare use_closure(::AutoReverseDiff) = false. (In fact that makes use_closure useless, because ReverseDiff was the only one that was using the closure)

These are the most direct solutions, but I'm not sure if there's a better one (I need to benchmark).

And in any case this seems like an upstream issue too (although my hopes of getting it fixed are ... slim, so we had better find a workaround here).

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions