Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Soss"
uuid = "8ce77f84-9b61-11e8-39ff-d17a774bf41c"
author = ["Chad Scherrer <chad.scherrer@gmail.com>"]
version = "0.15.3"
version = "0.15.4"

[deps]
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
Expand All @@ -20,6 +20,7 @@ LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Expand Down
4 changes: 1 addition & 3 deletions src/Soss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,10 @@ include("inference/rejection.jl")
include("inference/dynamicHMC.jl")
include("inference/advancedhmc.jl")


# include("weighted.jl")
#
# # include("graph.jl")
# # # include("optim.jl")
# include("importance.jl")
include("importance.jl")
#
# # # include("sobols.jl")
# # # include("fromcube.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/core/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ end

getntkeys(::NamedTuple{A,B}) where {A,B} = A
getntkeys(::Type{NamedTuple{A,B}}) where {A,B} = A

getntkeys(::Type{NamedTuple{A}}) where {A} = A

# These macros quickly define additional methods for when you get tired of typing `NamedTuple()`
macro tuple3args(f)
Expand Down
4 changes: 2 additions & 2 deletions src/core/weighted.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ end

using Printf
function Base.show(io::IO, ℓx::Weighted)
@printf io "Weighted(%g.4\n" (ℓx.ℓ)
println(",", ℓx.val)
@printf io "Weighted(%.4g" (ℓx.ℓ)
println(io, ", ", ℓx.val)
end
55 changes: 46 additions & 9 deletions src/importance.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,30 @@
using Distributions
using MonteCarloMeasurements

"""
importanceSample(p(p_args), q(q_args), observed_data)

Sample from `q`, and weight the result to behave as if the sample were taken from `p`. For example,

```
julia> p = @model begin
x ~ Normal()
y ~ Normal(x,1) |> iid(5)
end;

julia> q = @model μ,σ begin
x ~ Normal(μ,σ)
end;

julia> y = rand(p()).y;

julia> importanceSample(p(),q(μ=0.0, σ=0.5), (y=y,))
Weighted(-7.13971.4
,(x = -0.12280566635062592,)
````
"""
function importanceSample end

export importanceSample
@inline function importanceSample(p::JointDistribution, q::JointDistribution, _data)
return _importanceSample(getmoduletypencoding(p.model), p.model, p.args, q.model, q.args, _data)
Expand All @@ -9,33 +33,39 @@ end
@gg M function _importanceSample(_::Type{M}, p::Model, _pargs, q::Model, _qargs, _data) where M <: TypeLevel{Module}
p = type2model(p)
q = type2model(q)

Expr(:let,
Expr(:(=), :M, from_type(M)),
sourceImportanceSample()(p,q) |> loadvals(_qargs, _data) |> loadvals(_pargs, NamedTuple()))
sourceImportanceSample(_data)(p,q) |> loadvals(_qargs, _data) |> loadvals(_pargs, NamedTuple()) |> merge_pqargs)


end

sourceImportanceSample(p::Model,q::Model) = sourceImportanceSample()(p::Model,q::Model)
sourceImportanceSample(p::Model,q::Model,_data) = sourceImportanceSample(_data)(p::Model,q::Model)

export sourceImportanceSample
function sourceImportanceSample()
function sourceImportanceSample(_data)
function(p::Model,q::Model)
p = canonical(p)
q = canonical(q)
m = merge(p,q)

_datakeys = getntkeys(_data)

function proc(m, st::Sample)
st.x ∈ _datakeys && return :(_ℓ += logpdf($(st.rhs), $(st.x)))

if hasproperty(p.dists, st.x)
pdist = getproperty(p.dists, st.x)
qdist = st.rhs
qdist = getproperty(q.dists, st.x)
@gensym ℓx
result = @q begin
$ℓx = importanceSample($pdist, $qdist, _data)
_ℓ += $ℓx.ℓ
$(st.x) = $ℓx.val
end
return flatten(result)
else return :($(st.x) = $(st.rhs))
else return :($(st.x) = rand($(st.rhs)))
end
return flatten(result)
end
Expand All @@ -45,11 +75,11 @@ function sourceImportanceSample()

body = buildSource(m, proc) |> flatten

kwargs = freeVariables(q) ∪ arguments(p)
kwargs = arguments(p) ∪ arguments(q)
kwargsExpr = Expr(:tuple,kwargs...)

stochExpr = begin
vals = map(sampled(m)) do x Expr(:(=), x,x) end
vals = map(sampled(q)) do x Expr(:(=), x,x) end
Expr(:tuple, vals...)
end

Expand All @@ -64,7 +94,7 @@ function sourceImportanceSample()
end

@inline function importanceSample(p, q, _data)
x = merge(rand(q), _data)
x = rand(q)
ℓ = logpdf(p,x) - logpdf(q,x)
Weighted(ℓ,x)
end
Expand Down Expand Up @@ -139,3 +169,10 @@ end
# ))

# end

function merge_pqargs(src)
@q begin
_args = merge(_pargs, _qargs)
$src
end |> flatten
end
8 changes: 3 additions & 5 deletions src/particles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ end

export parts

using MappedArrays, FillArrays

# Just a little helper function for particles
# https://github.com/baggepinnen/MonteCarloMeasurements.jl/issues/22
Expand All @@ -43,12 +44,9 @@ parts(x::Integer, N::Int=DEFAULT_SAMPLE_SIZE) = parts(float(x))
parts(x::Real, N::Int=DEFAULT_SAMPLE_SIZE) = parts(repeat([x],N))
parts(x::AbstractArray, N::Int=DEFAULT_SAMPLE_SIZE) = Particles(x)
parts(p::Particles, N::Int=DEFAULT_SAMPLE_SIZE) = p
parts(d::For, N::Int=DEFAULT_SAMPLE_SIZE) = parts.(d.f.(d.θ...), N)
function parts(d::For{F,Tuple{I}}, N::Int=DEFAULT_SAMPLE_SIZE) where {F <: Function, I <: Integer}
parts.(d.f.(Base.OneTo.(d.θ)...), N)
end
parts(d::For, N::Int=DEFAULT_SAMPLE_SIZE) = parts.(mappedarray(d.f, CartesianIndices(d.θ)), N)

parts(d::iid, N::Int=DEFAULT_SAMPLE_SIZE) = map(1:d.size) do j parts(d.dist, N) end
parts(d::iid, N::Int=DEFAULT_SAMPLE_SIZE) = parts.(fill(d.dist, d.size))
# size
# dist

Expand Down