|  | 
|  | 1 | +module DynamicPPLMarginalLogDensitiesExt | 
|  | 2 | + | 
|  | 3 | +using DynamicPPL: DynamicPPL, LogDensityProblems, VarName | 
|  | 4 | +using MarginalLogDensities: MarginalLogDensities | 
|  | 5 | + | 
|  | 6 | +# A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by | 
|  | 7 | +# MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type | 
|  | 8 | +# below. | 
|  | 9 | +struct LogDensityFunctionWrapper{L<:DynamicPPL.LogDensityFunction} | 
|  | 10 | +    logdensity::L | 
|  | 11 | +end | 
|  | 12 | +function (lw::LogDensityFunctionWrapper)(x, _) | 
|  | 13 | +    return LogDensityProblems.logdensity(lw.logdensity, x) | 
|  | 14 | +end | 
|  | 15 | + | 
|  | 16 | +""" | 
|  | 17 | +    marginalize( | 
|  | 18 | +        model::DynamicPPL.Model, | 
|  | 19 | +        marginalized_varnames::AbstractVector{<:VarName}; | 
|  | 20 | +        varinfo::DynamicPPL.AbstractVarInfo=link(VarInfo(model), model), | 
|  | 21 | +        getlogprob=DynamicPPL.getlogjoint, | 
|  | 22 | +        method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(); | 
|  | 23 | +        kwargs..., | 
|  | 24 | +    ) | 
|  | 25 | +
 | 
|  | 26 | +Construct a `MarginalLogDensities.MarginalLogDensity` object that represents the marginal | 
|  | 27 | +log-density of the given `model`, after marginalizing out the variables specified in | 
|  | 28 | +`varnames`. | 
|  | 29 | +
 | 
|  | 30 | +The resulting object can be called with a vector of parameter values to compute the marginal | 
|  | 31 | +log-density. | 
|  | 32 | +
 | 
|  | 33 | +## Keyword arguments | 
|  | 34 | +
 | 
|  | 35 | +- `varinfo`: The `varinfo` to use for the model. By default we use a linked `VarInfo`, | 
|  | 36 | +   meaning that the resulting log-density function accepts parameters that have been | 
|  | 37 | +   transformed to unconstrained space. | 
|  | 38 | +
 | 
|  | 39 | +- `getlogprob`: A function which specifies which kind of marginal log-density to compute. | 
|  | 40 | +   Its default value is `DynamicPPL.getlogjoint` which returns the marginal log-joint | 
|  | 41 | +   probability. | 
|  | 42 | +
 | 
|  | 43 | +- `method`: The marginalization method; defaults to a Laplace approximation. Please see [the | 
|  | 44 | +   MarginalLogDensities.jl package](https://github.com/ElOceanografo/MarginalLogDensities.jl/) | 
|  | 45 | +   for other options. | 
|  | 46 | +
 | 
|  | 47 | +- Other keyword arguments are passed to the `MarginalLogDensities.MarginalLogDensity` | 
|  | 48 | +  constructor. | 
|  | 49 | +
 | 
|  | 50 | +## Example | 
|  | 51 | +
 | 
|  | 52 | +```jldoctest | 
|  | 53 | +julia> using DynamicPPL, Distributions, MarginalLogDensities | 
|  | 54 | +
 | 
|  | 55 | +julia> @model function demo() | 
|  | 56 | +           x ~ Normal(1.0) | 
|  | 57 | +           y ~ Normal(2.0) | 
|  | 58 | +       end | 
|  | 59 | +demo (generic function with 2 methods) | 
|  | 60 | +
 | 
|  | 61 | +julia> marginalized = marginalize(demo(), [:x]); | 
|  | 62 | +
 | 
|  | 63 | +julia> # The resulting callable computes the marginal log-density of `y`. | 
|  | 64 | +       marginalized([1.0]) | 
|  | 65 | +-1.4189385332046727 | 
|  | 66 | +
 | 
|  | 67 | +julia> logpdf(Normal(2.0), 1.0) | 
|  | 68 | +-1.4189385332046727 | 
|  | 69 | +``` | 
|  | 70 | +
 | 
|  | 71 | +
 | 
|  | 72 | +!!! warning | 
|  | 73 | +
 | 
|  | 74 | +    The default usage of linked VarInfo means that, for example, optimization of the | 
|  | 75 | +    marginal log-density can be performed in unconstrained space. However, care must be | 
|  | 76 | +    taken if the model contains variables where the link transformation depends on a | 
|  | 77 | +    marginalized variable. For example: | 
|  | 78 | +
 | 
|  | 79 | +    ```julia | 
|  | 80 | +    @model function f() | 
|  | 81 | +        x ~ Normal() | 
|  | 82 | +        y ~ truncated(Normal(); lower=x) | 
|  | 83 | +    end | 
|  | 84 | +    ``` | 
|  | 85 | +
 | 
|  | 86 | +    Here, the support of `y`, and hence the link transformation used, depends on the value | 
|  | 87 | +    of `x`. If we now marginalize over `x`, we obtain a function mapping linked values of | 
|  | 88 | +    `y` to log-probabilities. However, it will not be possible to use DynamicPPL to | 
|  | 89 | +    correctly retrieve _unlinked_ values of `y`. | 
|  | 90 | +""" | 
|  | 91 | +function DynamicPPL.marginalize( | 
|  | 92 | +    model::DynamicPPL.Model, | 
|  | 93 | +    marginalized_varnames::AbstractVector{<:VarName}; | 
|  | 94 | +    varinfo::DynamicPPL.AbstractVarInfo=DynamicPPL.link(DynamicPPL.VarInfo(model), model), | 
|  | 95 | +    getlogprob::Function=DynamicPPL.getlogjoint, | 
|  | 96 | +    method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(), | 
|  | 97 | +    kwargs..., | 
|  | 98 | +) | 
|  | 99 | +    # Determine the indices for the variables to marginalise out. | 
|  | 100 | +    varindices = reduce(vcat, DynamicPPL.vector_getranges(varinfo, marginalized_varnames)) | 
|  | 101 | +    # Construct the marginal log-density model. | 
|  | 102 | +    f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo) | 
|  | 103 | +    mld = MarginalLogDensities.MarginalLogDensity( | 
|  | 104 | +        LogDensityFunctionWrapper(f), varinfo[:], varindices, (), method; kwargs... | 
|  | 105 | +    ) | 
|  | 106 | +    return mld | 
|  | 107 | +end | 
|  | 108 | + | 
|  | 109 | +""" | 
|  | 110 | +    VarInfo( | 
|  | 111 | +        mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper}, | 
|  | 112 | +        unmarginalized_params::Union{AbstractVector,Nothing}=nothing | 
|  | 113 | +    ) | 
|  | 114 | +
 | 
|  | 115 | +Retrieve the `VarInfo` object used in the marginalisation process. | 
|  | 116 | +
 | 
|  | 117 | +If a Laplace approximation was used for the marginalisation, the values of the marginalized | 
|  | 118 | +parameters are also set to their mode (note that this only happens if the `mld` object has | 
|  | 119 | +been used to compute the marginal log-density at least once, so that the mode has been | 
|  | 120 | +computed). | 
|  | 121 | +
 | 
|  | 122 | +If a vector of `unmarginalized_params` is specified, the values for the corresponding | 
|  | 123 | +parameters will also be updated in the returned VarInfo. This vector may be obtained e.g. by | 
|  | 124 | +performing an optimization of the marginal log-density. | 
|  | 125 | +
 | 
|  | 126 | +All other aspects of the VarInfo, such as link status, are preserved from the original | 
|  | 127 | +VarInfo used in the marginalisation. | 
|  | 128 | +
 | 
|  | 129 | +!!! note | 
|  | 130 | +
 | 
|  | 131 | +    The other fields of the VarInfo, e.g. accumulated log-probabilities, will not be | 
|  | 132 | +    updated. If you wish to have a fully consistent VarInfo, you should re-evaluate the | 
|  | 133 | +    model with the returned VarInfo (e.g. using `vi = last(DynamicPPL.evaluate!!(model, | 
|  | 134 | +    vi))`). | 
|  | 135 | +
 | 
|  | 136 | +## Example | 
|  | 137 | +
 | 
|  | 138 | +```jldoctest | 
|  | 139 | +julia> using DynamicPPL, Distributions, MarginalLogDensities | 
|  | 140 | +
 | 
|  | 141 | +julia> @model function demo() | 
|  | 142 | +           x ~ Normal() | 
|  | 143 | +           y ~ Beta(2, 2) | 
|  | 144 | +       end | 
|  | 145 | +demo (generic function with 2 methods) | 
|  | 146 | +
 | 
|  | 147 | +julia> # Note that by default `marginalize` uses a linked VarInfo. | 
|  | 148 | +       mld = marginalize(demo(), [@varname(x)]); | 
|  | 149 | +
 | 
|  | 150 | +julia> using MarginalLogDensities: Optimization, OptimizationOptimJL | 
|  | 151 | +
 | 
|  | 152 | +julia> # Find the mode of the marginal log-density of `y`, with an initial point of `y0`. | 
|  | 153 | +       y0 = 2.0; opt_problem = Optimization.OptimizationProblem(mld, [y0]) | 
|  | 154 | +OptimizationProblem. In-place: true | 
|  | 155 | +u0: 1-element Vector{Float64}: | 
|  | 156 | + 2.0 | 
|  | 157 | +
 | 
|  | 158 | +julia> # This tells us the optimal (linked) value of `y` is around 0. | 
|  | 159 | +       opt_solution = Optimization.solve(opt_problem, OptimizationOptimJL.NelderMead()) | 
|  | 160 | +retcode: Success | 
|  | 161 | +u: 1-element Vector{Float64}: | 
|  | 162 | + 4.88281250001733e-5 | 
|  | 163 | +
 | 
|  | 164 | +julia> # Get the VarInfo corresponding to the mode of `y`. | 
|  | 165 | +       vi = VarInfo(mld, opt_solution.u); | 
|  | 166 | +
 | 
|  | 167 | +julia> # `x` is set to its mode (which for `Normal()` is zero). | 
|  | 168 | +       vi[@varname(x)] | 
|  | 169 | +0.0 | 
|  | 170 | +
 | 
|  | 171 | +julia> # `y` is set to the optimal value we found above. | 
|  | 172 | +       DynamicPPL.getindex_internal(vi, @varname(y)) | 
|  | 173 | +1-element Vector{Float64}: | 
|  | 174 | + 4.88281250001733e-5 | 
|  | 175 | +
 | 
|  | 176 | +julia> # To obtain values in the original constrained space, we can either | 
|  | 177 | +       # use `getindex`: | 
|  | 178 | +       vi[@varname(y)] | 
|  | 179 | +0.5000122070312476 | 
|  | 180 | +
 | 
|  | 181 | +julia> # Or invlink the entire VarInfo object using the model: | 
|  | 182 | +       vi_unlinked = DynamicPPL.invlink(vi, demo()); vi_unlinked[:] | 
|  | 183 | +2-element Vector{Float64}: | 
|  | 184 | + 0.0 | 
|  | 185 | + 0.5000122070312476 | 
|  | 186 | +``` | 
|  | 187 | +""" | 
|  | 188 | +function DynamicPPL.VarInfo( | 
|  | 189 | +    mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper}, | 
|  | 190 | +    unmarginalized_params::Union{AbstractVector,Nothing}=nothing, | 
|  | 191 | +) | 
|  | 192 | +    # Extract the original VarInfo. Its contents will in general be junk. | 
|  | 193 | +    original_vi = mld.logdensity.logdensity.varinfo | 
|  | 194 | +    # Extract the stored parameters, which includes the modes for any marginalized | 
|  | 195 | +    # parameters | 
|  | 196 | +    full_params = MarginalLogDensities.cached_params(mld) | 
|  | 197 | +    # We can then (if needed) set the values for any non-marginalized parameters | 
|  | 198 | +    if unmarginalized_params !== nothing | 
|  | 199 | +        full_params[MarginalLogDensities.ijoint(mld)] = unmarginalized_params | 
|  | 200 | +    end | 
|  | 201 | +    return DynamicPPL.unflatten(original_vi, full_params) | 
|  | 202 | +end | 
|  | 203 | + | 
|  | 204 | +end | 
0 commit comments