Skip to content

Commit 38fd931

Browse files
committed
Combine LogDensityFunction{,WithGrad} into one
1 parent 5b05ad3 commit 38fd931

File tree

2 files changed

+155
-118
lines changed

2 files changed

+155
-118
lines changed

src/logdensityfunction.jl

Lines changed: 142 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
import DifferentiationInterface as DI
22

33
"""
4-
LogDensityFunction
4+
LogDensityFunction(
5+
model::Model,
6+
varinfo::AbstractVarInfo=VarInfo(model),
7+
context::AbstractContext=DefaultContext(),
8+
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing)
59
610
A callable representing a log density function of a `model`.
7-
`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface,
8-
but only to 0th-order, i.e. it is only possible to calculate the log density,
9-
and not its gradient. If you need to calculate the gradient as well, you have
10-
to construct a [`DynamicPPL.LogDensityFunctionWithGrad`](@ref) object.
11+
12+
`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface.
13+
If the keyword argument `adtype` is provided, the log density function will be
14+
differentiable with respect to its input, and the gradient can be calculated
15+
using the AD backend specified by `adtype`. If not, then `adtype` defaults to
16+
nothing and it will only be possible to calculate the log density, not its
17+
gradient.
1118
1219
# Fields
1320
$(FIELDS)
@@ -50,64 +57,143 @@ julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
5057
true
5158
```
5259
"""
53-
struct LogDensityFunction{V,M,C}
60+
struct LogDensityFunction{
61+
V<:AbstractVarInfo,
62+
M<:Model,
63+
C<:Union{Nothing,AbstractContext},
64+
AD<:Union{Nothing,ADTypes.AbstractADType},
65+
}
5466
"varinfo used for evaluation"
5567
varinfo::V
5668
"model used for evaluation"
5769
model::M
5870
"context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable"
5971
context::C
60-
end
61-
62-
function LogDensityFunction(
63-
model::Model,
64-
varinfo::AbstractVarInfo=VarInfo(model),
65-
context::Union{Nothing,AbstractContext}=nothing,
66-
)
67-
return LogDensityFunction(varinfo, model, context)
68-
end
72+
"AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated"
73+
adtype::AD
74+
"gradient preparation object for the model; used internally only"
75+
prep::Union{Nothing,DI.GradientPrep}
76+
"whether a closure was used for the gradient preparation"
77+
with_closure::Bool
6978

70-
# If a `context` has been specified, we use that. Otherwise we just use the leaf context of `model`.
71-
function getcontext(f::LogDensityFunction)
72-
return f.context === nothing ? leafcontext(f.model.context) : f.context
79+
function LogDensityFunction(
80+
model::Model,
81+
varinfo::AbstractVarInfo=VarInfo(model),
82+
context::AbstractContext=DefaultContext();
83+
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
84+
)
85+
if adtype === nothing
86+
prep = nothing
87+
with_closure = false
88+
else
89+
# Get a set of dummy params to use for prep
90+
x = map(identity, varinfo[:])
91+
with_closure = use_closure(adtype)
92+
if with_closure
93+
prep = DI.prepare_gradient(
94+
x -> logdensity_at(x, model, varinfo, context), adtype, x
95+
)
96+
else
97+
prep = DI.prepare_gradient(
98+
logdensity_at,
99+
adtype,
100+
x,
101+
DI.Constant(model),
102+
DI.Constant(varinfo),
103+
DI.Constant(context),
104+
)
105+
end
106+
with_closure = with_closure
107+
end
108+
return new{typeof(varinfo),typeof(model),typeof(context),typeof(adtype)}(
109+
varinfo, model, context, adtype, prep, with_closure
110+
)
111+
end
73112
end
74113

75114
"""
76-
getmodel(f)
115+
setadtype(f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType})
77116
78-
Return the `DynamicPPL.Model` wrapped in the given log-density function `f`.
79-
"""
80-
getmodel(f::DynamicPPL.LogDensityFunction) = f.model
117+
Set the AD type used for evaluation of log density gradient in the given LogDensityFunction.
118+
This function also performs preparation of the gradient, and sets the `prep`
119+
and `with_closure` fields of the LogDensityFunction.
81120
82-
"""
83-
setmodel(f, model[, adtype])
121+
If `adtype` is `nothing`, the `prep` field will be set to `nothing` as well.
84122
85-
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
123+
This function returns a new LogDensityFunction with the updated AD type, i.e. it does
124+
not mutate the input LogDensityFunction.
86125
"""
87-
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
88-
return Accessors.@set f.model = model
126+
function setadtype(f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType})
127+
return if adtype === f.adtype
128+
f # Avoid recomputing prep if not needed
129+
else
130+
LogDensityFunction(f.model, f.varinfo, f.context; adtype=adtype)
131+
end
89132
end
90133

91134
"""
92-
getparams(f::LogDensityFunction)
93-
94-
Return the parameters of the wrapped varinfo as a vector.
135+
logdensity_at(
136+
x::AbstractVector,
137+
model::Model,
138+
varinfo::AbstractVarInfo,
139+
context::AbstractContext
140+
)
141+
142+
Evaluate the log density of the given `model` at the given parameter values `x`,
143+
using the given `varinfo` and `context`. Note that the `varinfo` argument is provided
144+
only for its structure, in the sense that the parameters from the vector `x` are inserted into
145+
it, and its own parameters are discarded.
95146
"""
96-
getparams(f::LogDensityFunction) = f.varinfo[:]
147+
function logdensity_at(
148+
x::AbstractVector, model::Model, varinfo::AbstractVarInfo, context::AbstractContext
149+
)
150+
varinfo_new = unflatten(varinfo, x)
151+
return getlogp(last(evaluate!!(model, varinfo_new, context)))
152+
end
153+
154+
### LogDensityProblems interface
97155

98-
# LogDensityProblems interface: logp (0th order)
156+
function LogDensityProblems.capabilities(
157+
::Type{<:LogDensityFunction{V,M,C,Nothing}}
158+
) where {V,M,C}
159+
return LogDensityProblems.LogDensityOrder{0}()
160+
end
161+
function LogDensityProblems.capabilities(
162+
::Type{<:LogDensityFunction{V,M,C,AD}}
163+
) where {V,M,C,AD<:ADTypes.AbstractADType}
164+
return LogDensityProblems.LogDensityOrder{1}()
165+
end
99166
function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector)
100-
context = getcontext(f)
101-
vi_new = unflatten(f.varinfo, x)
102-
return getlogp(last(evaluate!!(f.model, vi_new, context)))
167+
return logdensity_at(x, f.model, f.varinfo, f.context)
103168
end
104-
function LogDensityProblems.capabilities(::Type{<:LogDensityFunction})
105-
return LogDensityProblems.LogDensityOrder{0}()
169+
function LogDensityProblems.logdensity_and_gradient(
170+
f::LogDensityFunction{V,M,C,AD}, x::AbstractVector
171+
) where {V,M,C,AD<:ADTypes.AbstractADType}
172+
f.prep === nothing &&
173+
error("Gradient preparation not available; this should not happen")
174+
x = map(identity, x) # Concretise type
175+
return if f.with_closure
176+
DI.value_and_gradient(
177+
x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.adtype, x
178+
)
179+
else
180+
DI.value_and_gradient(
181+
logdensity_at,
182+
f.prep,
183+
f.adtype,
184+
x,
185+
DI.Constant(f.model),
186+
DI.Constant(f.varinfo),
187+
DI.Constant(f.context),
188+
)
189+
end
106190
end
191+
107192
# TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)?
108193
LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f))
109194

110-
# LogDensityProblems interface: gradient (1st order)
195+
### Utils
196+
111197
"""
112198
use_closure(adtype::ADTypes.AbstractADType)
113199
@@ -138,76 +224,30 @@ use_closure(::ADTypes.AutoForwardDiff) = false
138224
use_closure(::ADTypes.AutoMooncake) = false
139225
use_closure(::ADTypes.AutoReverseDiff) = true
140226

141-
"""
142-
_flipped_logdensity(f::LogDensityFunction, x::AbstractVector)
143-
144-
This function is the same as `LogDensityProblems.logdensity(f, x)` but with the
145-
arguments flipped. It is used in the 'constant' approach to DifferentiationInterface
146-
(see `use_closure` for more information).
147-
"""
148-
function _flipped_logdensity(x::AbstractVector, f::LogDensityFunction)
149-
return LogDensityProblems.logdensity(f, x)
227+
# If a `context` has been specified, we use that. Otherwise we just use the leaf context of `model`.
228+
function getcontext(f::LogDensityFunction)
229+
return f.context === nothing ? leafcontext(f.model.context) : f.context
150230
end
151231

152232
"""
153-
LogDensityFunctionWithGrad(ldf::DynamicPPL.LogDensityFunction, adtype::ADTypes.AbstractADType)
154-
155-
A callable representing a log density function of a `model`.
156-
`DynamicPPL.LogDensityFunctionWithGrad` implements the LogDensityProblems.jl
157-
interface to 1st-order, meaning that you can both calculate the log density
158-
using
159-
160-
LogDensityProblems.logdensity(f, x)
233+
getmodel(f)
161234
162-
and its gradient using
235+
Return the `DynamicPPL.Model` wrapped in the given log-density function `f`.
236+
"""
237+
getmodel(f::DynamicPPL.LogDensityFunction) = f.model
163238

164-
LogDensityProblems.logdensity_and_gradient(f, x)
239+
"""
240+
setmodel(f, model[, adtype])
165241
166-
where `f` is a `LogDensityFunctionWithGrad` object and `x` is a vector of parameters.
242+
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
243+
"""
244+
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
245+
return Accessors.@set f.model = model
246+
end
167247

168-
# Fields
169-
$(FIELDS)
170248
"""
171-
struct LogDensityFunctionWithGrad{V,M,C,TAD<:ADTypes.AbstractADType}
172-
ldf::LogDensityFunction{V,M,C}
173-
adtype::TAD
174-
prep::DI.GradientPrep
175-
with_closure::Bool
249+
getparams(f::LogDensityFunction)
176250
177-
function LogDensityFunctionWithGrad(
178-
ldf::LogDensityFunction{V,M,C}, adtype::TAD
179-
) where {V,M,C,TAD}
180-
# Get a set of dummy params to use for prep
181-
x = map(identity, getparams(ldf))
182-
with_closure = use_closure(adtype)
183-
if with_closure
184-
prep = DI.prepare_gradient(
185-
Base.Fix1(LogDensityProblems.logdensity, ldf), adtype, x
186-
)
187-
else
188-
prep = DI.prepare_gradient(_flipped_logdensity, adtype, x, DI.Constant(ldf))
189-
end
190-
# Store the prep with the struct. We also store whether a closure was used because
191-
# we need to know this when calling `DI.value_and_gradient`. In practice we could
192-
# recalculate it, but this runs the risk of introducing inconsistencies.
193-
return new{V,M,C,TAD}(ldf, adtype, prep, with_closure)
194-
end
195-
end
196-
function LogDensityProblems.logdensity(f::LogDensityFunctionWithGrad)
197-
return LogDensityProblems.logdensity(f.ldf)
198-
end
199-
function LogDensityProblems.capabilities(::Type{<:LogDensityFunctionWithGrad})
200-
return LogDensityProblems.LogDensityOrder{1}()
201-
end
202-
function LogDensityProblems.logdensity_and_gradient(
203-
f::LogDensityFunctionWithGrad, x::AbstractVector
204-
)
205-
x = map(identity, x) # Concretise type
206-
return if f.with_closure
207-
DI.value_and_gradient(
208-
Base.Fix1(LogDensityProblems.logdensity, f.ldf), f.prep, f.adtype, x
209-
)
210-
else
211-
DI.value_and_gradient(_flipped_logdensity, f.prep, f.adtype, x, DI.Constant(f.ldf))
212-
end
213-
end
251+
Return the parameters of the wrapped varinfo as a vector.
252+
"""
253+
getparams(f::LogDensityFunction) = f.varinfo[:]

test/ad.jl

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using DynamicPPL: LogDensityFunction, LogDensityFunctionWithGrad
1+
using DynamicPPL: LogDensityFunction
22

33
@testset "AD: ForwardDiff, ReverseDiff, and Mooncake" begin
44
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
@@ -10,11 +10,9 @@ using DynamicPPL: LogDensityFunction, LogDensityFunctionWithGrad
1010
f = LogDensityFunction(m, varinfo)
1111
x = DynamicPPL.getparams(f)
1212
# Calculate reference logp + gradient of logp using ForwardDiff
13-
default_adtype = ADTypes.AutoForwardDiff()
14-
ldf_with_grad = LogDensityFunctionWithGrad(f, default_adtype)
15-
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(
16-
ldf_with_grad, x
17-
)
13+
ref_adtype = ADTypes.AutoForwardDiff()
14+
ref_ldf = LogDensityFunction(m, varinfo; adtype=ref_adtype)
15+
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x)
1816

1917
@testset "$adtype" for adtype in [
2018
AutoReverseDiff(; compile=false),
@@ -33,20 +31,18 @@ using DynamicPPL: LogDensityFunction, LogDensityFunctionWithGrad
3331
# Mooncake doesn't work with several combinations of SimpleVarInfo.
3432
if is_mooncake && is_1_11 && is_svi_vnv
3533
# https://github.com/compintell/Mooncake.jl/issues/470
36-
@test_throws ArgumentError LogDensityFunctionWithGrad(f, adtype)
34+
@test_throws ArgumentError DynamicPPL.setadtype(ref_ldf, adtype)
3735
elseif is_mooncake && is_1_10 && is_svi_vnv
3836
# TODO: report upstream
39-
@test_throws UndefRefError LogDensityFunctionWithGrad(f, adtype)
37+
@test_throws UndefRefError DynamicPPL.setadtype(ref_ldf, adtype)
4038
elseif is_mooncake && is_1_10 && is_svi_od
4139
# TODO: report upstream
42-
@test_throws Mooncake.MooncakeRuleCompilationError LogDensityFunctionWithGrad(
43-
f, adtype
40+
@test_throws Mooncake.MooncakeRuleCompilationError setadtype(
41+
ref_ldf, adtype
4442
)
4543
else
46-
ldf_with_grad = LogDensityFunctionWithGrad(f, adtype)
47-
logp, grad = LogDensityProblems.logdensity_and_gradient(
48-
ldf_with_grad, x
49-
)
44+
ldf = DynamicPPL.setadtype(ref_ldf, adtype)
45+
logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x)
5046
@test grad ref_grad
5147
@test logp ref_logp
5248
end
@@ -90,8 +86,9 @@ using DynamicPPL: LogDensityFunction, LogDensityFunctionWithGrad
9086
# Compiling the ReverseDiff tape used to fail here
9187
spl = Sampler(MyEmptyAlg())
9288
vi = VarInfo(model)
93-
ldf = LogDensityFunction(vi, model, SamplingContext(spl))
94-
ldf_grad = LogDensityFunctionWithGrad(ldf, AutoReverseDiff(; compile=true))
89+
ldf = LogDensityFunction(
90+
model, vi, SamplingContext(spl); adtype=AutoReverseDiff(; compile=true)
91+
)
9592
@test LogDensityProblems.logdensity_and_gradient(ldf_grad, vi[:]) isa Any
9693
end
9794
end

0 commit comments

Comments
 (0)