Skip to content

Commit 97e49f2

Browse files
authored
importance sampling (#214)
* fix importance sampling * docs * unrelated particles bugfix * update Project.toml * bugfix in Base.show(io::IO, ℓx::Weighted)
1 parent 39f1767 commit 97e49f2

File tree

6 files changed

+55
-21
lines changed

6 files changed

+55
-21
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Soss"
22
uuid = "8ce77f84-9b61-11e8-39ff-d17a774bf41c"
33
author = ["Chad Scherrer <chad.scherrer@gmail.com>"]
4-
version = "0.15.3"
4+
version = "0.15.4"
55

66
[deps]
77
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
@@ -20,6 +20,7 @@ LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
2020
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
2121
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
2222
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
23+
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
2324
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
2425
NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50"
2526
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"

src/Soss.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,10 @@ include("inference/rejection.jl")
5757
include("inference/dynamicHMC.jl")
5858
include("inference/advancedhmc.jl")
5959

60-
61-
# include("weighted.jl")
6260
#
6361
# # include("graph.jl")
6462
# # # include("optim.jl")
65-
# include("importance.jl")
63+
include("importance.jl")
6664
#
6765
# # # include("sobols.jl")
6866
# # # include("fromcube.jl")

src/core/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ end
226226

227227
getntkeys(::NamedTuple{A,B}) where {A,B} = A
228228
getntkeys(::Type{NamedTuple{A,B}}) where {A,B} = A
229-
229+
getntkeys(::Type{NamedTuple{A}}) where {A} = A
230230

231231
# These macros quickly define additional methods for when you get tired of typing `NamedTuple()`
232232
macro tuple3args(f)

src/core/weighted.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ end
66

77
using Printf
88
function Base.show(io::IO, ℓx::Weighted)
9-
@printf io "Weighted(%g.4\n" (ℓx.ℓ)
10-
println(",", ℓx.val)
9+
@printf io "Weighted(%.4g" (ℓx.ℓ)
10+
println(io, ", ", ℓx.val)
1111
end

src/importance.jl

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,30 @@
11
using Distributions
22
using MonteCarloMeasurements
33

4+
"""
5+
importanceSample(p(p_args), q(q_args), observed_data)
6+
7+
Sample from `q`, and weight the result to behave as if the sample were taken from `p`. For example,
8+
9+
```
10+
julia> p = @model begin
11+
x ~ Normal()
12+
y ~ Normal(x,1) |> iid(5)
13+
end;
14+
15+
julia> q = @model μ,σ begin
16+
x ~ Normal(μ,σ)
17+
end;
18+
19+
julia> y = rand(p()).y;
20+
21+
julia> importanceSample(p(),q(μ=0.0, σ=0.5), (y=y,))
22+
Weighted(-7.13971.4
23+
,(x = -0.12280566635062592,)
24+
````
25+
"""
26+
function importanceSample end
27+
428
export importanceSample
529
@inline function importanceSample(p::JointDistribution, q::JointDistribution, _data)
630
return _importanceSample(getmoduletypencoding(p.model), p.model, p.args, q.model, q.args, _data)
@@ -9,33 +33,39 @@ end
933
@gg M function _importanceSample(_::Type{M}, p::Model, _pargs, q::Model, _qargs, _data) where M <: TypeLevel{Module}
1034
p = type2model(p)
1135
q = type2model(q)
12-
36+
1337
Expr(:let,
1438
Expr(:(=), :M, from_type(M)),
15-
sourceImportanceSample()(p,q) |> loadvals(_qargs, _data) |> loadvals(_pargs, NamedTuple()))
39+
sourceImportanceSample(_data)(p,q) |> loadvals(_qargs, _data) |> loadvals(_pargs, NamedTuple()) |> merge_pqargs)
40+
41+
1642
end
1743

18-
sourceImportanceSample(p::Model,q::Model) = sourceImportanceSample()(p::Model,q::Model)
44+
sourceImportanceSample(p::Model,q::Model,_data) = sourceImportanceSample(_data)(p::Model,q::Model)
1945

2046
export sourceImportanceSample
21-
function sourceImportanceSample()
47+
function sourceImportanceSample(_data)
2248
function(p::Model,q::Model)
2349
p = canonical(p)
2450
q = canonical(q)
2551
m = merge(p,q)
2652

53+
_datakeys = getntkeys(_data)
54+
2755
function proc(m, st::Sample)
56+
st.x _datakeys && return :(_ℓ += logpdf($(st.rhs), $(st.x)))
57+
2858
if hasproperty(p.dists, st.x)
2959
pdist = getproperty(p.dists, st.x)
30-
qdist = st.rhs
60+
qdist = getproperty(q.dists, st.x)
3161
@gensym ℓx
3262
result = @q begin
3363
$ℓx = importanceSample($pdist, $qdist, _data)
3464
_ℓ += $ℓx.
3565
$(st.x) = $ℓx.val
3666
end
3767
return flatten(result)
38-
else return :($(st.x) = $(st.rhs))
68+
else return :($(st.x) = rand($(st.rhs)))
3969
end
4070
return flatten(result)
4171
end
@@ -45,11 +75,11 @@ function sourceImportanceSample()
4575

4676
body = buildSource(m, proc) |> flatten
4777

48-
kwargs = freeVariables(q) arguments(p)
78+
kwargs = arguments(p) arguments(q)
4979
kwargsExpr = Expr(:tuple,kwargs...)
5080

5181
stochExpr = begin
52-
vals = map(sampled(m)) do x Expr(:(=), x,x) end
82+
vals = map(sampled(q)) do x Expr(:(=), x,x) end
5383
Expr(:tuple, vals...)
5484
end
5585

@@ -64,7 +94,7 @@ function sourceImportanceSample()
6494
end
6595

6696
@inline function importanceSample(p, q, _data)
67-
x = merge(rand(q), _data)
97+
x = rand(q)
6898
= logpdf(p,x) - logpdf(q,x)
6999
Weighted(ℓ,x)
70100
end
@@ -139,3 +169,10 @@ end
139169
# ))
140170

141171
# end
172+
173+
function merge_pqargs(src)
174+
@q begin
175+
_args = merge(_pargs, _qargs)
176+
$src
177+
end |> flatten
178+
end

src/particles.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ end
3333

3434
export parts
3535

36+
using MappedArrays, FillArrays
3637

3738
# Just a little helper function for particles
3839
# https://github.com/baggepinnen/MonteCarloMeasurements.jl/issues/22
@@ -43,12 +44,9 @@ parts(x::Integer, N::Int=DEFAULT_SAMPLE_SIZE) = parts(float(x))
4344
parts(x::Real, N::Int=DEFAULT_SAMPLE_SIZE) = parts(repeat([x],N))
4445
parts(x::AbstractArray, N::Int=DEFAULT_SAMPLE_SIZE) = Particles(x)
4546
parts(p::Particles, N::Int=DEFAULT_SAMPLE_SIZE) = p
46-
parts(d::For, N::Int=DEFAULT_SAMPLE_SIZE) = parts.(d.f.(d.θ...), N)
47-
function parts(d::For{F,Tuple{I}}, N::Int=DEFAULT_SAMPLE_SIZE) where {F <: Function, I <: Integer}
48-
parts.(d.f.(Base.OneTo.(d.θ)...), N)
49-
end
47+
parts(d::For, N::Int=DEFAULT_SAMPLE_SIZE) = parts.(mappedarray(d.f, CartesianIndices(d.θ)), N)
5048

51-
parts(d::iid, N::Int=DEFAULT_SAMPLE_SIZE) = map(1:d.size) do j parts(d.dist, N) end
49+
parts(d::iid, N::Int=DEFAULT_SAMPLE_SIZE) = parts.(fill(d.dist, d.size))
5250
# size
5351
# dist
5452

0 commit comments

Comments
 (0)