11using Distributions
22using MonteCarloMeasurements
33
4+ """
5+ importanceSample(p(p_args), q(q_args), observed_data)
6+
7+ Sample from `q`, and weight the result to behave as if the sample were taken from `p`. For example,
8+
9+ ```
10+ julia> p = @model begin
11+ x ~ Normal()
12+ y ~ Normal(x,1) |> iid(5)
13+ end;
14+
15+ julia> q = @model μ,σ begin
16+ x ~ Normal(μ,σ)
17+ end;
18+
19+ julia> y = rand(p()).y;
20+
21+ julia> importanceSample(p(),q(μ=0.0, σ=0.5), (y=y,))
22+ Weighted(-7.13971.4
23+ ,(x = -0.12280566635062592,)
24+ ````
25+ """
26+ function importanceSample end
27+
428export importanceSample
529@inline function importanceSample (p:: JointDistribution , q:: JointDistribution , _data)
630 return _importanceSample (getmoduletypencoding (p. model), p. model, p. args, q. model, q. args, _data)
933@gg M function _importanceSample (_:: Type{M} , p:: Model , _pargs, q:: Model , _qargs, _data) where M <: TypeLevel{Module}
1034 p = type2model (p)
1135 q = type2model (q)
12-
36+
1337 Expr (:let ,
1438 Expr (:(= ), :M , from_type (M)),
15- sourceImportanceSample ()(p,q) |> loadvals (_qargs, _data) |> loadvals (_pargs, NamedTuple ()))
39+ sourceImportanceSample (_data)(p,q) |> loadvals (_qargs, _data) |> loadvals (_pargs, NamedTuple ()) |> merge_pqargs)
40+
41+
1642end
1743
18- sourceImportanceSample (p:: Model ,q:: Model ) = sourceImportanceSample ()(p:: Model ,q:: Model )
44+ sourceImportanceSample (p:: Model ,q:: Model ,_data ) = sourceImportanceSample (_data )(p:: Model ,q:: Model )
1945
2046export sourceImportanceSample
21- function sourceImportanceSample ()
47+ function sourceImportanceSample (_data )
2248 function (p:: Model ,q:: Model )
2349 p = canonical (p)
2450 q = canonical (q)
2551 m = merge (p,q)
2652
53+ _datakeys = getntkeys (_data)
54+
2755 function proc (m, st:: Sample )
56+ st. x ∈ _datakeys && return :(_ℓ += logpdf ($ (st. rhs), $ (st. x)))
57+
2858 if hasproperty (p. dists, st. x)
2959 pdist = getproperty (p. dists, st. x)
30- qdist = st. rhs
60+ qdist = getproperty (q . dists, st. x)
3161 @gensym ℓx
3262 result = @q begin
3363 $ ℓx = importanceSample ($ pdist, $ qdist, _data)
3464 _ℓ += $ ℓx. ℓ
3565 $ (st. x) = $ ℓx. val
3666 end
3767 return flatten (result)
38- else return :($ (st. x) = $ (st. rhs))
68+ else return :($ (st. x) = rand ( $ (st. rhs) ))
3969 end
4070 return flatten (result)
4171 end
@@ -45,11 +75,11 @@ function sourceImportanceSample()
4575
4676 body = buildSource (m, proc) |> flatten
4777
48- kwargs = freeVariables (q ) ∪ arguments (p )
78+ kwargs = arguments (p ) ∪ arguments (q )
4979 kwargsExpr = Expr (:tuple ,kwargs... )
5080
5181 stochExpr = begin
52- vals = map (sampled (m )) do x Expr (:(= ), x,x) end
82+ vals = map (sampled (q )) do x Expr (:(= ), x,x) end
5383 Expr (:tuple , vals... )
5484 end
5585
@@ -64,7 +94,7 @@ function sourceImportanceSample()
6494end
6595
6696@inline function importanceSample (p, q, _data)
67- x = merge ( rand (q), _data )
97+ x = rand (q)
6898 ℓ = logpdf (p,x) - logpdf (q,x)
6999 Weighted (ℓ,x)
70100end
139169# ))
140170
141171# end
172+
173+ function merge_pqargs (src)
174+ @q begin
175+ _args = merge (_pargs, _qargs)
176+ $ src
177+ end |> flatten
178+ end
0 commit comments