Skip to content

Commit aef4b91

Browse files
committed
Combine LogDensityFunction{,WithGrad} into one
1 parent 0e24d97 commit aef4b91

File tree

3 files changed

+165
-122
lines changed

3 files changed

+165
-122
lines changed

docs/src/api.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,10 @@ logjoint
5454

5555
### LogDensityProblems.jl interface
5656

57-
The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction` or `DynamicPPL.LogDensityFunctionWithGrad`.
57+
The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction`.
5858

5959
```@docs
6060
DynamicPPL.LogDensityFunction
61-
DynamicPPL.LogDensityFunctionWithGrad
6261
```
6362

6463
## Condition and decondition

src/logdensityfunction.jl

Lines changed: 150 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,27 @@
11
import DifferentiationInterface as DI
22

33
"""
4-
LogDensityFunction
4+
LogDensityFunction(
5+
model::Model,
6+
varinfo::AbstractVarInfo=VarInfo(model),
7+
context::AbstractContext=DefaultContext(),
8+
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing)
59
6-
A callable representing a log density function of a `model`.
7-
`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface,
8-
but only to 0th-order, i.e. it is only possible to calculate the log density,
9-
and not its gradient. If you need to calculate the gradient as well, you have
10-
to construct a [`DynamicPPL.LogDensityFunctionWithGrad`](@ref) object.
10+
A struct which contains a model, along with all the information necessary to:
11+
12+
- calculate its log density at a given point;
13+
- and if `adtype` is provided, calculate the gradient of the log density at that point.
14+
15+
At its most basic level, a LogDensityFunction wraps the model together with its
16+
the type of varinfo to be used, as well as the evaluation context. These must
17+
be known in order to calculate the log density (using [`DynamicPPL.evaluate!!`](@ref)).
18+
19+
If `adtype` is provided, then this struct will also contain the adtype along with
20+
other information for efficient calculation of the gradient of the log density.
21+
22+
`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface.
23+
If `adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a
24+
concrete AD backend type, then `logdensity_and_gradient` is also implemented.
1125
1226
# Fields
1327
$(FIELDS)
@@ -50,64 +64,143 @@ julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
5064
true
5165
```
5266
"""
53-
struct LogDensityFunction{V,M,C}
67+
struct LogDensityFunction{
68+
V<:AbstractVarInfo,
69+
M<:Model,
70+
C<:Union{Nothing,AbstractContext},
71+
AD<:Union{Nothing,ADTypes.AbstractADType},
72+
}
5473
"varinfo used for evaluation"
5574
varinfo::V
5675
"model used for evaluation"
5776
model::M
5877
"context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable"
5978
context::C
60-
end
61-
62-
function LogDensityFunction(
63-
model::Model,
64-
varinfo::AbstractVarInfo=VarInfo(model),
65-
context::Union{Nothing,AbstractContext}=nothing,
66-
)
67-
return LogDensityFunction(varinfo, model, context)
68-
end
79+
"AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated"
80+
adtype::AD
81+
"(internal use only) gradient preparation object for the model"
82+
prep::Union{Nothing,DI.GradientPrep}
83+
"(internal use only) whether a closure was used for the gradient preparation"
84+
with_closure::Bool
6985

70-
# If a `context` has been specified, we use that. Otherwise we just use the leaf context of `model`.
71-
function getcontext(f::LogDensityFunction)
72-
return f.context === nothing ? leafcontext(f.model.context) : f.context
86+
function LogDensityFunction(
87+
model::Model,
88+
varinfo::AbstractVarInfo=VarInfo(model),
89+
context::AbstractContext=DefaultContext();
90+
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
91+
)
92+
if adtype === nothing
93+
prep = nothing
94+
with_closure = false
95+
else
96+
# Get a set of dummy params to use for prep
97+
x = map(identity, varinfo[:])
98+
with_closure = use_closure(adtype)
99+
if with_closure
100+
prep = DI.prepare_gradient(
101+
x -> logdensity_at(x, model, varinfo, context), adtype, x
102+
)
103+
else
104+
prep = DI.prepare_gradient(
105+
logdensity_at,
106+
adtype,
107+
x,
108+
DI.Constant(model),
109+
DI.Constant(varinfo),
110+
DI.Constant(context),
111+
)
112+
end
113+
with_closure = with_closure
114+
end
115+
return new{typeof(varinfo),typeof(model),typeof(context),typeof(adtype)}(
116+
varinfo, model, context, adtype, prep, with_closure
117+
)
118+
end
73119
end
74120

75121
"""
76-
getmodel(f)
122+
setadtype(f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType})
77123
78-
Return the `DynamicPPL.Model` wrapped in the given log-density function `f`.
79-
"""
80-
getmodel(f::DynamicPPL.LogDensityFunction) = f.model
124+
Set the AD type used for evaluation of log density gradient in the given LogDensityFunction.
125+
This function also performs preparation of the gradient, and sets the `prep`
126+
and `with_closure` fields of the LogDensityFunction.
81127
82-
"""
83-
setmodel(f, model[, adtype])
128+
If `adtype` is `nothing`, the `prep` field will be set to `nothing` as well.
84129
85-
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
130+
This function returns a new LogDensityFunction with the updated AD type, i.e. it does
131+
not mutate the input LogDensityFunction.
86132
"""
87-
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
88-
return Accessors.@set f.model = model
133+
function setadtype(f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType})
134+
return if adtype === f.adtype
135+
f # Avoid recomputing prep if not needed
136+
else
137+
LogDensityFunction(f.model, f.varinfo, f.context; adtype=adtype)
138+
end
89139
end
90140

91141
"""
92-
getparams(f::LogDensityFunction)
93-
94-
Return the parameters of the wrapped varinfo as a vector.
142+
logdensity_at(
143+
x::AbstractVector,
144+
model::Model,
145+
varinfo::AbstractVarInfo,
146+
context::AbstractContext
147+
)
148+
149+
Evaluate the log density of the given `model` at the given parameter values `x`,
150+
using the given `varinfo` and `context`. Note that the `varinfo` argument is provided
151+
only for its structure, in the sense that the parameters from the vector `x` are inserted into
152+
it, and its own parameters are discarded.
95153
"""
96-
getparams(f::LogDensityFunction) = f.varinfo[:]
154+
function logdensity_at(
155+
x::AbstractVector, model::Model, varinfo::AbstractVarInfo, context::AbstractContext
156+
)
157+
varinfo_new = unflatten(varinfo, x)
158+
return getlogp(last(evaluate!!(model, varinfo_new, context)))
159+
end
160+
161+
### LogDensityProblems interface
97162

98-
# LogDensityProblems interface: logp (0th order)
163+
function LogDensityProblems.capabilities(
164+
::Type{<:LogDensityFunction{V,M,C,Nothing}}
165+
) where {V,M,C}
166+
return LogDensityProblems.LogDensityOrder{0}()
167+
end
168+
function LogDensityProblems.capabilities(
169+
::Type{<:LogDensityFunction{V,M,C,AD}}
170+
) where {V,M,C,AD<:ADTypes.AbstractADType}
171+
return LogDensityProblems.LogDensityOrder{1}()
172+
end
99173
function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector)
100-
context = getcontext(f)
101-
vi_new = unflatten(f.varinfo, x)
102-
return getlogp(last(evaluate!!(f.model, vi_new, context)))
174+
return logdensity_at(x, f.model, f.varinfo, f.context)
103175
end
104-
function LogDensityProblems.capabilities(::Type{<:LogDensityFunction})
105-
return LogDensityProblems.LogDensityOrder{0}()
176+
function LogDensityProblems.logdensity_and_gradient(
177+
f::LogDensityFunction{V,M,C,AD}, x::AbstractVector
178+
) where {V,M,C,AD<:ADTypes.AbstractADType}
179+
f.prep === nothing &&
180+
error("Gradient preparation not available; this should not happen")
181+
x = map(identity, x) # Concretise type
182+
return if f.with_closure
183+
DI.value_and_gradient(
184+
x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.adtype, x
185+
)
186+
else
187+
DI.value_and_gradient(
188+
logdensity_at,
189+
f.prep,
190+
f.adtype,
191+
x,
192+
DI.Constant(f.model),
193+
DI.Constant(f.varinfo),
194+
DI.Constant(f.context),
195+
)
196+
end
106197
end
198+
107199
# TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)?
108200
LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f))
109201

110-
# LogDensityProblems interface: gradient (1st order)
202+
### Utils
203+
111204
"""
112205
use_closure(adtype::ADTypes.AbstractADType)
113206
@@ -138,76 +231,30 @@ use_closure(::ADTypes.AutoForwardDiff) = false
138231
use_closure(::ADTypes.AutoMooncake) = false
139232
use_closure(::ADTypes.AutoReverseDiff) = true
140233

141-
"""
142-
_flipped_logdensity(f::LogDensityFunction, x::AbstractVector)
143-
144-
This function is the same as `LogDensityProblems.logdensity(f, x)` but with the
145-
arguments flipped. It is used in the 'constant' approach to DifferentiationInterface
146-
(see `use_closure` for more information).
147-
"""
148-
function _flipped_logdensity(x::AbstractVector, f::LogDensityFunction)
149-
return LogDensityProblems.logdensity(f, x)
234+
# If a `context` has been specified, we use that. Otherwise we just use the leaf context of `model`.
235+
function getcontext(f::LogDensityFunction)
236+
return f.context === nothing ? leafcontext(f.model.context) : f.context
150237
end
151238

152239
"""
153-
LogDensityFunctionWithGrad(ldf::DynamicPPL.LogDensityFunction, adtype::ADTypes.AbstractADType)
154-
155-
A callable representing a log density function of a `model`.
156-
`DynamicPPL.LogDensityFunctionWithGrad` implements the LogDensityProblems.jl
157-
interface to 1st-order, meaning that you can both calculate the log density
158-
using
159-
160-
LogDensityProblems.logdensity(f, x)
240+
getmodel(f)
161241
162-
and its gradient using
242+
Return the `DynamicPPL.Model` wrapped in the given log-density function `f`.
243+
"""
244+
getmodel(f::DynamicPPL.LogDensityFunction) = f.model
163245

164-
LogDensityProblems.logdensity_and_gradient(f, x)
246+
"""
247+
setmodel(f, model[, adtype])
165248
166-
where `f` is a `LogDensityFunctionWithGrad` object and `x` is a vector of parameters.
249+
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
250+
"""
251+
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
252+
return Accessors.@set f.model = model
253+
end
167254

168-
# Fields
169-
$(FIELDS)
170255
"""
171-
struct LogDensityFunctionWithGrad{V,M,C,TAD<:ADTypes.AbstractADType}
172-
ldf::LogDensityFunction{V,M,C}
173-
adtype::TAD
174-
prep::DI.GradientPrep
175-
with_closure::Bool
256+
getparams(f::LogDensityFunction)
176257
177-
function LogDensityFunctionWithGrad(
178-
ldf::LogDensityFunction{V,M,C}, adtype::TAD
179-
) where {V,M,C,TAD}
180-
# Get a set of dummy params to use for prep
181-
x = map(identity, getparams(ldf))
182-
with_closure = use_closure(adtype)
183-
if with_closure
184-
prep = DI.prepare_gradient(
185-
Base.Fix1(LogDensityProblems.logdensity, ldf), adtype, x
186-
)
187-
else
188-
prep = DI.prepare_gradient(_flipped_logdensity, adtype, x, DI.Constant(ldf))
189-
end
190-
# Store the prep with the struct. We also store whether a closure was used because
191-
# we need to know this when calling `DI.value_and_gradient`. In practice we could
192-
# recalculate it, but this runs the risk of introducing inconsistencies.
193-
return new{V,M,C,TAD}(ldf, adtype, prep, with_closure)
194-
end
195-
end
196-
function LogDensityProblems.logdensity(f::LogDensityFunctionWithGrad)
197-
return LogDensityProblems.logdensity(f.ldf)
198-
end
199-
function LogDensityProblems.capabilities(::Type{<:LogDensityFunctionWithGrad})
200-
return LogDensityProblems.LogDensityOrder{1}()
201-
end
202-
function LogDensityProblems.logdensity_and_gradient(
203-
f::LogDensityFunctionWithGrad, x::AbstractVector
204-
)
205-
x = map(identity, x) # Concretise type
206-
return if f.with_closure
207-
DI.value_and_gradient(
208-
Base.Fix1(LogDensityProblems.logdensity, f.ldf), f.prep, f.adtype, x
209-
)
210-
else
211-
DI.value_and_gradient(_flipped_logdensity, f.prep, f.adtype, x, DI.Constant(f.ldf))
212-
end
213-
end
258+
Return the parameters of the wrapped varinfo as a vector.
259+
"""
260+
getparams(f::LogDensityFunction) = f.varinfo[:]

test/ad.jl

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using DynamicPPL: LogDensityFunction, LogDensityFunctionWithGrad
1+
using DynamicPPL: LogDensityFunction
22

33
@testset "AD: ForwardDiff, ReverseDiff, and Mooncake" begin
44
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
@@ -10,11 +10,9 @@ using DynamicPPL: LogDensityFunction, LogDensityFunctionWithGrad
1010
f = LogDensityFunction(m, varinfo)
1111
x = DynamicPPL.getparams(f)
1212
# Calculate reference logp + gradient of logp using ForwardDiff
13-
default_adtype = ADTypes.AutoForwardDiff()
14-
ldf_with_grad = LogDensityFunctionWithGrad(f, default_adtype)
15-
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(
16-
ldf_with_grad, x
17-
)
13+
ref_adtype = ADTypes.AutoForwardDiff()
14+
ref_ldf = LogDensityFunction(m, varinfo; adtype=ref_adtype)
15+
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x)
1816

1917
@testset "$adtype" for adtype in [
2018
AutoReverseDiff(; compile=false),
@@ -33,20 +31,18 @@ using DynamicPPL: LogDensityFunction, LogDensityFunctionWithGrad
3331
# Mooncake doesn't work with several combinations of SimpleVarInfo.
3432
if is_mooncake && is_1_11 && is_svi_vnv
3533
# https://github.com/compintell/Mooncake.jl/issues/470
36-
@test_throws ArgumentError LogDensityFunctionWithGrad(f, adtype)
34+
@test_throws ArgumentError DynamicPPL.setadtype(ref_ldf, adtype)
3735
elseif is_mooncake && is_1_10 && is_svi_vnv
3836
# TODO: report upstream
39-
@test_throws UndefRefError LogDensityFunctionWithGrad(f, adtype)
37+
@test_throws UndefRefError DynamicPPL.setadtype(ref_ldf, adtype)
4038
elseif is_mooncake && is_1_10 && is_svi_od
4139
# TODO: report upstream
42-
@test_throws Mooncake.MooncakeRuleCompilationError LogDensityFunctionWithGrad(
43-
f, adtype
40+
@test_throws Mooncake.MooncakeRuleCompilationError setadtype(
41+
ref_ldf, adtype
4442
)
4543
else
46-
ldf_with_grad = LogDensityFunctionWithGrad(f, adtype)
47-
logp, grad = LogDensityProblems.logdensity_and_gradient(
48-
ldf_with_grad, x
49-
)
44+
ldf = DynamicPPL.setadtype(ref_ldf, adtype)
45+
logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x)
5046
@test grad ref_grad
5147
@test logp ref_logp
5248
end
@@ -90,8 +86,9 @@ using DynamicPPL: LogDensityFunction, LogDensityFunctionWithGrad
9086
# Compiling the ReverseDiff tape used to fail here
9187
spl = Sampler(MyEmptyAlg())
9288
vi = VarInfo(model)
93-
ldf = LogDensityFunction(vi, model, SamplingContext(spl))
94-
ldf_grad = LogDensityFunctionWithGrad(ldf, AutoReverseDiff(; compile=true))
95-
@test LogDensityProblems.logdensity_and_gradient(ldf_grad, vi[:]) isa Any
89+
ldf = LogDensityFunction(
90+
model, vi, SamplingContext(spl); adtype=AutoReverseDiff(; compile=true)
91+
)
92+
@test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any
9693
end
9794
end

0 commit comments

Comments
 (0)