diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 982621c0..c1a5fdc3 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -15,8 +15,6 @@ jobs: test: name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} runs-on: ${{ matrix.os }} - env: - ADSUITE: General strategy: fail-fast: false matrix: diff --git a/.github/workflows/Enzyme.yml b/.github/workflows/Enzyme.yml index 14eafa5d..64335480 100644 --- a/.github/workflows/Enzyme.yml +++ b/.github/workflows/Enzyme.yml @@ -15,7 +15,7 @@ jobs: test: name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} env: - ADSUITE: Enzyme + TEST_GROUP: Enzyme runs-on: ${{ matrix.os }} strategy: fail-fast: false diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index 5c320972..125da71c 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -1,5 +1,7 @@ -AD_distributionsad = if AD_GROUP == "General" +AD_repgradelbo_distributionsad = if TEST_GROUP == "Enzyme" + Dict(:Enzyme => AutoEnzyme()) +else Dict( :ForwarDiff => AutoForwardDiff(), #:ReverseDiff => AutoReverseDiff(), # DistributionsAD doesn't support ReverseDiff at the moment @@ -7,8 +9,6 @@ AD_distributionsad = if AD_GROUP == "General" :Mooncake => AutoMooncake(; config=Mooncake.Config()), :Enzyme => AutoEnzyme(), ) -elseif AD_GROUP == "Enzyme" - Dict(:Enzyme => AutoEnzyme()) end @testset "inference RepGradELBO DistributionsAD" begin @@ -21,7 +21,7 @@ end :RepGradELBOStickingTheLanding => RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), ), - (adbackname, adtype) in AD_distributionsad + (adbackname, adtype) in AD_repgradelbo_distributionsad seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl index 1ec761ad..4a84526b 100644 --- a/test/inference/repgradelbo_locationscale.jl +++ b/test/inference/repgradelbo_locationscale.jl @@ -1,13 +1,13 @@ -AD_repgradelbo_locationscale = if AD_GROUP == "General" +AD_repgradelbo_locationscale = if TEST_GROUP == "Enzyme" + Dict(:Enzyme => AutoEnzyme()) +else Dict( :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), :Mooncake => AutoMooncake(; config=Mooncake.Config()), ) -elseif AD_GROUP == "Enzyme" - Dict(:Enzyme => AutoEnzyme()) end @testset "inference ScoreGradELBO VILocationScale" begin diff --git a/test/inference/repgradelbo_locationscale_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl index d2d6cefe..5b197cc4 100644 --- a/test/inference/repgradelbo_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_locationscale_bijectors.jl @@ -1,13 +1,13 @@ -AD_repgradelbo_locationscale_bijectors = if AD_GROUP == "General" +AD_repgradelbo_locationscale_bijectors = if TEST_GROUP == "Enzyme" + Dict(:Enzyme => AutoEnzyme()) +else Dict( :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), :Mooncake => AutoMooncake(; config=Mooncake.Config()), ) -elseif AD_GROUP == "Enzyme" - Dict(:Enzyme => AutoEnzyme()) end @testset "inference RepGradELBO VILocationScale Bijectors" begin diff --git a/test/inference/scoregradelbo_distributionsad.jl b/test/inference/scoregradelbo_distributionsad.jl index 94a55368..962b9e03 100644 --- a/test/inference/scoregradelbo_distributionsad.jl +++ b/test/inference/scoregradelbo_distributionsad.jl @@ -1,13 +1,13 @@ -AD_scoregradelbo_distributionsad = if AD_GROUP == "General" +AD_scoregradelbo_distributionsad = if TEST_GROUP == "Enzyme" + Dict(:Enzyme => AutoEnzyme()) +else Dict( :ForwarDiff => AutoForwardDiff(), #:ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), #:Mooncake => AutoMooncake(; config=Mooncake.Config()), ) -elseif AD_GROUP == "Enzyme" - Dict(:Enzyme => AutoEnzyme()) end @testset "inference ScoreGradELBO DistributionsAD" begin diff --git a/test/inference/scoregradelbo_locationscale.jl b/test/inference/scoregradelbo_locationscale.jl index 226d8a46..5a389e7c 100644 --- a/test/inference/scoregradelbo_locationscale.jl +++ b/test/inference/scoregradelbo_locationscale.jl @@ -1,13 +1,13 @@ -AD_scoregradelbo_locationscale = if AD_GROUP == "General" +AD_scoregradelbo_locationscale = if TEST_GROUP == "Enzyme" + Dict(:Enzyme => AutoEnzyme()) +else Dict( :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), :Mooncake => AutoMooncake(; config=Mooncake.Config()), ) -elseif AD_GROUP == "Enzyme" - Dict(:Enzyme => AutoEnzyme()) end @testset "inference ScoreGradELBO VILocationScale" begin diff --git a/test/inference/scoregradelbo_locationscale_bijectors.jl b/test/inference/scoregradelbo_locationscale_bijectors.jl index fe06d97b..22be98ba 100644 --- a/test/inference/scoregradelbo_locationscale_bijectors.jl +++ b/test/inference/scoregradelbo_locationscale_bijectors.jl @@ -1,13 +1,13 @@ -AD_scoregradelbo_locationscale_bijectors = if AD_GROUP == "General" +AD_scoregradelbo_locationscale_bijectors = if TEST_GROUP == "Enzyme" + Dict(:Enzyme => AutoEnzyme()) +else Dict( :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), #:Zygote => AutoZygote(), #:Mooncake => AutoMooncake(; config=Mooncake.Config()), ) -elseif AD_GROUP == "Enzyme" - Dict(:Enzyme => AutoEnzyme()) end @testset "inference ScoreGradELBO VILocationScale Bijectors" begin diff --git a/test/interface/ad.jl b/test/interface/ad.jl index 7bd78cb8..5f8e8f0f 100644 --- a/test/interface/ad.jl +++ b/test/interface/ad.jl @@ -1,15 +1,15 @@ using Test -AD_interface = if AD_GROUP == "General" +AD_interface = if TEST_GROUP == "Enzyme" + Dict(:Enzyme => AutoEnzyme()) +else Dict( :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), :Mooncake => AutoMooncake(; config=Mooncake.Config()), ) -elseif AD_GROUP == "Enzyme" - Dict(:Enzyme => AutoEnzyme()) end @testset "ad" begin diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl index a9f6ed66..b34dc659 100644 --- a/test/interface/repgradelbo.jl +++ b/test/interface/repgradelbo.jl @@ -27,15 +27,15 @@ using Test end end -AD_repgradelbo_stl = if AD_GROUP == "General" - Dict( - :ForwarDiff => AutoForwardDiff(), - :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), - :Mooncake => AutoMooncake(; config=Mooncake.Config()), - ) -elseif AD_GROUP == "Enzyme" - Dict(:Enzyme => AutoEnzyme()) +AD_repgradelbo_stl = if TEST_GROUP == "Enzyme" + [AutoEnzyme()] +else + [ + AutoForwardDiff(), + AutoReverseDiff(), + AutoZygote(), + AutoMooncake(; config=Mooncake.Config()), + ] end @testset "interface RepGradELBO STL variance reduction" begin @@ -45,7 +45,7 @@ end modelstats = normal_meanfield(rng, Float64) (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats - @testset for (_, adtype) in AD_repgradelbo_stl + @testset for adtype in AD_repgradelbo_stl q_true = MeanFieldGaussian( Vector{eltype(μ_true)}(μ_true), Diagonal(Vector{eltype(L_true)}(diag(L_true))) ) diff --git a/test/runtests.jl b/test/runtests.jl index 6c3def92..8e0cbc11 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,7 +26,6 @@ using ForwardDiff, ReverseDiff, Zygote, Mooncake, Enzyme using AdvancedVI const TEST_GROUP = get(ENV, "TEST_GROUP", "All") -const AD_GROUP = get(ENV, "AD_GROUP", "General") # Models for Inference Tests struct TestModel{M,L,S,SC} @@ -40,14 +39,18 @@ end include("models/normal.jl") include("models/normallognormal.jl") -# Tests if TEST_GROUP == "All" || TEST_GROUP == "Interface" - include("interface/ad.jl") + # Interface tests that do not involve testing on Enzyme include("interface/optimize.jl") - include("interface/repgradelbo.jl") - include("interface/scoregradelbo.jl") include("interface/rules.jl") include("interface/averaging.jl") + include("interface/scoregradelbo.jl") +end + +if TEST_GROUP == "All" || TEST_GROUP == "Interface" || TEST_GROUP == "Enzyme" + # Interface tests that involve testing on Enzyme + include("interface/ad.jl") + include("interface/repgradelbo.jl") end if TEST_GROUP == "All" || TEST_GROUP == "Families" @@ -57,7 +60,7 @@ end const PROGRESS = haskey(ENV, "PROGRESS") -if TEST_GROUP == "All" || TEST_GROUP == "Inference" +if TEST_GROUP == "All" || TEST_GROUP == "Inference" || TEST_GROUP == "Enzyme" include("inference/repgradelbo_distributionsad.jl") include("inference/repgradelbo_locationscale.jl") include("inference/repgradelbo_locationscale_bijectors.jl")