Skip to content

Commit

Permalink
refactor test group organizations
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Oct 21, 2024
1 parent b92d382 commit 09b81ee
Show file tree
Hide file tree
Showing 11 changed files with 42 additions and 41 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
runs-on: ${{ matrix.os }}
env:
ADSUITE: General
strategy:
fail-fast: false
matrix:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/Enzyme.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
env:
ADSUITE: Enzyme
TEST_GROUP: Enzyme
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
Expand Down
8 changes: 4 additions & 4 deletions test/inference/repgradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@

AD_distributionsad = if AD_GROUP == "General"
AD_repgradelbo_distributionsad = if TEST_GROUP == "Enzyme"
Dict(:Enzyme => AutoEnzyme())
else
Dict(
:ForwarDiff => AutoForwardDiff(),
#:ReverseDiff => AutoReverseDiff(), # DistributionsAD doesn't support ReverseDiff at the moment
:Zygote => AutoZygote(),
:Mooncake => AutoMooncake(; config=Mooncake.Config()),
:Enzyme => AutoEnzyme(),
)
elseif AD_GROUP == "Enzyme"
Dict(:Enzyme => AutoEnzyme())
end

@testset "inference RepGradELBO DistributionsAD" begin
Expand All @@ -21,7 +21,7 @@ end
:RepGradELBOStickingTheLanding =>
RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
),
(adbackname, adtype) in AD_distributionsad
(adbackname, adtype) in AD_repgradelbo_distributionsad

seed = (0x38bef07cf9cc549d)
rng = StableRNG(seed)
Expand Down
6 changes: 3 additions & 3 deletions test/inference/repgradelbo_locationscale.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@

AD_repgradelbo_locationscale = if AD_GROUP == "General"
AD_repgradelbo_locationscale = if TEST_GROUP == "Enzyme"
Dict(:Enzyme => AutoEnzyme())
else
Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
:Mooncake => AutoMooncake(; config=Mooncake.Config()),
)
elseif AD_GROUP == "Enzyme"
Dict(:Enzyme => AutoEnzyme())
end

@testset "inference ScoreGradELBO VILocationScale" begin
Expand Down
6 changes: 3 additions & 3 deletions test/inference/repgradelbo_locationscale_bijectors.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@

AD_repgradelbo_locationscale_bijectors = if AD_GROUP == "General"
AD_repgradelbo_locationscale_bijectors = if TEST_GROUP == "Enzyme"
Dict(:Enzyme => AutoEnzyme())
else
Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
:Mooncake => AutoMooncake(; config=Mooncake.Config()),
)
elseif AD_GROUP == "Enzyme"
Dict(:Enzyme => AutoEnzyme())
end

@testset "inference RepGradELBO VILocationScale Bijectors" begin
Expand Down
6 changes: 3 additions & 3 deletions test/inference/scoregradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@

AD_scoregradelbo_distributionsad = if AD_GROUP == "General"
AD_scoregradelbo_distributionsad = if TEST_GROUP == "Enzyme"
Dict(:Enzyme => AutoEnzyme())
else
Dict(
:ForwarDiff => AutoForwardDiff(),
#:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
#:Mooncake => AutoMooncake(; config=Mooncake.Config()),
)
elseif AD_GROUP == "Enzyme"
Dict(:Enzyme => AutoEnzyme())
end

@testset "inference ScoreGradELBO DistributionsAD" begin
Expand Down
6 changes: 3 additions & 3 deletions test/inference/scoregradelbo_locationscale.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@

AD_scoregradelbo_locationscale = if AD_GROUP == "General"
AD_scoregradelbo_locationscale = if TEST_GROUP == "Enzyme"
Dict(:Enzyme => AutoEnzyme())
else
Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
:Mooncake => AutoMooncake(; config=Mooncake.Config()),
)
elseif AD_GROUP == "Enzyme"
Dict(:Enzyme => AutoEnzyme())
end

@testset "inference ScoreGradELBO VILocationScale" begin
Expand Down
6 changes: 3 additions & 3 deletions test/inference/scoregradelbo_locationscale_bijectors.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@

AD_scoregradelbo_locationscale_bijectors = if AD_GROUP == "General"
AD_scoregradelbo_locationscale_bijectors = if TEST_GROUP == "Enzyme"
Dict(:Enzyme => AutoEnzyme())
else
Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
#:Zygote => AutoZygote(),
#:Mooncake => AutoMooncake(; config=Mooncake.Config()),
)
elseif AD_GROUP == "Enzyme"
Dict(:Enzyme => AutoEnzyme())
end

@testset "inference ScoreGradELBO VILocationScale Bijectors" begin
Expand Down
6 changes: 3 additions & 3 deletions test/interface/ad.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@

using Test

AD_interface = if AD_GROUP == "General"
AD_interface = if TEST_GROUP == "Enzyme"
Dict(:Enzyme => AutoEnzyme())
else
Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
:Mooncake => AutoMooncake(; config=Mooncake.Config()),
)
elseif AD_GROUP == "Enzyme"
Dict(:Enzyme => AutoEnzyme())
end

@testset "ad" begin
Expand Down
20 changes: 10 additions & 10 deletions test/interface/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ using Test
end
end

AD_repgradelbo_stl = if AD_GROUP == "General"
Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
:Mooncake => AutoMooncake(; config=Mooncake.Config()),
)
elseif AD_GROUP == "Enzyme"
Dict(:Enzyme => AutoEnzyme())
AD_repgradelbo_stl = if TEST_GROUP == "Enzyme"
[AutoEnzyme()]
else
[
AutoForwardDiff(),
AutoReverseDiff(),
AutoZygote(),
AutoMooncake(; config=Mooncake.Config()),
]
end

@testset "interface RepGradELBO STL variance reduction" begin
Expand All @@ -45,7 +45,7 @@ end
modelstats = normal_meanfield(rng, Float64)
(; model, μ_true, L_true, n_dims, is_meanfield) = modelstats

@testset for (_, adtype) in AD_repgradelbo_stl
@testset for adtype in AD_repgradelbo_stl
q_true = MeanFieldGaussian(
Vector{eltype(μ_true)}(μ_true), Diagonal(Vector{eltype(L_true)}(diag(L_true)))
)
Expand Down
15 changes: 9 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ using ForwardDiff, ReverseDiff, Zygote, Mooncake, Enzyme
using AdvancedVI

const TEST_GROUP = get(ENV, "TEST_GROUP", "All")
const AD_GROUP = get(ENV, "AD_GROUP", "General")

# Models for Inference Tests
struct TestModel{M,L,S,SC}
Expand All @@ -40,14 +39,18 @@ end
include("models/normal.jl")
include("models/normallognormal.jl")

# Tests
if TEST_GROUP == "All" || TEST_GROUP == "Interface"
include("interface/ad.jl")
# Interface tests that do not involve testing on Enzyme
include("interface/optimize.jl")
include("interface/repgradelbo.jl")
include("interface/scoregradelbo.jl")
include("interface/rules.jl")
include("interface/averaging.jl")
include("interface/scoregradelbo.jl")
end

if TEST_GROUP == "All" || TEST_GROUP == "Interface" || TEST_GROUP == "Enzyme"
# Interface tests that involve testing on Enzyme
include("interface/ad.jl")
include("interface/repgradelbo.jl")
end

if TEST_GROUP == "All" || TEST_GROUP == "Families"
Expand All @@ -57,7 +60,7 @@ end

const PROGRESS = haskey(ENV, "PROGRESS")

if TEST_GROUP == "All" || TEST_GROUP == "Inference"
if TEST_GROUP == "All" || TEST_GROUP == "Inference" || TEST_GROUP == "Enzyme"
include("inference/repgradelbo_distributionsad.jl")
include("inference/repgradelbo_locationscale.jl")
include("inference/repgradelbo_locationscale_bijectors.jl")
Expand Down

0 comments on commit 09b81ee

Please sign in to comment.