Skip to content

Unify argument order in phasepoint and transition #435

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 1, 2025
4 changes: 4 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# AdvancedHMC Changelog

## 0.8.0

- To make an MCMC transtion from phasepoint `z` using trajectory `τ`(or HMCKernel `κ`) under Hamiltonian `h`, use `transition(h, τ, z)` or `transition(rng, h, τ, z)`(if using HMCKernel, use `transition(h, κ, z)` or `transition(rng, h, κ, z)`).

## v0.7.1

- README has been simplified, many docs transfered to docs: https://turinglang.org/AdvancedHMC.jl/dev/.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AdvancedHMC"
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
version = "0.7.1"
version = "0.8.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"

[compat]
AdvancedHMC = "0.7"
AdvancedHMC = "0.8"
Documenter = "1"
DocumenterCitations = "1"
2 changes: 1 addition & 1 deletion src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function transition(
(; refreshment, τ) = κ
@set! τ.integrator = jitter(rng, τ.integrator)
z = refresh(rng, refreshment, h, z)
return transition(rng, τ, h, z)
return transition(rng, h, τ, z)
end

function Adaptation.adapt!(
Expand Down
10 changes: 5 additions & 5 deletions src/trajectory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,10 @@ $(SIGNATURES)

Make a MCMC transition from phase point `z` using the trajectory `τ` under Hamiltonian `h`.

NOTE: This is a RNG-implicit fallback function for `transition(Random.default_rng(), τ, h, z)`
NOTE: This is a RNG-implicit fallback function for `transition(Random.default_rng(), h, τ, z)`
"""
function transition(τ::Trajectory, h::Hamiltonian, z::PhasePoint)
return transition(Random.default_rng(), τ, h, z)
function transition(h::Hamiltonian, τ::Trajectory, z::PhasePoint)
return transition(Random.default_rng(), h, τ, z)
end

###
Expand All @@ -256,8 +256,8 @@ end

function transition(
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
τ::Trajectory{TS,I,TC},
h::Hamiltonian,
τ::Trajectory{TS,I,TC},
z::PhasePoint,
) where {TS<:AbstractTrajectorySampler,I,TC<:StaticTerminationCriterion}
H0 = energy(z)
Expand Down Expand Up @@ -665,7 +665,7 @@ function build_tree(
end

function transition(
rng::AbstractRNG, τ::Trajectory{TS,I,TC}, h::Hamiltonian, z0::PhasePoint
rng::AbstractRNG, h::Hamiltonian, τ::Trajectory{TS,I,TC}, z0::PhasePoint
) where {
TS<:AbstractTrajectorySampler,I<:AbstractIntegrator,TC<:DynamicTerminationCriterion
}
Expand Down
79 changes: 41 additions & 38 deletions test/CUDA/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,47 +11,50 @@ using LogDensityProblems
include(joinpath(@__DIR__, "..", "common.jl"))

@testset "AdvancedHMC GPU" begin
n_chains = 1000
n_samples = 1000
dim = 5

T = Float32
m, s, θ₀ = zeros(T, dim), ones(T, dim), rand(T, dim, n_chains)
m, s, θ₀ = CuArray(m), CuArray(s), CuArray(θ₀)

target = Gaussian(m, s)
metric = UnitEuclideanMetric(T, size(θ₀))
ℓπ, ∇ℓπ = get_ℓπ(target), get_∇ℓπ(target)
hamiltonian = Hamiltonian(metric, ℓπ, ∇ℓπ)
integrator = Leapfrog(one(T) / 5)
proposal = HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(5)))

samples, stats = sample(hamiltonian, proposal, θ₀, n_samples)
if CUDA.functional()
n_chains = 1000
n_samples = 1000
dim = 5
T = Float32
m, s, θ₀ = zeros(T, dim), ones(T, dim), rand(T, dim, n_chains)
m, s, θ₀ = CuArray(m), CuArray(s), CuArray(θ₀)
target = Gaussian(m, s)
metric = UnitEuclideanMetric(T, size(θ₀))
ℓπ, ∇ℓπ = get_ℓπ(target), get_∇ℓπ(target)
hamiltonian = Hamiltonian(metric, ℓπ, ∇ℓπ)
integrator = Leapfrog(one(T) / 5)
proposal = HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(5)))
samples, stats = sample(hamiltonian, proposal, θ₀, n_samples)
else
println("GPU tests are skipped because no CUDA devices are found.")
end
end

@testset "PhasePoint GPU" begin
for T in [Float32, Float64]
function init_z1()
return PhasePoint(
CuArray([T(NaN) T(NaN)]),
CuArray([T(NaN) T(NaN)]),
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
)
if CUDA.functional()
for T in [Float32, Float64]
function init_z1()
return PhasePoint(
CuArray([T(NaN) T(NaN)]),
CuArray([T(NaN) T(NaN)]),
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
)
end
function init_z2()
return PhasePoint(
CuArray([T(Inf) T(Inf)]),
CuArray([T(Inf) T(Inf)]),
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
)
end
z1 = init_z1()
z2 = init_z2()
@test z1.ℓπ.value == z2.ℓπ.value
@test z1.ℓκ.value == z2.ℓκ.value
end
function init_z2()
return PhasePoint(
CuArray([T(Inf) T(Inf)]),
CuArray([T(Inf) T(Inf)]),
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
)
end

z1 = init_z1()
z2 = init_z2()

@test z1.ℓπ.value == z2.ℓπ.value
@test z1.ℓκ.value == z2.ℓκ.value
else
println("GPU tests are skipped because no CUDA devices are found.")
end
end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReTest = "e0db7c4e-2690-44b9-bad6-7687da720f89"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand Down
6 changes: 3 additions & 3 deletions test/quality.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
using AdvancedHMC
using ReTest
using Test: Test
using Aqua: Aqua
using JET
using ForwardDiff

@testset "Aqua" begin
Test.@testset "Aqua" begin
Aqua.test_all(AdvancedHMC)
end

@testset "JET" begin
Test.@testset "JET" begin
JET.test_package(AdvancedHMC; target_defined_modules=true)
end
4 changes: 2 additions & 2 deletions test/trajectory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ end
for τ_test in [τ, τ_with_jittered_lf], seed in [1234, 5678, 90]
rng = MersenneTwister(seed)
z = AdvancedHMC.phasepoint(h, θ_init, r_init)
z1′ = AdvancedHMC.transition(rng, τ_test, h, z).z
z1′ = AdvancedHMC.transition(rng, h, τ_test, z).z

rng = MersenneTwister(seed)
z = AdvancedHMC.phasepoint(h, θ_init, r_init)
z2′ = AdvancedHMC.transition(rng, τ_test, h, z).z
z2′ = AdvancedHMC.transition(rng, h, τ_test, z).z

@test z1′.θ == z2′.θ
@test z1′.r == z2′.r
Expand Down
Loading