Skip to content

Commit

Permalink
Add StaticArrays broken tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Jul 30, 2024
1 parent abe165f commit e5bb445
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 13 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ julia = "1"
DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["DiffTests", "FillArrays", "IrrationalConstants", "Test"]
test = ["DiffTests", "FillArrays", "IrrationalConstants", "StaticArrays", "Test"]
8 changes: 4 additions & 4 deletions src/api/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ function seeded_reverse_pass!(result::DiffResult, output::AbstractArray, input::
end

function seeded_reverse_pass!(result::Tuple, output::AbstractArray, input::Tuple, tape)
for i in eachindex(result)
result = map(eachindex(result, input)) do i
seeded_reverse_pass!(result[i], output, input[i], tape)
end
return result
Expand All @@ -75,14 +75,14 @@ end
#####################

function extract_result!(result::Tuple, output, input::Tuple)
for i in eachindex(result)
result = map(eachindex(result, input)) do i
extract_result!(result[i], output, input[i])
end
return result
end

function extract_result!(result::Tuple, output)
for i in eachindex(result)
result = map(eachindex(result)) do i
extract_result!(result[i], output)
end
return result
Expand Down Expand Up @@ -111,7 +111,7 @@ function extract_result!(result::DiffResult, output::Number)
end

function extract_result_value!(result::Tuple, output)
for i in eachindex(result)
result = map(eachindex(result)) do i
extract_result_value!(result[i], output)
end
return result
Expand Down
36 changes: 28 additions & 8 deletions test/compat/CompatTests.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,33 @@
module CompatTests

using FillArrays, ReverseDiff, Test
using DiffResults, FillArrays, StaticArrays, ReverseDiff, Test

@test ReverseDiff.gradient(fill(2.0, 3)) do x
sum(abs2.(x .- Zeros(3)))
end == fill(4.0, 3)
@testset "FillArrays" begin
@test ReverseDiff.gradient(fill(2.0, 3)) do x
sum(abs2.(x .- Zeros(3)))
end == fill(4.0, 3)

@test ReverseDiff.gradient(fill(2.0, 3)) do x
sum(abs2.(x .- (1:3)))
end == [2, 0, -2]
@test ReverseDiff.gradient(fill(2.0, 3)) do x
sum(abs2.(x .- (1:3)))
end == [2, 0, -2]
end

end
sumabs2(x) = sum(abs2, x)

@testset "StaticArrays" begin
@testset "Gradient" begin
x = MVector{2}(3.0, 4.0)
result = ReverseDiff.gradient!(DiffResults.GradientResult(x), sumabs2, x)
@test_broken x == [3.0, 4.0]
@test_broken DiffResults.value(result) == 25.0
end

@testset "Hessian" begin
x = MVector{2}(3.0, 4.0)
result = ReverseDiff.hessian!(DiffResults.HessianResult(x), sumabs2, x)
@test_broken x == [3.0, 4.0]
@test_broken DiffResults.value(result) == 25.0
end
end

end

0 comments on commit e5bb445

Please sign in to comment.