Skip to content

Remove AD backend loops in test suite #2564

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
Expand Down Expand Up @@ -52,7 +53,7 @@ Combinatorics = "1"
Distributions = "0.25"
DistributionsAD = "0.6.3"
DynamicHMC = "2.1.6, 3.0"
DynamicPPL = "0.36"
DynamicPPL = "0.36.6"
FiniteDifferences = "0.10.8, 0.11, 0.12"
ForwardDiff = "0.10.12 - 0.10.32, 0.10"
HypothesisTests = "0.11"
Expand Down
136 changes: 119 additions & 17 deletions test/test_utils/ad_utils.jl → test/ad.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
module ADUtils
module TuringADTests

using ForwardDiff: ForwardDiff
using Pkg: Pkg
using Turing
using DynamicPPL
using DynamicPPL.TestUtils: DEMO_MODELS
using DynamicPPL.TestUtils.AD: run_ad
using Random: Random
using ReverseDiff: ReverseDiff
using Mooncake: Mooncake
using Test: Test
using Turing: Turing
using Turing: DynamicPPL

export ADTypeCheckContext, adbackends

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Stuff for checking that the right AD backend is being used.
using StableRNGs: StableRNG
using Test
using ..Models: gdemo_default
import ForwardDiff, ReverseDiff, Mooncake

"""Element types that are always valid for a VarInfo regardless of ADType."""
const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational)
Expand Down Expand Up @@ -178,16 +174,122 @@ function DynamicPPL.tilde_observe(context::ADTypeCheckContext, sampler, right, l
return logp, vi
end

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# List of AD backends to test.

"""
All the ADTypes on which we want to run the tests.
"""
adbackends = [
ADTYPES = [
Turing.AutoForwardDiff(),
Turing.AutoReverseDiff(; compile=false),
Turing.AutoMooncake(; config=nothing),
]

# Check that ADTypeCheckContext itself works as expected.
@testset "ADTypeCheckContext" begin
@model test_model() = x ~ Normal(0, 1)
tm = test_model()
adtypes = (
Turing.AutoForwardDiff(),
Turing.AutoReverseDiff(),
# Don't need to test Mooncake as it doesn't use tracer types
)
for actual_adtype in adtypes
sampler = Turing.HMC(0.1, 5; adtype=actual_adtype)
for expected_adtype in adtypes
contextualised_tm = DynamicPPL.contextualize(
tm, ADTypeCheckContext(expected_adtype, tm.context)
)
@testset "Expected: $expected_adtype, Actual: $actual_adtype" begin
if actual_adtype == expected_adtype
# Check that this does not throw an error.
Turing.sample(contextualised_tm, sampler, 2)
else
@test_throws AbstractWrongADBackendError Turing.sample(
contextualised_tm, sampler, 2
)
end
end
end
end
end

@testset verbose = true "AD / ADTypeCheckContext" begin
# This testset ensures that samplers or optimisers don't accidentally
# override the AD backend set in it.
@testset "adtype=$adtype" for adtype in ADTYPES
seed = 123
alg = HMC(0.1, 10; adtype=adtype)
m = DynamicPPL.contextualize(
gdemo_default, ADTypeCheckContext(adtype, gdemo_default.context)
)
# These will error if the adbackend being used is not the one set.
sample(StableRNG(seed), m, alg, 10)
maximum_likelihood(m; adtype=adtype)
maximum_a_posteriori(m; adtype=adtype)
end
end

@testset verbose = true "AD / SamplingContext" begin
# AD tests for gradient-based samplers need to be run with SamplingContext
# because samplers can potentially use this to define custom behaviour in
# the tilde-pipeline and thus change the code executed during model
# evaluation.
@testset "adtype=$adtype" for adtype in ADTYPES
@testset "alg=$alg" for alg in [
HMC(0.1, 10; adtype=adtype),
HMCDA(0.8, 0.75; adtype=adtype),
NUTS(1000, 0.8; adtype=adtype),
SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adtype),
SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adtype),
]
@info "Testing AD for $alg"

@testset "model=$(model.f)" for model in DEMO_MODELS
rng = StableRNG(123)
ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg))
@test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any
end
end
end
end

@testset verbose = true "AD / GibbsContext" begin
# Gibbs sampling also needs extra AD testing because the models are
# executed with GibbsContext and a subsetted varinfo. (see e.g.
# `gibbs_initialstep_recursive` and `gibbs_step_recursive` in
# src/mcmc/gibbs.jl -- the code here mimics what happens in those
# functions)
@testset "adtype=$adtype" for adtype in ADTYPES
@testset "model=$(model.f)" for model in DEMO_MODELS
# All the demo models have variables `s` and `m`, so we'll pretend
# that we're using a Gibbs sampler where both of them are sampled
# with a gradient-based sampler (say HMC(0.1, 10)).
# This means we need to construct one with only `s`, and one model with
# only `m`.
global_vi = DynamicPPL.VarInfo(model)
@testset for varnames in ([@varname(s)], [@varname(m)])
@info "Testing Gibbs AD with model=$(model.f), varnames=$varnames"
conditioned_model = Turing.Inference.make_conditional(
model, varnames, deepcopy(global_vi)
)
rng = StableRNG(123)
ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(HMC(0.1, 10)))
@test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any
end
end
end
end

@testset verbose = true "AD / Gibbs sampling" begin
# Make sure that Gibbs sampling doesn't fall over when using AD.
@testset "adtype=$adtype" for adtype in ADTYPES
spl = Gibbs(
@varname(s) => HMC(0.1, 10; adtype=adtype),
@varname(m) => HMC(0.1, 10; adtype=adtype),
)
@testset "model=$(model.f)" for model in DEMO_MODELS
@test sample(model, spl, 2) isa Any
end
end
end

end # module
56 changes: 22 additions & 34 deletions test/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ module InferenceTests

using ..Models: gdemo_d, gdemo_default
using ..NumericalTests: check_gdemo, check_numerical
import ..ADUtils
using Distributions: Bernoulli, Beta, InverseGamma, Normal
using Distributions: sample
import DynamicPPL
Expand All @@ -17,8 +16,9 @@ import Mooncake
using Test: @test, @test_throws, @testset
using Turing

@testset "Testing inference.jl with $adbackend" for adbackend in ADUtils.adbackends
@info "Starting Inference.jl tests with $adbackend"
@testset verbose = true "Testing Inference.jl" begin
@info "Starting Inference.jl tests"

seed = 23

@testset "threaded sampling" begin
Expand All @@ -27,12 +27,12 @@ using Turing
model = gdemo_default

samplers = (
HMC(0.1, 7; adtype=adbackend),
HMC(0.1, 7),
PG(10),
IS(),
MH(),
Gibbs(:s => PG(3), :m => HMC(0.4, 8; adtype=adbackend)),
Gibbs(:s => HMC(0.1, 5; adtype=adbackend), :m => ESS()),
Gibbs(:s => PG(3), :m => HMC(0.4, 8)),
Gibbs(:s => HMC(0.1, 5), :m => ESS()),
)
for sampler in samplers
Random.seed!(5)
Expand All @@ -44,7 +44,7 @@ using Turing
@test chain1.value == chain2.value
end

# Should also be stable with am explicit RNG
# Should also be stable with an explicit RNG
seed = 5
rng = Random.MersenneTwister(seed)
for sampler in samplers
Expand All @@ -61,27 +61,22 @@ using Turing
# Smoke test for default sample call.
@testset "gdemo_default" begin
chain = sample(
StableRNG(seed),
gdemo_default,
HMC(0.1, 7; adtype=adbackend),
MCMCThreads(),
1_000,
4,
StableRNG(seed), gdemo_default, HMC(0.1, 7), MCMCThreads(), 1_000, 4
)
check_gdemo(chain)

# run sampler: progress logging should be disabled and
# it should return a Chains object
sampler = Sampler(HMC(0.1, 7; adtype=adbackend))
sampler = Sampler(HMC(0.1, 7))
chains = sample(StableRNG(seed), gdemo_default, sampler, MCMCThreads(), 10, 4)
@test chains isa MCMCChains.Chains
end
end

@testset "chain save/resume" begin
alg1 = HMCDA(1000, 0.65, 0.15; adtype=adbackend)
alg1 = HMCDA(1000, 0.65, 0.15)
alg2 = PG(20)
alg3 = Gibbs(:s => PG(30), :m => HMC(0.2, 4; adtype=adbackend))
alg3 = Gibbs(:s => PG(30), :m => HMC(0.2, 4))

chn1 = sample(StableRNG(seed), gdemo_default, alg1, 10_000; save_state=true)
check_gdemo(chn1)
Expand Down Expand Up @@ -260,7 +255,7 @@ using Turing

smc = SMC()
pg = PG(10)
gibbs = Gibbs(:p => HMC(0.2, 3; adtype=adbackend), :x => PG(10))
gibbs = Gibbs(:p => HMC(0.2, 3), :x => PG(10))

chn_s = sample(StableRNG(seed), testbb(obs), smc, 200)
chn_p = sample(StableRNG(seed), testbb(obs), pg, 200)
Expand All @@ -273,22 +268,17 @@ using Turing

@testset "forbid global" begin
xs = [1.5 2.0]
# xx = 1

@model function fggibbstest(xs)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
# xx ~ Normal(m, sqrt(s)) # this is illegal

for i in 1:length(xs)
xs[i] ~ Normal(m, sqrt(s))
# for xx in xs
# xx ~ Normal(m, sqrt(s))
end
return s, m
end

gibbs = Gibbs(:s => PG(10), :m => HMC(0.4, 8; adtype=adbackend))
gibbs = Gibbs(:s => PG(10), :m => HMC(0.4, 8))
chain = sample(StableRNG(seed), fggibbstest(xs), gibbs, 2)
end

Expand Down Expand Up @@ -353,7 +343,7 @@ using Turing
)
end

# TODO(mhauru) What is this testing? Why does it not use the looped-over adbackend?
# TODO(mhauru) What is this testing? Why does it use a different adbackend?
@testset "new interface" begin
obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1]

Expand Down Expand Up @@ -382,9 +372,7 @@ using Turing
end
end

chain = sample(
StableRNG(seed), noreturn([1.5 2.0]), HMC(0.1, 10; adtype=adbackend), 4000
)
chain = sample(StableRNG(seed), noreturn([1.5 2.0]), HMC(0.1, 10), 4000)
check_numerical(chain, [:s, :m], [49 / 24, 7 / 6])
end

Expand Down Expand Up @@ -415,7 +403,7 @@ using Turing
end

@testset "sample" begin
alg = Gibbs(:m => HMC(0.2, 3; adtype=adbackend), :s => PG(10))
alg = Gibbs(:m => HMC(0.2, 3), :s => PG(10))
chn = sample(StableRNG(seed), gdemo_default, alg, 10)
end

Expand All @@ -427,7 +415,7 @@ using Turing
return s, m
end

alg = HMC(0.01, 5; adtype=adbackend)
alg = HMC(0.01, 5)
x = randn(100)
res = sample(StableRNG(seed), vdemo1(x), alg, 10)

Expand All @@ -442,7 +430,7 @@ using Turing

# Vector assumptions
N = 10
alg = HMC(0.2, 4; adtype=adbackend)
alg = HMC(0.2, 4)

@model function vdemo3()
x = Vector{Real}(undef, N)
Expand Down Expand Up @@ -497,7 +485,7 @@ using Turing
return s, m
end

alg = HMC(0.01, 5; adtype=adbackend)
alg = HMC(0.01, 5)
x = randn(100)
res = sample(StableRNG(seed), vdemo1(x), alg, 10)

Expand All @@ -507,12 +495,12 @@ using Turing
end

D = 2
alg = HMC(0.01, 5; adtype=adbackend)
alg = HMC(0.01, 5)
res = sample(StableRNG(seed), vdemo2(randn(D, 100)), alg, 10)

# Vector assumptions
N = 10
alg = HMC(0.2, 4; adtype=adbackend)
alg = HMC(0.2, 4)

@model function vdemo3()
x = Vector{Real}(undef, N)
Expand Down Expand Up @@ -559,7 +547,7 @@ using Turing

@testset "Type parameters" begin
N = 10
alg = HMC(0.01, 5; adtype=adbackend)
alg = HMC(0.01, 5)
x = randn(1000)
@model function vdemo1(::Type{T}=Float64) where {T}
x = Vector{T}(undef, N)
Expand Down
Loading
Loading