@@ -18,7 +18,8 @@ is_supported(::ADTypes.AutoReverseDiff) = true
1818""" 
1919    LogDensityFunction( 
2020        model::Model, 
21-         varinfo::AbstractVarInfo=VarInfo(model); 
21+         getlogdensity::Function=getlogjoint, 
22+         varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); 
2223        adtype::Union{ADTypes.AbstractADType,Nothing}=nothing 
2324    ) 
2425
@@ -28,9 +29,10 @@ A struct which contains a model, along with all the information necessary to:
2829 - and if `adtype` is provided, calculate the gradient of the log density at 
2930 that point. 
3031
31- At its most basic level, a LogDensityFunction wraps the model together with the 
32- type of varinfo to be used. These must be known in order to calculate the log 
33- density (using [`DynamicPPL.evaluate!!`](@ref)). 
32+ At its most basic level, a LogDensityFunction wraps the model together with a 
33+ function that specifies how to extract the log density, and the type of  
34+ VarInfo to be used. These must be known in order to calculate the log density 
35+ (using [`DynamicPPL.evaluate!!`](@ref)). 
3436
3537If the `adtype` keyword argument is provided, then this struct will also store 
3638the adtype along with other information for efficient calculation of the 
@@ -72,13 +74,13 @@ julia> LogDensityProblems.dimension(f)
72741 
7375
7476julia> # By default it uses `VarInfo` under the hood, but this is not necessary. 
75-        f = LogDensityFunction(model, SimpleVarInfo(model)); 
77+        f = LogDensityFunction(model, getlogjoint,  SimpleVarInfo(model)); 
7678
7779julia> LogDensityProblems.logdensity(f, [0.0]) 
7880-2.3378770664093453 
7981
80- julia> # LogDensityFunction respects  the accumulators in VarInfo : 
81-        f_prior = LogDensityFunction(model, setaccs!!(VarInfo(model), (LogPriorAccumulator(),)) ); 
82+ julia> # One can also specify evaluating e.g.  the log prior only : 
83+        f_prior = LogDensityFunction(model, getlogprior ); 
8284
8385julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0) 
8486true 
@@ -93,11 +95,13 @@ julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
9395``` 
9496""" 
9597struct  LogDensityFunction{
96-     M<: Model ,V<: AbstractVarInfo ,AD<: Union{Nothing,ADTypes.AbstractADType} 
98+     M<: Model ,F <: Function , V<: AbstractVarInfo ,AD<: Union{Nothing,ADTypes.AbstractADType} 
9799} <:  AbstractModel 
98100    " model used for evaluation" 
99101    model:: M 
100-     " varinfo used for evaluation" 
102+     " function to be called on `varinfo` to extract the log density. By default `getlogjoint`." 
103+     getlogdensity:: F 
104+     " varinfo used for evaluation. If not specified, generated with `ldf_default_varinfo`." 
101105    varinfo:: V 
102106    " AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated" 
103107    adtype:: AD 
@@ -106,7 +110,8 @@ struct LogDensityFunction{
106110
107111    function  LogDensityFunction (
108112        model:: Model ,
109-         varinfo:: AbstractVarInfo = VarInfo (model);
113+         getlogdensity:: Function = getlogjoint,
114+         varinfo:: AbstractVarInfo = ldf_default_varinfo (model, getlogdensity);
110115        adtype:: Union{ADTypes.AbstractADType,Nothing} = nothing ,
111116    )
112117        if  adtype ===  nothing 
@@ -120,15 +125,22 @@ struct LogDensityFunction{
120125            #  Get a set of dummy params to use for prep
121126            x =  map (identity, varinfo[:])
122127            if  use_closure (adtype)
123-                 prep =  DI. prepare_gradient (LogDensityAt (model, varinfo), adtype, x)
128+                 prep =  DI. prepare_gradient (
129+                     LogDensityAt (model, getlogdensity, varinfo), adtype, x
130+                 )
124131            else 
125132                prep =  DI. prepare_gradient (
126-                     logdensity_at, adtype, x, DI. Constant (model), DI. Constant (varinfo)
133+                     logdensity_at,
134+                     adtype,
135+                     x,
136+                     DI. Constant (model),
137+                     DI. Constant (getlogdensity),
138+                     DI. Constant (varinfo),
127139                )
128140            end 
129141        end 
130-         return  new {typeof(model),typeof(varinfo),typeof(adtype)} (
131-             model, varinfo, adtype, prep
142+         return  new {typeof(model),typeof(getlogdensity),typeof( varinfo),typeof(adtype)} (
143+             model, getlogdensity,  varinfo, adtype, prep
132144        )
133145    end 
134146end 
@@ -149,83 +161,112 @@ function LogDensityFunction(
149161    return  if  adtype ===  f. adtype
150162        f  #  Avoid recomputing prep if not needed
151163    else 
152-         LogDensityFunction (f. model, f. varinfo; adtype= adtype)
164+         LogDensityFunction (f. model, f. getlogdensity, f . varinfo; adtype= adtype)
153165    end 
154166end 
155167
168+ """ 
169+     ldf_default_varinfo(model::Model, getlogdensity::Function) 
170+ 
171+ Create the default AbstractVarInfo that should be used for evaluating the log density. 
172+ 
173+ Only the accumulators necesessary for `getlogdensity` will be used. 
174+ """ 
175+ function  ldf_default_varinfo (:: Model , getlogdensity:: Function )
176+     msg =  """ 
177+     LogDensityFunction does not know what sort of VarInfo should be used when \ 
178+     `getlogdensity` is $getlogdensity . Please specify a VarInfo explicitly. 
179+     """  
180+     return  error (msg)
181+ end 
182+ 
183+ ldf_default_varinfo (model:: Model , :: typeof (getlogjoint)) =  VarInfo (model)
184+ 
185+ function  ldf_default_varinfo (model:: Model , :: typeof (getlogprior))
186+     return  setaccs!! (VarInfo (model), (LogPriorAccumulator (),))
187+ end 
188+ 
189+ function  ldf_default_varinfo (model:: Model , :: typeof (getloglikelihood))
190+     return  setaccs!! (VarInfo (model), (LogLikelihoodAccumulator (),))
191+ end 
192+ 
156193""" 
157194    logdensity_at( 
158195        x::AbstractVector, 
159196        model::Model, 
197+         getlogdensity::Function, 
160198        varinfo::AbstractVarInfo, 
161199    ) 
162200
163- Evaluate the log density of the given `model` at the given parameter values `x`, 
164- using the given `varinfo`. Note that the `varinfo` argument is provided only 
165- for its structure, in the sense that the parameters from the vector `x` are 
166- inserted into it, and its own parameters are discarded. It does, however, 
167- determine whether the log prior, likelihood, or joint is returned, based on 
168- which accumulators are set in it. 
201+ Evaluate the log density of the given `model` at the given parameter values 
202+ `x`, using the given `varinfo`. Note that the `varinfo` argument is provided 
203+ only for its structure, in the sense that the parameters from the vector `x` 
204+ are inserted into it, and its own parameters are discarded. `getlogdensity` is 
205+ the function that extracts the log density from the evaluated varinfo. 
169206""" 
170- function  logdensity_at (x:: AbstractVector , model:: Model , varinfo:: AbstractVarInfo )
207+ function  logdensity_at (
208+     x:: AbstractVector , model:: Model , getlogdensity:: Function , varinfo:: AbstractVarInfo 
209+ )
171210    varinfo_new =  unflatten (varinfo, x)
172211    varinfo_eval =  last (evaluate!! (model, varinfo_new))
173-     has_prior =  hasacc (varinfo_eval, Val (:LogPrior ))
174-     has_likelihood =  hasacc (varinfo_eval, Val (:LogLikelihood ))
175-     if  has_prior &&  has_likelihood
176-         return  getlogjoint (varinfo_eval)
177-     elseif  has_prior
178-         return  getlogprior (varinfo_eval)
179-     elseif  has_likelihood
180-         return  getloglikelihood (varinfo_eval)
181-     else 
182-         error (" LogDensityFunction: varinfo tracks neither log prior nor log likelihood" 
183-     end 
212+     return  getlogdensity (varinfo_eval)
184213end 
185214
186215""" 
187-     LogDensityAt{M<:Model,V<:AbstractVarInfo}( 
216+     LogDensityAt{M<:Model,F<:Function, V<:AbstractVarInfo}( 
188217        model::M 
218+         getlogdensity::F, 
189219        varinfo::V 
190220    ) 
191221
192222A callable struct that serves the same purpose as `x -> logdensity_at(x, model, 
193- varinfo)`. 
223+ getlogdensity,  varinfo)`.
194224""" 
195- struct  LogDensityAt{M<: Model ,V<: AbstractVarInfo }
225+ struct  LogDensityAt{M<: Model ,F <: Function , V<: AbstractVarInfo }
196226    model:: M 
227+     getlogdensity:: F 
197228    varinfo:: V 
198229end 
199- (ld:: LogDensityAt )(x:: AbstractVector ) =  logdensity_at (x, ld. model, ld. varinfo)
230+ function  (ld:: LogDensityAt )(x:: AbstractVector )
231+     return  logdensity_at (x, ld. model, ld. getlogdensity, ld. varinfo)
232+ end 
200233
201234# ## LogDensityProblems interface
202235
203236function  LogDensityProblems. capabilities (
204-     :: Type{<:LogDensityFunction{M,V,Nothing}} 
205- ) where  {M,V}
237+     :: Type{<:LogDensityFunction{M,F, V,Nothing}} 
238+ ) where  {M,F, V}
206239    return  LogDensityProblems. LogDensityOrder {0} ()
207240end 
208241function  LogDensityProblems. capabilities (
209-     :: Type{<:LogDensityFunction{M,V,AD}} 
210- ) where  {M,V,AD<: ADTypes.AbstractADType }
242+     :: Type{<:LogDensityFunction{M,F, V,AD}} 
243+ ) where  {M,F, V,AD<: ADTypes.AbstractADType }
211244    return  LogDensityProblems. LogDensityOrder {1} ()
212245end 
213246function  LogDensityProblems. logdensity (f:: LogDensityFunction , x:: AbstractVector )
214-     return  logdensity_at (x, f. model, f. varinfo)
247+     return  logdensity_at (x, f. model, f. getlogdensity, f . varinfo)
215248end 
216249function  LogDensityProblems. logdensity_and_gradient (
217-     f:: LogDensityFunction{M,V,AD} , x:: AbstractVector 
218- ) where  {M,V,AD<: ADTypes.AbstractADType }
250+     f:: LogDensityFunction{M,F, V,AD} , x:: AbstractVector 
251+ ) where  {M,F, V,AD<: ADTypes.AbstractADType }
219252    f. prep ===  nothing  && 
220253        error (" Gradient preparation not available; this should not happen" 
221254    x =  map (identity, x)  #  Concretise type
222255    #  Make branching statically inferrable, i.e. type-stable (even if the two
223256    #  branches happen to return different types)
224257    return  if  use_closure (f. adtype)
225-         DI. value_and_gradient (LogDensityAt (f. model, f. varinfo), f. prep, f. adtype, x)
258+         DI. value_and_gradient (
259+             LogDensityAt (f. model, f. getlogdensity, f. varinfo), f. prep, f. adtype, x
260+         )
226261    else 
227262        DI. value_and_gradient (
228-             logdensity_at, f. prep, f. adtype, x, DI. Constant (f. model), DI. Constant (f. varinfo)
263+             logdensity_at,
264+             f. prep,
265+             f. adtype,
266+             x,
267+             DI. Constant (f. model),
268+             DI. Constant (f. getlogdensity),
269+             DI. Constant (f. varinfo),
229270        )
230271    end 
231272end 
@@ -264,9 +305,9 @@ There are two ways of dealing with this:
264305
2653061. Construct a closure over the model, i.e. let g = Base.Fix1(logdensity, f) 
266307
267- 2. Use a constant context.  This lets us pass a two-argument function to 
268-    DifferentiationInterface,  as long as we also give it the 'inactive argument' 
269-    (i.e. the model) wrapped  in `DI.Constant`. 
308+ 2. Use a constant DI.Context.  This lets us pass a two-argument function to DI,  
309+    as long as we also give it the 'inactive argument' (i.e. the model) wrapped  
310+    in `DI.Constant`. 
270311
271312The relative performance of the two approaches, however, depends on the AD 
272313backend used. Some benchmarks are provided here: 
@@ -292,7 +333,7 @@ getmodel(f::DynamicPPL.LogDensityFunction) = f.model
292333Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. 
293334""" 
294335function  setmodel (f:: DynamicPPL.LogDensityFunction , model:: DynamicPPL.Model )
295-     return  LogDensityFunction (model, f. varinfo; adtype= f. adtype)
336+     return  LogDensityFunction (model, f. getlogdensity, f . varinfo; adtype= f. adtype)
296337end 
297338
298339""" 
0 commit comments