diff --git a/Project.toml b/Project.toml index c0dd5e7..a5edb83 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PreallocationTools" uuid = "d236fae5-4411-538c-8e31-a6e3d9e00b46" authors = ["Chris Rackauckas "] -version = "0.4.14" +version = "0.4.15" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -10,9 +10,11 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" [weakdeps] ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" [extensions] PreallocationToolsReverseDiffExt = "ReverseDiff" +PreallocationToolsSymbolicsExt = "Symbolics" [compat] Adapt = "3.4, 4" diff --git a/ext/PreallocationToolsSymbolicsExt.jl b/ext/PreallocationToolsSymbolicsExt.jl new file mode 100644 index 0000000..3b3a24d --- /dev/null +++ b/ext/PreallocationToolsSymbolicsExt.jl @@ -0,0 +1,49 @@ +module PreallocationToolsSymbolicsExt + +using PreallocationTools +import PreallocationTools: _restructure, get_tmp +using Symbolics, ForwardDiff + +function get_tmp(dc::DiffCache, u::Type{X}) where {T,N, X<: ForwardDiff.Dual{T, Num, N}} + if length(dc.du) > length(dc.any_du) + resize!(dc.any_du, length(dc.du)) + end + _restructure(dc.du, dc.any_du) +end + +function get_tmp(dc::DiffCache, u::X) where {T,N, X<: ForwardDiff.Dual{T, Num, N}} + if length(dc.du) > length(dc.any_du) + resize!(dc.any_du, length(dc.du)) + end + _restructure(dc.du, dc.any_du) +end + +function get_tmp(dc::DiffCache, u::AbstractArray{X}) where {T,N, X<: ForwardDiff.Dual{T, Num, N}} + if length(dc.du) > length(dc.any_du) + resize!(dc.any_du, length(dc.du)) + end + _restructure(dc.du, dc.any_du) +end + +function get_tmp(dc::FixedSizeDiffCache, u::Type{X}) where {T,N, X<: ForwardDiff.Dual{T, Num, N}} + if length(dc.du) > length(dc.any_du) + resize!(dc.any_du, length(dc.du)) + end + _restructure(dc.du, dc.any_du) +end + +function get_tmp(dc::FixedSizeDiffCache, u::X) where {T,N, X<: ForwardDiff.Dual{T, Num, N}} + if length(dc.du) > length(dc.any_du) + resize!(dc.any_du, length(dc.du)) + end + _restructure(dc.du, dc.any_du) +end + +function get_tmp(dc::FixedSizeDiffCache, u::AbstractArray{X}) where {T,N, X<: ForwardDiff.Dual{T, Num, N}} + if length(dc.du) > length(dc.any_du) + resize!(dc.any_du, length(dc.du)) + end + _restructure(dc.du, dc.any_du) +end + +end diff --git a/src/PreallocationTools.jl b/src/PreallocationTools.jl index c5ca91a..d0f4425 100644 --- a/src/PreallocationTools.jl +++ b/src/PreallocationTools.jl @@ -75,6 +75,17 @@ function get_tmp(dc::FixedSizeDiffCache, u::Union{Number, AbstractArray}) end end +function get_tmp(dc::FixedSizeDiffCache, ::Type{T}) where T <: Number + if promote_type(eltype(dc.du), T) <: eltype(dc.du) + dc.du + else + if length(dc.du) > length(dc.any_du) + resize!(dc.any_du, length(dc.du)) + end + _restructure(dc.du, dc.any_du) + end +end + # DiffCache struct DiffCache{T <: AbstractArray, S <: AbstractArray} @@ -125,7 +136,7 @@ function get_tmp(dc::DiffCache, u::T) where {T <: ForwardDiff.Dual} _restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem))) end -function get_tmp(dc::DiffCache, u::Type{T}) where {T <: ForwardDiff.Dual} +function get_tmp(dc::DiffCache, ::Type{T}) where {T <: ForwardDiff.Dual} nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du) if nelem > length(dc.dual_du) enlargediffcache!(dc, nelem) @@ -153,6 +164,18 @@ function get_tmp(dc::DiffCache, u::Union{Number, AbstractArray}) end end +function get_tmp(dc::DiffCache, ::Type{T}) where T <: Number + if promote_type(eltype(dc.du), T) <: eltype(dc.du) + dc.du + else + if length(dc.du) > length(dc.any_du) + resize!(dc.any_du, length(dc.du)) + end + + _restructure(dc.du, dc.any_du) + end +end + get_tmp(dc, u) = dc function _restructure(normal_cache::Array, duals) diff --git a/test/core_odes.jl b/test/core_odes.jl index 8653a7d..239128b 100644 --- a/test/core_odes.jl +++ b/test/core_odes.jl @@ -4,7 +4,7 @@ using LinearAlgebra, #Base array function foo(du, u, (A, tmp), t) - tmp = get_tmp(tmp, u) + tmp = get_tmp(tmp, promote_type(eltype(u), typeof(t))) mul!(tmp, A, u) @. du = u + tmp nothing @@ -15,23 +15,23 @@ u0 = ones(5, 5) A = ones(5, 5) cache = DiffCache(zeros(5, 5), chunk_size) prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, u0, (0.0, 1.0), (A, cache)) -sol = solve(prob, TRBDF2(chunk_size = chunk_size)) +sol = solve(prob, Rodas5P(chunk_size = chunk_size)) @test sol.retcode == ReturnCode.Success cache = FixedSizeDiffCache(zeros(5, 5), chunk_size) prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, u0, (0.0, 1.0), (A, cache)) -sol = solve(prob, TRBDF2(chunk_size = chunk_size)) +sol = solve(prob, Rodas5P(chunk_size = chunk_size)) @test sol.retcode == ReturnCode.Success #with auto-detected chunk_size cache = DiffCache(zeros(5, 5)) prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, ones(5, 5), (0.0, 1.0), (A, cache)) -sol = solve(prob, TRBDF2()) +sol = solve(prob, Rodas5P()) @test sol.retcode == ReturnCode.Success prob = ODEProblem(foo, ones(5, 5), (0.0, 1.0), (ones(5, 5), FixedSizeDiffCache(zeros(5, 5)))) -sol = solve(prob, TRBDF2()) +sol = solve(prob, Rodas5P()) @test sol.retcode == ReturnCode.Success #Base array with LBC @@ -43,7 +43,7 @@ function foo(du, u, (A, lbc), t) end prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, ones(5, 5), (0.0, 1.0), (ones(5, 5), LazyBufferCache())) -sol = solve(prob, TRBDF2()) +sol = solve(prob, Rodas5P()) @test sol.retcode == ReturnCode.Success #LArray @@ -60,9 +60,9 @@ end chunk_size = 4 prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, u0, (0.0, 1.0), (A, DiffCache(c, chunk_size))) -sol = solve(prob, TRBDF2(chunk_size = chunk_size)) +sol = solve(prob, Rodas5P(chunk_size = chunk_size)) @test sol.retcode == ReturnCode.Success #with auto-detected chunk_size prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, u0, (0.0, 1.0), (A, DiffCache(c))) -sol = solve(prob, TRBDF2()) +sol = solve(prob, Rodas5P()) @test sol.retcode == ReturnCode.Success diff --git a/test/sparsity_support.jl b/test/sparsity_support.jl index b53273c..e6eba6d 100644 --- a/test/sparsity_support.jl +++ b/test/sparsity_support.jl @@ -48,3 +48,26 @@ A = Symbolics.jacobian_sparsity(wrap_res_dae_cache!, output, x0) B = sparse([1 0; 0 1]) @test A == B @test findnz(A) == findnz(B) + +# Test Nesting https://discourse.julialang.org/t/preallocationtools-jl-with-nested-forwarddiff-and-sparsity-pattern-detection-errors/107897 + +function foo(x, cache) + d = get_tmp(cache, x) + + d[:] = x + + 0.5 * x'*x +end + +function residual(r, x, cache) + function foo_wrap(x) + foo(x, cache) + end + + r[:] = ForwardDiff.gradient(foo_wrap, x) +end + +cache = DiffCache(zeros(2)) +pattern = Symbolics.jacobian_sparsity((r, x) -> residual(r, x, cache), zeros(2), zeros(2)) +@test pattern == sparse([1 0 + 0 1])