Skip to content

Commit f59a89c

Browse files
committed
Move ADTypeCheckContexts and optimisation tests to ad.jl as well
1 parent 5d654c7 commit f59a89c

File tree

8 files changed

+37
-50
lines changed

8 files changed

+37
-50
lines changed

test/ad.jl

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ using Turing
44
using DynamicPPL
55
using DynamicPPL.TestUtils: DEMO_MODELS
66
using DynamicPPL.TestUtils.AD: run_ad
7+
using Random: Random
78
using StableRNGs: StableRNG
89
using Test
910
using ..Models: gdemo_default
11+
import ForwardDiff, ReverseDiff, Mooncake
1012

1113
"""Element types that are always valid for a VarInfo regardless of ADType."""
1214
const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational)
@@ -181,17 +183,49 @@ ADTYPES = [
181183
Turing.AutoMooncake(; config=nothing),
182184
]
183185

186+
# Check that ADTypeCheckContext itself works as expected.
187+
@testset "ADTypeCheckContext" begin
188+
@model test_model() = x ~ Normal(0, 1)
189+
tm = test_model()
190+
adtypes = (
191+
Turing.AutoForwardDiff(),
192+
Turing.AutoReverseDiff(),
193+
# TODO: Mooncake
194+
# Turing.AutoMooncake(config=nothing),
195+
)
196+
for actual_adtype in adtypes
197+
sampler = Turing.HMC(0.1, 5; adtype=actual_adtype)
198+
for expected_adtype in adtypes
199+
contextualised_tm = DynamicPPL.contextualize(
200+
tm, ADTypeCheckContext(expected_adtype, tm.context)
201+
)
202+
@testset "Expected: $expected_adtype, Actual: $actual_adtype" begin
203+
if actual_adtype == expected_adtype
204+
# Check that this does not throw an error.
205+
Turing.sample(contextualised_tm, sampler, 2)
206+
else
207+
@test_throws AbstractWrongADBackendError Turing.sample(
208+
contextualised_tm, sampler, 2
209+
)
210+
end
211+
end
212+
end
213+
end
214+
end
215+
184216
@testset verbose = true "AD / ADTypeCheckContext" begin
185-
# This testset ensures that samplers don't accidentally override the AD
186-
# backend set in it.
187-
@testset "Check ADType" begin
217+
# This testset ensures that samplers or optimisers don't accidentally
218+
# override the AD backend set in it.
219+
@testset "adtype=$adtype" for adtype in ADTYPES
188220
seed = 123
189221
alg = HMC(0.1, 10; adtype=adtype)
190222
m = DynamicPPL.contextualize(
191223
gdemo_default, ADTypeCheckContext(adtype, gdemo_default.context)
192224
)
193225
# These will error if the adbackend being used is not the one set.
194226
sample(StableRNG(seed), m, alg, 10)
227+
maximum_likelihood(m; adtype=adtype)
228+
maximum_a_posteriori(m; adtype=adtype)
195229
end
196230
end
197231

test/mcmc/Inference.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ module InferenceTests
22

33
using ..Models: gdemo_d, gdemo_default
44
using ..NumericalTests: check_gdemo, check_numerical
5-
import ..ADUtils
65
using Distributions: Bernoulli, Beta, InverseGamma, Normal
76
using Distributions: sample
87
import DynamicPPL

test/mcmc/abstractmcmc.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
module AbstractMCMCTests
22

3-
import ..ADUtils
43
using AbstractMCMC: AbstractMCMC
54
using AdvancedMH: AdvancedMH
65
using Distributions: sample

test/mcmc/gibbs.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ using ..NumericalTests:
77
check_gdemo,
88
check_numerical,
99
two_sample_test
10-
import ..ADUtils
1110
import Combinatorics
1211
using AbstractMCMC: AbstractMCMC
1312
using Distributions: InverseGamma, Normal

test/mcmc/hmc.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
module HMCTests
22

33
using ..Models: gdemo_default
4-
using ..ADUtils: ADTypeCheckContext
54
using ..NumericalTests: check_gdemo, check_numerical
6-
import ..ADUtils
75
using Bijectors: Bijectors
86
using Distributions: Bernoulli, Beta, Categorical, Dirichlet, Normal, Wishart, sample
97
using DynamicPPL: DynamicPPL, Sampler

test/mcmc/sghmc.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ module SGHMCTests
22

33
using ..Models: gdemo_default
44
using ..NumericalTests: check_gdemo
5-
import ..ADUtils
65
using DynamicPPL.TestUtils.AD: run_ad
76
using DynamicPPL.TestUtils: DEMO_MODELS
87
using DynamicPPL: DynamicPPL

test/optimisation/Optimisation.jl

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
module OptimisationTests
22

33
using ..Models: gdemo, gdemo_default
4-
using ..ADUtils: ADUtils
54
using Distributions
65
using Distributions.FillArrays: Zeros
76
using DynamicPPL: DynamicPPL
@@ -624,16 +623,6 @@ using Turing
624623
@assert get(result, :c) == (; :c => Array{Float64}[])
625624
end
626625

627-
@testset "ADType test with $adbackend" for adbackend in ADUtils.adbackends
628-
Random.seed!(222)
629-
m = DynamicPPL.contextualize(
630-
gdemo_default, ADUtils.ADTypeCheckContext(adbackend, gdemo_default.context)
631-
)
632-
# These will error if the adbackend being used is not the one set.
633-
maximum_likelihood(m; adtype=adbackend)
634-
maximum_a_posteriori(m; adtype=adbackend)
635-
end
636-
637626
@testset "Collinear coeftable" begin
638627
xs = [-1.0, 0.0, 1.0]
639628
ys = [0.0, 0.0, 0.0]

test/test_utils/test_utils.jl

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,34 +8,4 @@ using Test: @test, @testset, @test_throws
88
using Turing: Turing
99
using Turing: DynamicPPL
1010

11-
# Check that the ADTypeCheckContext works as expected.
12-
@testset "ADTypeCheckContext" begin
13-
Turing.@model test_model() = x ~ Turing.Normal(0, 1)
14-
tm = test_model()
15-
adtypes = (
16-
Turing.AutoForwardDiff(),
17-
Turing.AutoReverseDiff(),
18-
# TODO: Mooncake
19-
# Turing.AutoMooncake(config=nothing),
20-
)
21-
for actual_adtype in adtypes
22-
sampler = Turing.HMC(0.1, 5; adtype=actual_adtype)
23-
for expected_adtype in adtypes
24-
contextualised_tm = DynamicPPL.contextualize(
25-
tm, ADTypeCheckContext(expected_adtype, tm.context)
26-
)
27-
@testset "Expected: $expected_adtype, Actual: $actual_adtype" begin
28-
if actual_adtype == expected_adtype
29-
# Check that this does not throw an error.
30-
Turing.sample(contextualised_tm, sampler, 2)
31-
else
32-
@test_throws AbstractWrongADBackendError Turing.sample(
33-
contextualised_tm, sampler, 2
34-
)
35-
end
36-
end
37-
end
38-
end
39-
end
40-
4111
end

0 commit comments

Comments
 (0)