@@ -106,6 +106,8 @@ struct LogDensityFunction{
106106 adtype:: AD
107107 " (internal use only) gradient preparation object for the model"
108108 prep:: Union{Nothing,DI.GradientPrep}
109+ " (internal use only) the closure used for the gradient preparation"
110+ closure:: Union{Nothing,Function}
109111
110112 function LogDensityFunction (
111113 model:: Model ,
@@ -124,10 +126,16 @@ struct LogDensityFunction{
124126 # Get a set of dummy params to use for prep
125127 x = map (identity, varinfo[:])
126128 if use_closure (adtype)
127- prep = DI. prepare_gradient (
128- x -> logdensity_at (x, model, varinfo, context), adtype, x
129- )
129+ # The closure itself has to be stored inside the
130+ # LogDensityFunction to ensure that the signature of the
131+ # function being differentiated is the same as that used for
132+ # preparation. See
133+ # https://github.com/TuringLang/DynamicPPL.jl/pull/922 for an
134+ # explanation.
135+ closure = x -> logdensity_at (x, model, varinfo, context)
136+ prep = DI. prepare_gradient (closure, adtype, x)
130137 else
138+ closure = nothing
131139 prep = DI. prepare_gradient (
132140 logdensity_at,
133141 adtype,
@@ -139,7 +147,7 @@ struct LogDensityFunction{
139147 end
140148 end
141149 return new {typeof(model),typeof(varinfo),typeof(context),typeof(adtype)} (
142- model, varinfo, context, adtype, prep
150+ model, varinfo, context, adtype, prep, closure
143151 )
144152 end
145153end
@@ -208,9 +216,8 @@ function LogDensityProblems.logdensity_and_gradient(
208216 # Make branching statically inferrable, i.e. type-stable (even if the two
209217 # branches happen to return different types)
210218 return if use_closure (f. adtype)
211- DI. value_and_gradient (
212- x -> logdensity_at (x, f. model, f. varinfo, f. context), f. prep, f. adtype, x
213- )
219+ f. closure === nothing && error (" Closure not available; this should not happen" )
220+ DI. value_and_gradient (f. closure, f. prep, f. adtype, x)
214221 else
215222 DI. value_and_gradient (
216223 logdensity_at,
0 commit comments