Skip to content

Commit

Permalink
remove SimpleUnPack
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Oct 8, 2024
1 parent 98dfa14 commit 82246ad
Show file tree
Hide file tree
Showing 20 changed files with 31 additions and 45 deletions.
6 changes: 0 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
Expand Down Expand Up @@ -54,20 +53,15 @@ ProgressMeter = "1.6"
Random = "1"
Requires = "1.0"
ReverseDiff = "1"
SimpleUnPack = "1.1.0"
StatsBase = "0.32, 0.33, 0.34"
Zygote = "0.6"
julia = "1.7"

[extras]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Pkg", "Test"]
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ a `LogDensityProblem` can be implemented as

```julia
using LogDensityProblems
using SimpleUnPack

struct NormalLogNormal{MX,SX,MY,SY}
μ_x::MX
Expand Down
2 changes: 0 additions & 2 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[compat]
Expand All @@ -25,6 +24,5 @@ LogDensityProblems = "2.1.1"
Optimisers = "0.3"
Plots = "1"
QuasiMonteCarlo = "0.3"
SimpleUnPack = "1"
StatsFuns = "1"
julia = "1.6"
5 changes: 2 additions & 3 deletions docs/src/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ Using the `LogDensityProblems` interface, we the model can be defined as follows

```@example elboexample
using LogDensityProblems
using SimpleUnPack
struct NormalLogNormal{MX,SX,MY,SY}
μ_x::MX
Expand All @@ -25,7 +24,7 @@ struct NormalLogNormal{MX,SX,MY,SY}
end
function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
@unpack μ_x, σ_x, μ_y, Σ_y = model
(; μ_x, σ_x, μ_y, Σ_y) = model
return logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
end
Expand Down Expand Up @@ -59,7 +58,7 @@ Thus, we will use [Bijectors](https://github.com/TuringLang/Bijectors.jl) to mat
using Bijectors
function Bijectors.bijector(model::NormalLogNormal)
@unpack μ_x, σ_x, μ_y, Σ_y = model
(; μ_x, σ_x, μ_y, Σ_y) = model
return Bijectors.Stacked(
Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]),
[1:1, 2:(1 + length(μ_y))],
Expand Down
1 change: 0 additions & 1 deletion src/AdvancedVI.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@

module AdvancedVI

using SimpleUnPack: @unpack, @pack!
using Accessors

using Random
Expand Down
16 changes: 8 additions & 8 deletions src/families/location_scale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ function (re::RestructureMeanField)(flat::AbstractVector)
end

function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L,E}) where {D,L,E}
@unpack location, scale, dist = q
(; location, scale, dist) = q
flat = vcat(location, diag(scale))
return flat, RestructureMeanField(q)
end
Expand All @@ -69,27 +69,27 @@ Base.size(q::MvLocationScale) = size(q.location)
Base.eltype(::Type{<:MvLocationScale{S,D,L,E}}) where {S,D,L,E} = eltype(D)

function StatsBase.entropy(q::MvLocationScale)
@unpack location, scale, dist = q
(; location, scale, dist) = q
n_dims = length(location)
# `convert` is necessary because `entropy` is not type stable upstream
return n_dims * convert(eltype(location), entropy(dist)) + logdet(scale)
end

function Distributions.logpdf(q::MvLocationScale, z::AbstractVector{<:Real})
@unpack location, scale, dist = q
(; location, scale, dist) = q
return sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - logdet(scale)
end

function Distributions.rand(q::MvLocationScale)
@unpack location, scale, dist = q
(; location, scale, dist) = q
n_dims = length(location)
return scale * rand(dist, n_dims) + location
end

function Distributions.rand(
rng::AbstractRNG, q::MvLocationScale{S,D,L}, num_samples::Int
) where {S,D,L}
@unpack location, scale, dist = q
(; location, scale, dist) = q
n_dims = length(location)
return scale * rand(rng, dist, n_dims, num_samples) .+ location
end
Expand All @@ -98,7 +98,7 @@ end
function Distributions.rand(
rng::AbstractRNG, q::MvLocationScale{<:Diagonal,D,L}, num_samples::Int
) where {L,D}
@unpack location, scale, dist = q
(; location, scale, dist) = q
n_dims = length(location)
scale_diag = diag(scale)
return scale_diag .* rand(rng, dist, n_dims, num_samples) .+ location
Expand All @@ -107,14 +107,14 @@ end
function Distributions._rand!(
rng::AbstractRNG, q::MvLocationScale, x::AbstractVecOrMat{<:Real}
)
@unpack location, scale, dist = q
(; location, scale, dist) = q
rand!(rng, dist, x)
x[:] = scale * x
return x .+= location
end

function Distributions.mean(q::MvLocationScale)
@unpack location, scale = q
(; location, scale) = q
return location + scale * Fill(mean(q.dist), length(location))
end

Expand Down
16 changes: 8 additions & 8 deletions src/families/location_scale_low_rank.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Base.size(q::MvLocationScaleLowRank) = size(q.location)
Base.eltype(::Type{<:MvLocationScaleLowRank{L,SD,SF,D,E}}) where {L,SD,SF,D,E} = eltype(L)

function StatsBase.entropy(q::MvLocationScaleLowRank)
@unpack location, scale_diag, scale_factors, dist = q
(; location, scale_diag, scale_factors, dist) = q
n_dims = length(location)
scale_diag2 = scale_diag .* scale_diag
UtDinvU = Hermitian(scale_factors' * (scale_factors ./ scale_diag2))
Expand All @@ -63,7 +63,7 @@ end
function Distributions.logpdf(
q::MvLocationScaleLowRank, z::AbstractVector{<:Real}; non_differntiable::Bool=false
)
@unpack location, scale_diag, scale_factors, dist = q
(; location, scale_diag, scale_factors, dist) = q
μ_base = mean(dist)
n_dims = length(location)

Expand All @@ -86,7 +86,7 @@ function Distributions.logpdf(
end

function Distributions.rand(q::MvLocationScaleLowRank)
@unpack location, scale_diag, scale_factors, dist = q
(; location, scale_diag, scale_factors, dist) = q
n_dims = length(location)
n_factors = size(scale_factors, 2)
u_diag = rand(dist, n_dims)
Expand All @@ -97,7 +97,7 @@ end
function Distributions.rand(
rng::AbstractRNG, q::MvLocationScaleLowRank{S,D,L}, num_samples::Int
) where {S,D,L}
@unpack location, scale_diag, scale_factors, dist = q
(; location, scale_diag, scale_factors, dist) = q
n_dims = length(location)
n_factors = size(scale_factors, 2)
u_diag = rand(rng, dist, n_dims, num_samples)
Expand All @@ -108,7 +108,7 @@ end
function Distributions._rand!(
rng::AbstractRNG, q::MvLocationScaleLowRank, x::AbstractVecOrMat{<:Real}
)
@unpack location, scale_diag, scale_factors, dist = q
(; location, scale_diag, scale_factors, dist) = q

rand!(rng, dist, x)
x[:] = scale_diag .* x
Expand All @@ -120,22 +120,22 @@ function Distributions._rand!(
end

function Distributions.mean(q::MvLocationScaleLowRank)
@unpack location, scale_diag, scale_factors = q
(; location, scale_diag, scale_factors) = q
μ = mean(q.dist)
return location +
scale_diag .* Fill(μ, length(scale_diag)) +
scale_factors * Fill(μ, size(scale_factors, 2))
end

function Distributions.var(q::MvLocationScaleLowRank)
@unpack scale_diag, scale_factors = q
(; scale_diag, scale_factors) = q
σ2 = var(q.dist)
return σ2 *
(scale_diag .* scale_diag + sum(scale_factors .* scale_factors; dims=2)[:, 1])
end

function Distributions.cov(q::MvLocationScaleLowRank)
@unpack scale_diag, scale_factors = q
(; scale_diag, scale_factors) = q
σ2 = var(q.dist)
return σ2 * (Diagonal(scale_diag .* scale_diag) + scale_factors * scale_factors')
end
Expand Down
2 changes: 1 addition & 1 deletion src/objectives/elbo/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ function estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int=obj.n_samp
end

function estimate_repgradelbo_ad_forward(params′, aux)
@unpack rng, obj, problem, adtype, restructure, q_stop = aux
(; rng, obj, problem, adtype, restructure, q_stop) = aux
q = restructure_ad_forward(adtype, restructure, params′)
samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy)
energy = estimate_energy_with_samples(problem, samples)
Expand Down
2 changes: 1 addition & 1 deletion src/objectives/elbo/scoregradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ function estimate_objective(obj::ScoreGradELBO, q, prob; n_samples::Int=obj.n_sa
end

function estimate_scoregradelbo_ad_forward(params′, aux)
@unpack rng, obj, problem, adtype, restructure, q_stop = aux
(; rng, obj, problem, adtype, restructure, q_stop) = aux
baseline = compute_control_variate_baseline(
obj.baseline_history, obj.baseline_window_size
)
Expand Down
2 changes: 0 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand All @@ -39,7 +38,6 @@ Optimisers = "0.2.16, 0.3"
PDMats = "0.11.7"
Random = "1"
ReverseDiff = "1.15.1"
SimpleUnPack = "1.1.0"
StableRNGs = "1.0.0"
Statistics = "1"
StatsBase = "0.34"
Expand Down
2 changes: 1 addition & 1 deletion test/inference/repgradelbo_locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ end
rng = StableRNG(seed)

modelstats = modelconstr(rng, realtype)
@unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats
(; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats

T = 1000
η = 1e-3
Expand Down
2 changes: 1 addition & 1 deletion test/inference/repgradelbo_locationscale_bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ end
rng = StableRNG(seed)

modelstats = modelconstr(rng, realtype)
@unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats
(; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats

T = 1000
η = 1e-3
Expand Down
2 changes: 1 addition & 1 deletion test/inference/scoregradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ end
rng = StableRNG(seed)

modelstats = modelconstr(rng, realtype)
@unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats
(; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats

T = 1000
η = 1e-5
Expand Down
2 changes: 1 addition & 1 deletion test/inference/scoregradelbo_locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ end
rng = StableRNG(seed)

modelstats = modelconstr(rng, realtype)
@unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats
(; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats

T = 1000
η = 1e-5
Expand Down
2 changes: 1 addition & 1 deletion test/inference/scoregradelbo_locationscale_bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ end
rng = StableRNG(seed)

modelstats = modelconstr(rng, realtype)
@unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats
(; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats

T = 1000
η = 1e-5
Expand Down
4 changes: 2 additions & 2 deletions test/interface/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using Test

modelstats = normal_meanfield(rng, Float64)

@unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
(; model, μ_true, L_true, n_dims, is_meanfield) = modelstats

q0 = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims))

Expand All @@ -32,7 +32,7 @@ end
rng = StableRNG(seed)

modelstats = normal_meanfield(rng, Float64)
@unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
(; model, μ_true, L_true, n_dims, is_meanfield) = modelstats

ad_backends = [
ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(), ADTypes.AutoZygote()
Expand Down
2 changes: 1 addition & 1 deletion test/interface/scoregradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using Test

modelstats = normal_meanfield(rng, Float64)

@unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
(; model, μ_true, L_true, n_dims, is_meanfield) = modelstats

q0 = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims))

Expand Down
2 changes: 1 addition & 1 deletion test/models/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ struct TestNormal{M,S}
end

function LogDensityProblems.logdensity(model::TestNormal, θ)
@unpack μ, Σ = model
(; μ, Σ) = model
return logpdf(MvNormal(μ, Σ), θ)
end

Expand Down
4 changes: 2 additions & 2 deletions test/models/normallognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ struct NormalLogNormal{MX,SX,MY,SY}
end

function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
@unpack μ_x, σ_x, μ_y, Σ_y = model
(; μ_x, σ_x, μ_y, Σ_y) = model
return logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
end

Expand All @@ -20,7 +20,7 @@ function LogDensityProblems.capabilities(::Type{<:NormalLogNormal})
end

function Bijectors.bijector(model::NormalLogNormal)
@unpack μ_x, σ_x, μ_y, Σ_y = model
(; μ_x, σ_x, μ_y, Σ_y) = model
return Bijectors.Stacked(
Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]),
[1:1, 2:(1 + length(μ_y))],
Expand Down
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ using Optimisers
using PDMats
using Pkg
using Random, StableRNGs
using SimpleUnPack: @unpack
using Statistics
using StatsBase

Expand Down

0 comments on commit 82246ad

Please sign in to comment.