Skip to content

Commit

Permalink
Add workaround for mismatch between eltype and type of samples.
Browse files Browse the repository at this point in the history
The semantics of `eltype` are unresolved, as discussed in this issue: 
JuliaStats/Distributions.jl#1071
  • Loading branch information
ztangent committed Jun 3, 2021
1 parent 48e12e3 commit 52ddc5e
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/distributions.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import SpecialFunctions: logfactorial

safe_eltype(d::Distributions.Distribution) = eltype(d)
safe_eltype(d::ContinuousDistribution) = float(eltype(d))

"Wraps Distributions.jl distributions as Gen.jl distributions."
struct WrappedDistribution{T,D <: Distributions.Distribution} <: Gen.Distribution{T}
dist::D
end
WrappedDistribution(d::D) where {D <: UnivariateDistribution} =
WrappedDistribution{eltype(D),D}(d)
WrappedDistribution{safe_eltype(d),D}(d)
WrappedDistribution(d::D) where {D <: MultivariateDistribution} =
WrappedDistribution{Vector{eltype(D)},D}(d)
WrappedDistribution{Vector{safe_eltype(d)},D}(d)
WrappedDistribution(d::D) where {D <: MatrixDistribution} =
WrappedDistribution{Matrix{eltype(D)},D}(d)
WrappedDistribution{Matrix{safe_eltype(d)},D}(d)
WrappedDistribution(d::Truncated{D}) where {D} =
WrappedDistribution{eltype(D),typeof(d)}(d)
WrappedDistribution{safe_eltype(d),typeof(d)}(d)

(d::WrappedDistribution)(args...) = Gen.random(d, args...)

Expand Down

0 comments on commit 52ddc5e

Please sign in to comment.