Skip to content

Commit

Permalink
Fix duplicate of include and some bugs in tests (#343)
Browse files Browse the repository at this point in the history
* Fix duplicate include of common.jl

* Bugfix.

* bump version

* Bugfix
  • Loading branch information
yebai authored Jul 28, 2023
1 parent 3a4b384 commit eb9b2e0
Show file tree
Hide file tree
Showing 13 changed files with 8 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AdvancedHMC"
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
version = "0.5.2"
version = "0.5.3"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
1 change: 0 additions & 1 deletion test/abstractmcmc.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using ReTest, Random, AdvancedHMC, ForwardDiff, AbstractMCMC
using Statistics: mean
include("common.jl")

@testset "AbstractMCMC w/ gdemo" begin
rng = MersenneTwister(0)
Expand Down
9 changes: 4 additions & 5 deletions test/constructors.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using AdvancedHMC, AbstractMCMC, Random
include("common.jl")

get_kernel_hyperparams(spl::HMC, state) = state.κ.τ.termination_criterion.L
get_kernel_hyperparams(spl::HMCDA, state) = state.κ.τ.termination_criterion.λ
Expand Down Expand Up @@ -169,13 +168,13 @@ end
spl = NUTS(0.8)
T = AdvancedHMC.sampler_eltype(spl)

metric = make_metric(spl, logdensity)
metric = AdvancedHMC.make_metric(spl, logdensity)
hamiltonian = Hamiltonian(metric, model)

init_params1 = make_init_params(rng, spl, logdensity, nothing)
init_params1 = AdvancedHMC.make_init_params(rng, spl, logdensity, nothing)
@test typeof(init_params1) == Vector{T}
@test length(init_params1) == d
init_params2 = make_init_params(rng, spl, logdensity, θ_init)
@test init_params2 === θ_init
init_params2 = AdvancedHMC.make_init_params(rng, spl, logdensity, θ_init)
@test init_params2 == θ_init
end
end
2 changes: 0 additions & 2 deletions test/contrib.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
using ReTest, AdvancedHMC, ForwardDiff, Zygote

include("common.jl")

@testset "contrib" begin
@testset "ad" begin
metric = UnitEuclideanMetric(D)
Expand Down
2 changes: 0 additions & 2 deletions test/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ using AdvancedHMC: DualValue, PhasePoint
using CUDA

@testset "AdvancedHMC GPU" begin
include("common.jl")

n_chains = 1000
n_samples = 1000
dim = 5
Expand Down
3 changes: 0 additions & 3 deletions test/hamiltonian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@ using ReTest, AdvancedHMC
using AdvancedHMC: GaussianKinetic, DualValue, PhasePoint
using LinearAlgebra: dot, diagm


include("common.jl")

@testset "Hamiltonian" begin
f = x -> dot(x, x)
g = x -> 2x
Expand Down
1 change: 0 additions & 1 deletion test/integrator.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using ReTest, Random, AdvancedHMC, ForwardDiff
include("common.jl")

using OrdinaryDiffEq
using LinearAlgebra: dot
Expand Down
1 change: 0 additions & 1 deletion test/mcmcchains.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using ReTest, Random, AdvancedHMC, ForwardDiff, AbstractMCMC, MCMCChains
using Statistics: mean
include("common.jl")

@testset "MCMCChains w/ gdemo" begin
rng = MersenneTwister(0)
Expand Down
2 changes: 0 additions & 2 deletions test/models.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
using ReTest, Random, AdvancedHMC, ForwardDiff
using Statistics: mean
include("common.jl")


@testset "Models" begin
@testset "gdemo" begin
Expand Down
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using FillArrays
using AdvancedHMC: AdvancedHMC
using LogDensityProblems: LogDensityProblems
using LogDensityProblemsAD: LogDensityProblemsAD
using ReTest

println("Environment variables for testing")
println(ENV)
Expand All @@ -11,6 +12,8 @@ const DIRECTORY_AdvancedHMC = dirname(dirname(pathof(AdvancedHMC)))
const DIRECTORY_Turing_tests = joinpath(DIRECTORY_AdvancedHMC, "test", "turing")
const GROUP = get(ENV, "AHMC_TEST_GROUP", "AdvancedHMC")

include("common.jl")

if GROUP == "All" || GROUP == "AdvancedHMC"
using ReTest, CUDA

Expand Down
1 change: 0 additions & 1 deletion test/sampler-vec.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using ReTest, AdvancedHMC, LinearAlgebra, UnicodePlots, Random
using Statistics: mean, var, cov
include("common.jl")

@testset "sample (vectorized)" begin
n_chains_max = 20
Expand Down
1 change: 0 additions & 1 deletion test/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ using AdvancedHMC: StaticTerminationCriterion, DynamicTerminationCriterion
using Setfield
using Statistics: mean, var, cov
unicodeplots()
include("common.jl")

function test_stats(
::Trajectory{TS,I,TC},
Expand Down
1 change: 0 additions & 1 deletion test/trajectory.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using ReTest, AdvancedHMC, Random
using Statistics: mean
using LinearAlgebra: dot
include("common.jl")

function makeplot(plt, traj_θ, ts_list...)
function plotturn!(traj_θ, ts)
Expand Down

2 comments on commit eb9b2e0

@yebai
Copy link
Member Author

@yebai yebai commented on eb9b2e0 Jul 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/88580

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.3 -m "<description of version>" eb9b2e0d60ef3dd85768d6e6a9f19de15b8f7130
git push origin v0.5.3

Please sign in to comment.