-
Notifications
You must be signed in to change notification settings - Fork 37
Description
Following from a discussion on Slack:
One thing which I think we need to solve is that there is a key architectural difference between Soss/Gen and Turing.
In the former case, you can construct a model in the syntax without reference to the eventual algorithm or target variables which you perform inference with/on in that model.
E.g. when constructing a GenerativeFunction in Gen, the important thing is the insertion of traceat expressions which are configured after the fact with the state passed in by each GFI method. But the macro expansion doesn’t require that you say what algorithms you’re going to use on the model, or what variables you are curious about.
In Turing: https://turing.ml/dev/docs/using-turing/advanced you have to preset the thing which you are performing inference on, and you also have to pass in the TBD sampler as state.
Now, there’s an interesting solution to this - but it’s a bit invasive. I can write a dynamo which intercepts these calls, and completely ignores their arguments, and then grabs the log probability after sampling from the model with context which does single step sampling. I.e. you hollow out the Turing model so that you can connect it up to the GFI. This is basically a Turing model with the connection to AbstractMCMC cut out, and the notion of target variable removed.
Then a user can re-introduce the interface to AbstractMCMC through the framework which is doing the hollowing. And target variables can be queried selectively in that framework.
My idea here is completely based on this example, which shows the primitives which are inserted during the @model expansion:
using Turing
# Initialize a NamedTuple containing our data variables.
data = (x = [1.5, 2.0],)
# Create the model function.
mf(vi, sampler, ctx, model) = begin
# Set the accumulated logp to zero.
resetlogp!(vi)
x = model.args.x
# Assume s has an InverseGamma distribution.
s, lp = Turing.Inference.tilde(
ctx,
sampler,
InverseGamma(2, 3),
Turing.@varname(s),
(),
vi,
)
# Add the lp to the accumulated logp.
acclogp!(vi, lp)
# Assume m has a Normal distribution.
m, lp = Turing.Inference.tilde(
ctx,
sampler,
Normal(0, sqrt(s)),
Turing.@varname(m),
(),
vi,
)
# Add the lp to the accumulated logp.
acclogp!(vi, lp)
# Observe each value of x[i], according to a
# Normal distribution.
lp = Turing.Inference.dot_tilde(ctx, sampler, Normal(m, sqrt(s)), x, vi)
acclogp!(vi, lp)
end
# Instantiate a Model object.
model = DynamicPPL.Model(mf, data, DynamicPPL.ModelGen{()}(nothing, nothing))
# Sample the model.
chain = sample(model, HMC(0.1, 5), 1000)I would like to be able to call Turing from my GFI/dynamo-based system like so:
using Jaynes
Jaynes.@load_foreign()
Jaynes.@load_foreign()
Jaynes.@load_turing_fmi()
# A Soss model.
m = Soss.@model σ begin
μ ~ Normal()
y ~ Normal(μ, σ) |> iid(5)
end
# A Turing model.
Turing.@model gdemo(x) = begin
# Set priors.
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
# Observe each value of x.
@. x ~ Normal(m, sqrt(s))
end
# A Gen model.
@gen (static) function foo(x::Float64)
y = @trace(normal(x, 1.0), :y)
return y
end
Gen.load_generated_functions()
bar = () -> begin
x = rand(:x, Normal(5.0, 1.0))
gen_ret = foreign(:gen, foo, x)
soss_ret = foreign(:soss, m, (σ = x,))
turing_ret = foreign(:tur, gdemo)
return turing_ret
end
ret, cl = Jaynes.simulate(bar)
display(cl.trace)This would provide a minimal set of interfaces for true plug-and-play in the ecosystem.
I plan to do this by writing a set of GFI language extensions, as shown here:
https://github.com/femtomc/Jaynes.jl/blob/master/src/foreign_model_interfaces/gen.jl
https://github.com/femtomc/Jaynes.jl/blob/master/src/foreign_model_interfaces/soss.jl
Turing will be a bit more complicated than these other interfaces, but I believe that it can be done. Specifically, we need to determine what the semantics of a call like foreign(:tur, gdemo) means with respect to the GFI.
My initial idea was to extract the model function out of the model, create a bunch of dummy types which match the signature of the model function call, and then context intercept each of the Turing.Inference calls in the model function e.g.
(mx::ExecutionContext)(inf_fn::typeof(Turing.Inference.dot_tilde), args...)
(mx::ExecutionContext)(inf_fn::typeof(acclogp!), vi, lp)
(mx::ExecutionContext)(inf_fn::typeof(Turing.Inference.tilde), ctx, sampler, dist, varname, foo, vi)Now, we give each of these function calls a semantics under the GFI. In particular, we accumulate the logpdf onto the context (and not the vi object), we ignore the target variable inference spec (as specified by dot_tilde), and we don't use the sampler.
What this basically means is that you cut a lot of the Turing interfaces away, and instead use it as a DSL which ultimately has a different semantics under the GFI, one which is closer to the trace-based semantics. That comes at a cost of course - there's a lot of usefulness in the interfaces, and especially the samplers. I have the beginnings of an idea about bridging AbstractMCMC to a trace-based framework - but nothing is going to be as elegant as doing inference using the host interfaces in Turing itself.
Other reference issues:
femtomc/Jaynes.jl#20
https://github.com/TuringLang/Turing.jl/issues/1351