Skip to content

Conversation

@torfjelde
Copy link
Member

@torfjelde torfjelde commented Apr 27, 2021

I'm just going to put this here before heading to bed, but there's some more stuff to do here:

  • Allow making observations/model arguments symbolic. I've verified that this works, but I just haven't gotten around to automatically constructing Variable for the model arguments yet.
  • Docstrings/tests
  • ???

Anyways, it's pretty dope.

julia> using DynamicPPL, Distributions

julia> @model function demo(x, ::Type{TV} = Vector{Float64}) where {TV}
           # Just to demonstrate that we're not restricted to `x ~ DistKnowAtExpansionTime`
           s_prior = InverseGamma()
           s ~ s_prior

           num_obs = length(x)
           m = TV(undef, num_obs)
           m[1] ~ Normal(0, s)
           x[1] ~ Normal(x[1], s)
           for i = 2:num_obs
               m[i] ~ Normal(m[i - 1], s)
               x[i] ~ Normal(m[i], s)
           end
           return m, x
       end;

julia> m = demo(randn(10) .+ 1);

Expression generation

EDIT: This doesn't work right now as it seems the original rewriters are now out of date. We can still extract an expression using vi, θ Symbolic.symbolize(m); getlogp(vi) but we won't have tracing through the logpdf computation.

julia> lp, θ = Symbolic.symbolic_logp(m) # can drop the constant in front too if we want
(-18.378770664093455 - (2.0log(θ₁)) - (20log(sqrt(θ₁))) - (0.5abs2(θ₂*(sqrt(θ₁)^-1))) - (0.5abs2((θ₁₀ - θ₉)*(sqrt(θ₁)^-1))) - (0.5abs2((θ₁₁ - θ₁₀)*(sqrt(θ₁)^-1))) - (0.5abs2((θ₃ - θ₂)*(sqrt(θ₁)^-1))) - (0.5abs2((θ₄ - θ₃)*(sqrt(θ₁)^-1))) - (0.5abs2((θ₅ - θ₄)*(sqrt(θ₁)^-1))) - (0.5abs2((θ₆ - θ₅)*(sqrt(θ₁)^-1))) - (0.5abs2((θ₇ - θ₆)*(sqrt(θ₁)^-1))) - (0.5abs2((θ₈ - θ₇)*(sqrt(θ₁)^-1))) - (0.5abs2((θ₉ - θ₈)*(sqrt(θ₁)^-1))) - (0.5abs2((sqrt(θ₁)^-1)*(-1.4166318988896798 - θ₁₀))) - (0.5abs2((sqrt(θ₁)^-1)*(1.9786928906469687 - θ₁₁))) - (0.5abs2((sqrt(θ₁)^-1)*(-0.7713281377661316 - θ₃))) - (0.5abs2((sqrt(θ₁)^-1)*(0.5953344430416955 - θ₄))) - (0.5abs2((sqrt(θ₁)^-1)*(0.4888163369370143 - θ₅))) - (0.5abs2((sqrt(θ₁)^-1)*(0.9869369683122523 - θ₆))) - (0.5abs2((sqrt(θ₁)^-1)*(0.18968854568519278 - θ₇))) - (0.5abs2((sqrt(θ₁)^-1)*(2.20765228354745 - θ₈))) - (0.5abs2((sqrt(θ₁)^-1)*(1.1746311120999957 - θ₉))) - (θ₁^-1), Symbolics.Num[θ₁, θ₂, θ₃, θ₄, θ₅, θ₆, θ₇, θ₈, θ₉, θ₁₀, θ₁₁])

julia> ∂lp = Symbolics.gradient(lp, θ)
11-element Vector{Symbolics.Num}:
                                     θ₁^-2 + 0.5(θ₂^2)*(sqrt(θ₁)^-4) + 0.5(sqrt(θ₁)^-4)*((θ₁₀ - θ₉)^2) + 0.5(sqrt(θ₁)^-4)*((θ₁₁ - θ₁₀)^2) + 0.5(sqrt(θ₁)^-4)*((θ₃ - θ₂)^2) + 0.5(sqrt(θ₁)^-4)*((θ₄ - θ₃)^2) + 0.5(sqrt(θ₁)^-4)*((θ₅ - θ₄)^2) + 0.5(sqrt(θ₁)^-4)*((θ₆ - θ₅)^2) + 0.5(sqrt(θ₁)^-4)*((θ₇ - θ₆)^2) + 0.5(sqrt(θ₁)^-4)*((θ₈ - θ₇)^2) + 0.5(sqrt(θ₁)^-4)*((θ₉ - θ₈)^2) + 0.5(sqrt(θ₁)^-4)*((-1.4166318988896798 - θ₁₀)^2) + 0.5(sqrt(θ₁)^-4)*((1.9786928906469687 - θ₁₁)^2) + 0.5(sqrt(θ₁)^-4)*((-0.7713281377661316 - θ₃)^2) + 0.5(sqrt(θ₁)^-4)*((0.5953344430416955 - θ₄)^2) + 0.5(sqrt(θ₁)^-4)*((0.4888163369370143 - θ₅)^2) + 0.5(sqrt(θ₁)^-4)*((0.9869369683122523 - θ₆)^2) + 0.5(sqrt(θ₁)^-4)*((0.18968854568519278 - θ₇)^2) + 0.5(sqrt(θ₁)^-4)*((2.20765228354745 - θ₈)^2) + 0.5(sqrt(θ₁)^-4)*((1.1746311120999957 - θ₉)^2) - (2.0(θ₁^-1)) - ((10//1)*(sqrt(θ₁)^-2))
  (θ₃ - θ₂)*(sqrt(θ₁)^-2) - (θ₂*(sqrt(θ₁)^-2))
   (θ₄ - θ₃)*(sqrt(θ₁)^-2) + (sqrt(θ₁)^-2)*(-0.7713281377661316 - θ₃) - ((θ₃ - θ₂)*(sqrt(θ₁)^-2))
    (θ₅ - θ₄)*(sqrt(θ₁)^-2) + (sqrt(θ₁)^-2)*(0.5953344430416955 - θ₄) - ((θ₄ - θ₃)*(sqrt(θ₁)^-2))
    (θ₆ - θ₅)*(sqrt(θ₁)^-2) + (sqrt(θ₁)^-2)*(0.4888163369370143 - θ₅) - ((θ₅ - θ₄)*(sqrt(θ₁)^-2))
    (θ₇ - θ₆)*(sqrt(θ₁)^-2) + (sqrt(θ₁)^-2)*(0.9869369683122523 - θ₆) - ((θ₆ - θ₅)*(sqrt(θ₁)^-2))
    (θ₈ - θ₇)*(sqrt(θ₁)^-2) + (sqrt(θ₁)^-2)*(0.18968854568519278 - θ₇) - ((θ₇ - θ₆)*(sqrt(θ₁)^-2))
    (θ₉ - θ₈)*(sqrt(θ₁)^-2) + (sqrt(θ₁)^-2)*(2.20765228354745 - θ₈) - ((θ₈ - θ₇)*(sqrt(θ₁)^-2))
   (θ₁₀ - θ₉)*(sqrt(θ₁)^-2) + (sqrt(θ₁)^-2)*(1.1746311120999957 - θ₉) - ((θ₉ - θ₈)*(sqrt(θ₁)^-2))
 (θ₁₁ - θ₁₀)*(sqrt(θ₁)^-2) + (sqrt(θ₁)^-2)*(-1.4166318988896798 - θ₁₀) - ((θ₁₀ - θ₉)*(sqrt(θ₁)^-2))
                              (sqrt(θ₁)^-2)*(1.9786928906469687 - θ₁₁) - ((θ₁₁ - θ₁₀)*(sqrt(θ₁)^-2))

julia> ∂f, ∂f! = Symbolics.build_function(∂lp, θ, expression = false);

julia> ∂f(rand(Symbolics.value(length(θ))))
11-element Vector{Float64}:
 26.085486889785475
 -1.1940138187926816
 -0.709786284651692
 -1.8974997722462343
  3.2298551373276108
  0.07043163767794725
  3.301109716439576
  3.527331147067538
  0.6372253272009718
 -7.676985933953628
  3.717132883407671

Dependencies

julia> Symbolic.dependencies(m) # `VarName` -> `VarName`
Dict{VarName, Vector{T} where T} with 11 entries:
  m[5]  => VarName[m[4], s]
  m[4]  => VarName[m[3], s]
  m[8]  => VarName[m[7], s]
  s     => Union{typeof(var), VarName}[]
  m[6]  => VarName[m[5], s]
  m[3]  => VarName[m[2], s]
  m[1]  => [s]
  m[7]  => VarName[m[6], s]
  m[2]  => VarName[m[1], s]
  m[10] => VarName[m[9], s]
  m[9]  => VarName[m[8], s]

julia> Symbolic.dependencies(m, true) # `Symbol` -> `Symbol`
Dict{Symbolics.Num, Vector{T} where T} with 11 entries:
  θ₉  => SymbolicUtils.Sym{Real, Nothing}[θ₈, θ₁]
  θ₇  => SymbolicUtils.Sym{Real, Nothing}[θ₆, θ₁]
  θ₅  => SymbolicUtils.Sym{Real, Nothing}[θ₄, θ₁]
  θ₆  => SymbolicUtils.Sym{Real, Nothing}[θ₅, θ₁]
  θ₃  => SymbolicUtils.Sym{Real, Nothing}[θ₂, θ₁]
  θ₁₀ => SymbolicUtils.Sym{Real, Nothing}[θ₉, θ₁]
  θ₂  => SymbolicUtils.Sym{Real, Nothing}[θ₁]
  θ₁₁ => SymbolicUtils.Sym{Real, Nothing}[θ₁₀, θ₁]
  θ₁  => Any[]
  θ₄  => SymbolicUtils.Sym{Real, Nothing}[θ₃, θ₁]
  θ₈  => SymbolicUtils.Sym{Real, Nothing}[θ₇, θ₁]

Can of course use LightGraphs.jl, etc. to generate visualizations of the models too:

julia> using LightGraphs, MetaGraphs

julia> function graph(m::Model)
           deps = Symbolic.dependencies(m)

           g = MetaDiGraph(length(deps))
           for (i, vn) in enumerate(keys(deps))
               set_prop!(g, i, :vn, vn)
           end
           set_indexing_prop!(g, :vn)

           for vn in keys(deps)
               for parent in deps[vn]
                   add_edge!(g, (g[parent, :vn], g[vn, :vn]))
               end
           end
           return g
       end
graph (generic function with 1 method)

julia> g = graph(m)
{11, 19} directed Int64 metagraph with Float64 weights defined by :weight (default weight 1.0)

julia> using GraphRecipes, Plots

julia> nodelabels = map(vertices(g)) do v
           get_prop(g, v, :vn)
       end;

julia> graphplot(g, names=nodelabels)

Resulting in:

Screenshot_20210427_051434

)
increment_num_produce!(vi)
return Distributions.loglikelihood(dist, value)
return sum(Distributions.logpdf(dist, value))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is left-over from some earlier experimentation I did. We should make loglikelihood a primitive, and then work from there I think.

@torfjelde torfjelde marked this pull request as draft July 24, 2021 15:36
Comment on lines +23 to +41
function symbolize(
rng::Random.AbstractRNG,
m::Model,
vi::VarInfo=VarInfo(m);
spl=SampleFromPrior(),
ctx=DefaultContext(),
include_data=false,
)
m(rng, vi, spl, ctx)
θ_orig = vi[spl]

# Symbolic `logpdf` for fixed observations.
# TODO: don't `collect` once symbolic arrays are mature enough.
Symbolics.@variables θ[1:length(θ_orig)]
vi = VarInfo{Real}(vi, spl, θ)
m(vi, ctx)

return vi, θ
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Execute model once to get shape of latent variables.
  2. Construct symbolic variables.
  3. Execute model on symbolic variables.
  4. vi (the trace struct) now contains a symbolic representation of the logjoint retrievable through getlogp(vi).

Comment on lines +67 to +72
function dependencies(m::Model, symbolic=false)
ctx = SymbolicContext(DefaultContext())
vi = symbolize(m, VarInfo(m); ctx=ctx)

return dependencies(ctx, symbolic)
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uses "contextual" dispatch to overload the corresponding *tilde_ statements, stroing the mapping from symbolic variable θ[i] to VarName.

Comment on lines +47 to +56
function getlogpdf(d, args)
replacements = Dict(:Normal => StatsFuns.normlogpdf, :Gamma => StatsFuns.gammalogpdf)

dsym = Symbol(d)
if haskey(replacements, dsym)
return replacements[dsym]
else
return d
end
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea behind all this was to replace the "non-tracable" logpdf impls from Distributions.jl with traceable impls from StatsFuns.jl. After #292 we could probably avoid this by simply using MeasureTheory.jl instead:)

Also, this is super messy and probably not the greatest; I blame it on the fact that I had no idea what I was doing:)

@yebai
Copy link
Member

yebai commented Mar 17, 2022

Closed in favour of TuringLang/AbstractPPL.jl#47

@yebai yebai closed this Mar 17, 2022
@yebai yebai deleted the tor/symbolics branch March 17, 2022 17:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants