Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NamedTupleVariate and ProductNamedTupleDistribution #1803

Open
wants to merge 41 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
142380b
Add NamedTupleVariate
sethaxen Nov 23, 2023
191ca1a
Add ProductNamedTupleDistribution
sethaxen Nov 23, 2023
399b03b
Correctly implement eltype
sethaxen Nov 23, 2023
eb946a8
Simplify insupport implementation
sethaxen Nov 23, 2023
32ca2f0
Overload std for ProductNamedTupleDistribution
sethaxen Nov 23, 2023
a416f02
Simplify rand for ProductNamedTupleDistribution
sethaxen Nov 23, 2023
7deff94
Reformat line
sethaxen Nov 23, 2023
978b2de
Add docstring to ProductNamedTupleDistribution
sethaxen Nov 23, 2023
b718e59
Add marginal API function
sethaxen Nov 23, 2023
d08431d
Add marginal for ProductDistribution
sethaxen Nov 23, 2023
79e5d59
Rearrange marginal
sethaxen Nov 23, 2023
52fb9a0
Allow tuple indexing via marginal
sethaxen Nov 23, 2023
1509abd
Make logpdf type-stable
sethaxen Nov 23, 2023
450fb7d
Add loglikelihood
sethaxen Nov 23, 2023
eb2ed6c
Support extrema for multivariate distributions
sethaxen Nov 23, 2023
e3a0814
Add tests
sethaxen Nov 23, 2023
9acc869
Improve type-inferrability
sethaxen Nov 23, 2023
6d8df2a
Remove extension
sethaxen Nov 23, 2023
d115441
Merge branch 'master' into namedtuplevariate
sethaxen May 27, 2024
9f19a2e
Merge branch 'master' into namedtuplevariate
devmotion Jul 14, 2024
800de5b
Apply suggestions from code review
sethaxen Jul 15, 2024
0b83587
Remove marginal
sethaxen Jul 15, 2024
ba03eea
Add sampler for product namedtuple
sethaxen Jul 15, 2024
1712be6
Use ProductNamedTupleSampler for array rand calls
sethaxen Jul 15, 2024
1056d0d
Add docs page for product distributions
sethaxen Aug 19, 2024
58937fd
Fix typo
sethaxen Aug 19, 2024
d7fd842
Fix ProductNamedTuple docstring
sethaxen Aug 19, 2024
c8b1602
Add deprecation warning to Product docstring
sethaxen Aug 19, 2024
db029c5
Move multivariate product distributions to own page
sethaxen Aug 19, 2024
2634adb
Document NamedTuple products
sethaxen Aug 19, 2024
eb5b176
Add docs index
sethaxen Aug 19, 2024
3ebc3ba
Document usage of ProductNamedTuple
sethaxen Aug 19, 2024
f0dd8c4
Load Distributions for jldoctest
sethaxen Aug 19, 2024
121dd2b
Apply suggestions from code review
sethaxen Sep 4, 2024
a86cac4
Call method on NamedTuple
sethaxen Sep 4, 2024
46fdcfc
Revert to typejoin based eltype
sethaxen Sep 5, 2024
54b0d03
Explicitly check eltype of dist matches that of draw
sethaxen Sep 5, 2024
fe284b1
Correctly compute eltype for nested prod namedtuple distributions
sethaxen Sep 5, 2024
96ccc99
Merge branch 'master' into namedtuplevariate
sethaxen Sep 5, 2024
1eabd23
Revert "Call method on NamedTuple"
sethaxen Sep 5, 2024
28a7c00
Update test/namedtuple/productnamedtuple.jl
sethaxen Sep 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/Distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ export
Multivariate,
Matrixvariate,
CholeskyVariate,
NamedTupleVariate,
Discrete,
Continuous,
Sampleable,
Expand Down Expand Up @@ -230,6 +231,7 @@ export
sqmahal!, # in-place evaluation of sqmahal
location, # get the location parameter
location!, # provide storage for the location parameter (used in multivariate distribution mvlognormal)
marginal, # marginal distributions
mean, # mean of distribution
meandir, # mean direction (of a spherical distribution)
meanform, # convert a normal distribution from canonical form to mean form
Expand Down Expand Up @@ -297,6 +299,7 @@ include("univariates.jl")
include("edgeworth.jl")
include("multivariates.jl")
include("matrixvariates.jl")
include("namedtuple/productnamedtuple.jl")
include("cholesky/lkjcholesky.jl")
include("samplers.jl")

Expand Down
16 changes: 16 additions & 0 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ const Univariate = ArrayLikeVariate{0}
const Multivariate = ArrayLikeVariate{1}
const Matrixvariate = ArrayLikeVariate{2}

"""
`F <: NamedTupleVariate{K}` specifies that the variate or a sample is of type
`NamedTuple{K}`.
"""
abstract type NamedTupleVariate{K} <: VariateForm end
sethaxen marked this conversation as resolved.
Show resolved Hide resolved

"""
`F <: CholeskyVariate` specifies that the variate or a sample is of type
`LinearAlgebra.Cholesky`.
Expand Down Expand Up @@ -464,6 +470,16 @@ Base.@propagate_inbounds function loglikelihood(
return sum(Base.Fix1(logpdf, d), x)
end

"""
marginal(d::Distribution, k...) -> Distribution

Return the marginal distribution of `d` at the indices `k...`.

The result is the distribution of the variate `rand(d)[k...]` that one would obtain by
integrating over all other indices.
"""
marginal(d::Distribution, k...)

sethaxen marked this conversation as resolved.
Show resolved Hide resolved
## TODO: the following types need to be improved
abstract type SufficientStats end
abstract type IncompleteDistribution end
Expand Down
131 changes: 131 additions & 0 deletions src/namedtuple/productnamedtuple.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""
ProductNamedTupleDistribution{Tnames,Tdists,S<:ValueSupport,eltypes} <:
Distribution{NamedTupleVariate{Tnames},S}

A distribution of `NamedTuple`s, constructed from a `NamedTuple` of independent named
distributions.

Users should use [`product_distribution`](@ref) to construct a product distribution of
independent distributions instead of constructing a `ProductNamedTupleDistribution`
directly.
"""
struct ProductNamedTupleDistribution{Tnames,Tdists,S<:ValueSupport,eltypes} <:
Distribution{NamedTupleVariate{Tnames},S}
dists::NamedTuple{Tnames,Tdists}
end
function ProductNamedTupleDistribution(
dists::NamedTuple{K,V}
) where {K,V<:Tuple{Vararg{Distribution}}}
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
vs = _product_valuesupport(values(dists))
eltypes = _product_namedtuple_eltype(values(dists))
return ProductNamedTupleDistribution{K,V,vs,eltypes}(dists)
end

_gentype(d::UnivariateDistribution) = eltype(d)
_gentype(d::Distribution{<:ArrayLikeVariate{S}}) where {S} = Array{eltype(d),S}
function _gentype(d::Distribution{CholeskyVariate})
T = eltype(d)
return LinearAlgebra.Cholesky{T,Matrix{T}}
end
_gentype(::Distribution) = Any

_product_namedtuple_eltype(dists) = typejoin(map(_gentype, dists)...)
devmotion marked this conversation as resolved.
Show resolved Hide resolved

function Base.show(io::IO, d::ProductNamedTupleDistribution)
return show_multline(io, d, collect(pairs(d.dists)))
devmotion marked this conversation as resolved.
Show resolved Hide resolved
end

function distrname(::ProductNamedTupleDistribution{K}) where {K}
return "ProductNamedTupleDistribution{$K}"
end

"""
product_distribution(dists::Namedtuple{K,Tuple{Vararg{Distribution}}}) where {K}

Create a distribution of `NamedTuple`s as a product distribution of independent named
distributions.

The function falls back to constructing a [`ProductNamedTupleDistribution`](@ref)
distribution but specialized methods can be defined.
"""
function product_distribution(
dists::NamedTuple{<:Any,<:Tuple{Distribution,Vararg{Distribution}}}
)
return ProductNamedTupleDistribution(dists)
end

# Properties

Base.eltype(::Type{<:ProductNamedTupleDistribution{<:Any,<:Any,<:Any,T}}) where {T} = T

Base.minimum(d::ProductNamedTupleDistribution) = map(minimum, d.dists)

Base.maximum(d::ProductNamedTupleDistribution) = map(maximum, d.dists)

marginal(d::ProductNamedTupleDistribution, k::Union{Int,Symbol}) = d.dists[k]
if VERSION ≥ v"1.7.0-DEV.294"
function marginal(d::ProductNamedTupleDistribution, ks::Tuple{Symbol,Vararg{Symbol}})
return ProductNamedTupleDistribution(d.dists[ks])
end
end

function insupport(dist::ProductNamedTupleDistribution{K}, x::NamedTuple{K}) where {K}
return all(map(insupport, dist.dists, x))
end

# Evaluation

function pdf(dist::ProductNamedTupleDistribution{K}, x::NamedTuple{K}) where {K}
return exp(logpdf(dist, x))
end

function logpdf(dist::ProductNamedTupleDistribution{K}, x::NamedTuple{K}) where {K}
return sum(map(logpdf, dist.dists, x))
end

function loglikelihood(dist::ProductNamedTupleDistribution{K}, x::NamedTuple{K}) where {K}
return logpdf(dist, x)
end

function loglikelihood(
dist::ProductNamedTupleDistribution{K}, xs::AbstractArray{<:NamedTuple{K}}
) where {K}
return sum(Base.Fix1(loglikelihood, dist), xs)
end

# Statistics

mode(d::ProductNamedTupleDistribution) = map(mode, d.dists)

mean(d::ProductNamedTupleDistribution) = map(mean, d.dists)

var(d::ProductNamedTupleDistribution) = map(var, d.dists)

std(d::ProductNamedTupleDistribution) = map(std, d.dists)

entropy(d::ProductNamedTupleDistribution) = sum(entropy, values(d.dists))

function kldivergence(
d1::ProductNamedTupleDistribution{K}, d2::ProductNamedTupleDistribution{K}
) where {K}
return mapreduce(kldivergence, +, d1.dists, d2.dists)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
end

# Sampling

function Base.rand(rng::AbstractRNG, d::ProductNamedTupleDistribution{K}) where {K}
return NamedTuple{K}(map(Base.Fix1(rand, rng), d.dists))
end
function Base.rand(rng::AbstractRNG, d::ProductNamedTupleDistribution, dims::Dims)
xs = return map(CartesianIndices(dims)) do _
return rand(rng, d)
end
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
return xs
end

function _rand!(rng::AbstractRNG, d::ProductNamedTupleDistribution, xs::AbstractArray)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
for i in eachindex(xs)
xs[i] = Random.rand(rng, d)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
end
return xs
end
2 changes: 2 additions & 0 deletions src/product.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ minimum(d::VectorOfUnivariateDistribution{<:Tuple}) = collect(map(minimum, d.dis
maximum(d::ArrayOfUnivariateDistribution) = map(maximum, d.dists)
maximum(d::VectorOfUnivariateDistribution{<:Tuple}) = collect(map(maximum, d.dists))

marginal(d::ProductDistribution, i...) = d.dists[i...]

function entropy(d::ArrayOfUnivariateDistribution)
# we use pairwise summation (https://github.com/JuliaLang/julia/pull/31020)
return sum(Broadcast.instantiate(Broadcast.broadcasted(entropy, d.dists)))
Expand Down
198 changes: 198 additions & 0 deletions test/namedtuple/productnamedtuple.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
using Distributions
using Distributions: ProductNamedTupleDistribution
using LinearAlgebra
using Random
using Test

@testset "ProductNamedTupleDistribution" begin
@testset "Constructor" begin
nt = (x=Normal(1.0, 2.0), y=Normal(3.0, 4.0))
d = @inferred ProductNamedTupleDistribution(nt)
@test d isa ProductNamedTupleDistribution
@test d.dists === nt
@test Distributions.variate_form(typeof(d)) === NamedTupleVariate{(:x, :y)}
@test Distributions.value_support(typeof(d)) === Continuous

nt = (
x=Normal(),
y=Dirichlet(10, 1.0),
z=DiscreteUniform(1, 10),
w=LKJCholesky(3, 2.0),
)
d = @inferred ProductNamedTupleDistribution(nt)
@test d isa ProductNamedTupleDistribution
@test d.dists === nt
@test Distributions.variate_form(typeof(d)) === NamedTupleVariate{(:x, :y, :z, :w)}
@test Distributions.value_support(typeof(d)) === Continuous
end

@testset "product_distribution" begin
nt = (x=Normal(1.0, 2.0), y=Normal(3.0, 4.0))
d = @inferred product_distribution(nt)
@test d === ProductNamedTupleDistribution(nt)

nt = (
x=Normal(),
y=Dirichlet(10, 1.0),
z=DiscreteUniform(1, 10),
w=LKJCholesky(3, 2.0),
)
d = @inferred product_distribution(nt)
@test d === ProductNamedTupleDistribution(nt)
end

@testset "show" begin
d = ProductNamedTupleDistribution((x=Gamma(1.0, 2.0), y=Normal()))
@test sprint(show, d) == """
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
ProductNamedTupleDistribution{(:x, :y)}(
x: Gamma{Float64}(α=1.0, θ=2.0)
y: Normal{Float64}(μ=0.0, σ=1.0)
)
"""
end

@testset "Properties" begin
@testset "eltype" begin
nt = (x=Normal(1.0, 2.0), y=Normal(3.0, 4.0))
d = ProductNamedTupleDistribution(nt)
@test eltype(d) === Float64

nt = (x=Normal(), y=Gamma())
d = ProductNamedTupleDistribution(nt)
@test eltype(d) === Float64

nt = (x=Bernoulli(),)
d = ProductNamedTupleDistribution(nt)
@test eltype(d) === Bool

nt = (x=Normal(), y=Bernoulli())
d = ProductNamedTupleDistribution(nt)
@test eltype(d) === Real

nt = (w=LKJCholesky(3, 2.0),)
d = ProductNamedTupleDistribution(nt)
@test eltype(d) === LinearAlgebra.Cholesky{Float64,Array{Float64,2}}

nt = (
x=Normal(),
y=Dirichlet(10, 1.0),
z=DiscreteUniform(1, 10),
w=LKJCholesky(3, 2.0),
)
d = ProductNamedTupleDistribution(nt)
@test eltype(d) === Any
end

@testset "minimum" begin
nt = (x=Normal(1.0, 2.0), y=Gamma(), z=MvNormal(ones(5)))
d = ProductNamedTupleDistribution(nt)
@test @inferred(minimum(d)) ==
(x=minimum(nt.x), y=minimum(nt.y), z=minimum(nt.z))
end

@testset "maximum" begin
nt = (x=Normal(1.0, 2.0), y=Gamma(), z=MvNormal(ones(5)))
d = ProductNamedTupleDistribution(nt)
@test @inferred(maximum(d)) ==
(x=maximum(nt.x), y=maximum(nt.y), z=maximum(nt.z))
end

@testset "marginal" begin
nt = (x=Normal(1.0, 2.0), y=Gamma(), z=Dirichlet(5, 1.0))
d = ProductNamedTupleDistribution(nt)
@test marginal(d, :x) === nt[:x]
@test marginal(d, :y) === nt[:y]
@test marginal(d, :z) === nt[:z]
@test marginal(d, 1) === nt[1]
@test marginal(d, 2) === nt[2]
@test marginal(d, 3) === nt[3]
if VERSION ≥ v"1.7.0-DEV.294"
@test marginal(d, (:x, :y)) ===
ProductNamedTupleDistribution((x=nt[:x], y=nt[:y]))
@test marginal(d, (:z, :x)) ===
ProductNamedTupleDistribution((z=nt[:z], x=nt[:x]))
@test_throws ErrorException marginal(d, (:x, :w))
end
@test_throws MethodError marginal(d, ())
end

@testset "insupport" begin
nt = (x=Normal(1.0, 2.0), y=Gamma(), z=Dirichlet(5, 1.0))
d = ProductNamedTupleDistribution(nt)
x = (x=rand(nt.x), y=rand(nt.y), z=rand(nt.z))
@test @inferred(insupport(d, x))
@test_throws MethodError insupport(d, NamedTuple{(:y, :z, :x)}(x))
@test_throws MethodError insupport(d, NamedTuple{(:x, :y)}(x))
@test !insupport(d, merge(x, (x=NaN,)))
@test !insupport(d, merge(x, (y=-1,)))
@test !insupport(d, merge(x, (z=fill(0.25, 4),)))
end
end

@testset "Evaluation" begin
nt = (x=Normal(1.0, 2.0), y=Gamma(), z=Dirichlet(5, 1.0), w=Bernoulli())
d = ProductNamedTupleDistribution(nt)
x = (x=rand(nt.x), y=rand(nt.y), z=rand(nt.z), w=rand(nt.w))
@test @inferred(logpdf(d, x)) ==
logpdf(nt.x, x.x) + logpdf(nt.y, x.y) + logpdf(nt.z, x.z) + logpdf(nt.w, x.w)
@test @inferred(pdf(d, x)) == exp(logpdf(d, x))
@test @inferred(loglikelihood(d, x)) == logpdf(d, x)
xs = [(x=rand(nt.x), y=rand(nt.y), z=rand(nt.z), w=rand(nt.w)) for _ in 1:10]
@test @inferred(loglikelihood(d, xs)) == sum(logpdf.(Ref(d), xs))
end

@testset "Statistics" begin
nt = (x=Normal(1.0, 2.0), y=Gamma(), z=MvNormal(1.0:5.0), w=Poisson(100))
d = ProductNamedTupleDistribution(nt)
@test @inferred(mode(d)) == (x=mode(nt.x), y=mode(nt.y), z=mode(nt.z), w=mode(nt.w))
@test @inferred(mean(d)) == (x=mean(nt.x), y=mean(nt.y), z=mean(nt.z), w=mean(nt.w))
@test @inferred(var(d)) == (x=var(nt.x), y=var(nt.y), z=var(nt.z), w=var(nt.w))
@test @inferred(entropy(d)) ==
entropy(nt.x) + entropy(nt.y) + entropy(nt.z) + entropy(nt.w)

d1 = ProductNamedTupleDistribution((x=Normal(1.0, 2.0), y=Gamma()))
d2 = ProductNamedTupleDistribution((x=Normal(), y=Gamma(2.0, 3.0)))
@test kldivergence(d1, d2) ==
kldivergence(d1.dists.x, d2.dists.x) + kldivergence(d1.dists.y, d2.dists.y)

d3 = ProductNamedTupleDistribution((x=Normal(1.0, 2.0), y=Gamma(6.0, 7.0)))
@test std(d3) == (x=std(d3.dists.x), y=std(d3.dists.y))
end

@testset "Sampling" begin
rng = MersenneTwister(973)

@testset "rand" begin
nt = (x=Normal(1.0, 2.0), y=Gamma(), z=Dirichlet(5, 1.0), w=Bernoulli())
d = ProductNamedTupleDistribution(nt)
rng = MersenneTwister(973)
x1 = @inferred rand(rng, d)
@test eltype(x1) === eltype(d)
rng = MersenneTwister(973)
x2 = (
x=rand(rng, nt.x), y=rand(rng, nt.y), z=rand(rng, nt.z), w=rand(rng, nt.w)
)
@test x1 == x2
x3 = rand(rng, d)
@test x3 != x1

xs1 = @inferred rand(rng, d, 10)
@test length(xs1) == 10
@test all(insupport.(Ref(d), xs1))

xs2 = @inferred rand(rng, d, (2, 3, 4))
@test size(xs2) == (2, 3, 4)
@test all(insupport.(Ref(d), xs2))
end

@testset "rand!" begin
d = ProductNamedTupleDistribution((
x=Normal(1.0, 2.0), y=Gamma(), z=Dirichlet(5, 1.0), w=Bernoulli()
))
x = rand(d)
xs = Array{typeof(x)}(undef, (2, 3, 4))
rand!(d, xs)
@test all(insupport.(Ref(d), xs))
end
end
end
Loading
Loading