Skip to content

Commit

Permalink
Merge pull request #90 from SciML/type2
Browse files Browse the repository at this point in the history
More extensive handling of dispatching on type
  • Loading branch information
ChrisRackauckas authored Dec 31, 2023
2 parents e379918 + 11ec603 commit b52a345
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 10 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PreallocationTools"
uuid = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
version = "0.4.14"
version = "0.4.15"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -10,9 +10,11 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

[weakdeps]
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

[extensions]
PreallocationToolsReverseDiffExt = "ReverseDiff"
PreallocationToolsSymbolicsExt = "Symbolics"

[compat]
Adapt = "3.4, 4"
Expand Down
49 changes: 49 additions & 0 deletions ext/PreallocationToolsSymbolicsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
module PreallocationToolsSymbolicsExt

using PreallocationTools
import PreallocationTools: _restructure, get_tmp
using Symbolics, ForwardDiff

function get_tmp(dc::DiffCache, u::Type{X}) where {T,N, X<: ForwardDiff.Dual{T, Num, N}}
if length(dc.du) > length(dc.any_du)
resize!(dc.any_du, length(dc.du))
end
_restructure(dc.du, dc.any_du)
end

function get_tmp(dc::DiffCache, u::X) where {T,N, X<: ForwardDiff.Dual{T, Num, N}}
if length(dc.du) > length(dc.any_du)
resize!(dc.any_du, length(dc.du))
end
_restructure(dc.du, dc.any_du)
end

function get_tmp(dc::DiffCache, u::AbstractArray{X}) where {T,N, X<: ForwardDiff.Dual{T, Num, N}}
if length(dc.du) > length(dc.any_du)
resize!(dc.any_du, length(dc.du))
end
_restructure(dc.du, dc.any_du)
end

function get_tmp(dc::FixedSizeDiffCache, u::Type{X}) where {T,N, X<: ForwardDiff.Dual{T, Num, N}}
if length(dc.du) > length(dc.any_du)
resize!(dc.any_du, length(dc.du))
end
_restructure(dc.du, dc.any_du)
end

function get_tmp(dc::FixedSizeDiffCache, u::X) where {T,N, X<: ForwardDiff.Dual{T, Num, N}}
if length(dc.du) > length(dc.any_du)
resize!(dc.any_du, length(dc.du))
end
_restructure(dc.du, dc.any_du)
end

function get_tmp(dc::FixedSizeDiffCache, u::AbstractArray{X}) where {T,N, X<: ForwardDiff.Dual{T, Num, N}}
if length(dc.du) > length(dc.any_du)
resize!(dc.any_du, length(dc.du))
end
_restructure(dc.du, dc.any_du)
end

end
25 changes: 24 additions & 1 deletion src/PreallocationTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,17 @@ function get_tmp(dc::FixedSizeDiffCache, u::Union{Number, AbstractArray})
end
end

function get_tmp(dc::FixedSizeDiffCache, ::Type{T}) where T <: Number
if promote_type(eltype(dc.du), T) <: eltype(dc.du)
dc.du
else
if length(dc.du) > length(dc.any_du)
resize!(dc.any_du, length(dc.du))
end
_restructure(dc.du, dc.any_du)
end
end

# DiffCache

struct DiffCache{T <: AbstractArray, S <: AbstractArray}
Expand Down Expand Up @@ -125,7 +136,7 @@ function get_tmp(dc::DiffCache, u::T) where {T <: ForwardDiff.Dual}
_restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
end

function get_tmp(dc::DiffCache, u::Type{T}) where {T <: ForwardDiff.Dual}
function get_tmp(dc::DiffCache, ::Type{T}) where {T <: ForwardDiff.Dual}
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
if nelem > length(dc.dual_du)
enlargediffcache!(dc, nelem)
Expand Down Expand Up @@ -153,6 +164,18 @@ function get_tmp(dc::DiffCache, u::Union{Number, AbstractArray})
end
end

function get_tmp(dc::DiffCache, ::Type{T}) where T <: Number
if promote_type(eltype(dc.du), T) <: eltype(dc.du)
dc.du
else
if length(dc.du) > length(dc.any_du)
resize!(dc.any_du, length(dc.du))
end

_restructure(dc.du, dc.any_du)
end
end

get_tmp(dc, u) = dc

function _restructure(normal_cache::Array, duals)
Expand Down
16 changes: 8 additions & 8 deletions test/core_odes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using LinearAlgebra,

#Base array
function foo(du, u, (A, tmp), t)
tmp = get_tmp(tmp, u)
tmp = get_tmp(tmp, promote_type(eltype(u), typeof(t)))
mul!(tmp, A, u)
@. du = u + tmp
nothing
Expand All @@ -15,23 +15,23 @@ u0 = ones(5, 5)
A = ones(5, 5)
cache = DiffCache(zeros(5, 5), chunk_size)
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, u0, (0.0, 1.0), (A, cache))
sol = solve(prob, TRBDF2(chunk_size = chunk_size))
sol = solve(prob, Rodas5P(chunk_size = chunk_size))
@test sol.retcode == ReturnCode.Success

cache = FixedSizeDiffCache(zeros(5, 5), chunk_size)
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, u0, (0.0, 1.0), (A, cache))
sol = solve(prob, TRBDF2(chunk_size = chunk_size))
sol = solve(prob, Rodas5P(chunk_size = chunk_size))
@test sol.retcode == ReturnCode.Success

#with auto-detected chunk_size
cache = DiffCache(zeros(5, 5))
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, ones(5, 5), (0.0, 1.0), (A, cache))
sol = solve(prob, TRBDF2())
sol = solve(prob, Rodas5P())
@test sol.retcode == ReturnCode.Success

prob = ODEProblem(foo, ones(5, 5), (0.0, 1.0),
(ones(5, 5), FixedSizeDiffCache(zeros(5, 5))))
sol = solve(prob, TRBDF2())
sol = solve(prob, Rodas5P())
@test sol.retcode == ReturnCode.Success

#Base array with LBC
Expand All @@ -43,7 +43,7 @@ function foo(du, u, (A, lbc), t)
end
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, ones(5, 5), (0.0, 1.0),
(ones(5, 5), LazyBufferCache()))
sol = solve(prob, TRBDF2())
sol = solve(prob, Rodas5P())
@test sol.retcode == ReturnCode.Success

#LArray
Expand All @@ -60,9 +60,9 @@ end
chunk_size = 4
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, u0, (0.0, 1.0),
(A, DiffCache(c, chunk_size)))
sol = solve(prob, TRBDF2(chunk_size = chunk_size))
sol = solve(prob, Rodas5P(chunk_size = chunk_size))
@test sol.retcode == ReturnCode.Success
#with auto-detected chunk_size
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, u0, (0.0, 1.0), (A, DiffCache(c)))
sol = solve(prob, TRBDF2())
sol = solve(prob, Rodas5P())
@test sol.retcode == ReturnCode.Success
23 changes: 23 additions & 0 deletions test/sparsity_support.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,26 @@ A = Symbolics.jacobian_sparsity(wrap_res_dae_cache!, output, x0)
B = sparse([1 0; 0 1])
@test A == B
@test findnz(A) == findnz(B)

# Test Nesting https://discourse.julialang.org/t/preallocationtools-jl-with-nested-forwarddiff-and-sparsity-pattern-detection-errors/107897

function foo(x, cache)
d = get_tmp(cache, x)

d[:] = x

0.5 * x'*x
end

function residual(r, x, cache)
function foo_wrap(x)
foo(x, cache)
end

r[:] = ForwardDiff.gradient(foo_wrap, x)
end

cache = DiffCache(zeros(2))
pattern = Symbolics.jacobian_sparsity((r, x) -> residual(r, x, cache), zeros(2), zeros(2))
@test pattern == sparse([1 0
0 1])

0 comments on commit b52a345

Please sign in to comment.