From d485aa0ddb9008f1ebe8d71d93fe68e2fcd851ab Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 12 Jun 2024 20:48:09 -0700 Subject: [PATCH] Mark some of the tests as broken --- Project.toml | 1 + test/layer_tests.jl | 10 ++++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index ac7a29d..7f1394d 100644 --- a/Project.toml +++ b/Project.toml @@ -43,6 +43,7 @@ Artifacts = "1.10" ChainRulesCore = "1.24" ComponentArrays = "0.15.13" ConcreteStructs = "0.2.3" +DataInterpolations = "5.2.0" ExplicitImports = "1.5" ForwardDiff = "0.10.36" GPUArraysCore = "0.1.6" diff --git a/test/layer_tests.jl b/test/layer_tests.jl index 080bc24..b29017d 100644 --- a/test/layer_tests.jl +++ b/test/layer_tests.jl @@ -150,8 +150,8 @@ end ongpu && continue @testset "$(spl): train_grid $(train_grid), dims $(dims)" for spl in ( - ConstantInterpolation, LinearInterpolation, QuadraticInterpolation, - QuadraticSpline, CubicSpline), + ConstantInterpolation, LinearInterpolation, + QuadraticInterpolation, QuadraticSpline, CubicSpline), train_grid in (true, false), dims in ((), (8,)) @@ -164,12 +164,14 @@ end y, st = spline(x, ps, st) @test size(y) == (dims..., 4) - @jet spline(x, ps, st) + opt_broken = !ongpu && dims != () && spl !== ConstantInterpolation + + @jet spline(x, ps, st) opt_broken=opt_broken # See SciML/DataInterpolations.jl/issues/267 y, st = spline(x, ps_ca, st) @test size(y) == (dims..., 4) - @jet spline(x, ps_ca, st) + @jet spline(x, ps_ca, st) opt_broken=opt_broken # See SciML/DataInterpolations.jl/issues/267 ∂x, ∂ps = Zygote.gradient((x, ps) -> sum(abs2, first(spline(x, ps, st))), x, ps) spl !== ConstantInterpolation && @test ∂x !== nothing