-
-
Notifications
You must be signed in to change notification settings - Fork 47
feat: use DI for structured Jacobians #470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,7 +37,6 @@ Construct a cache for the Jacobian of `f` w.r.t. `u`. | |
stats::NLStats | ||
autodiff | ||
di_extras | ||
sdifft_extras | ||
end | ||
|
||
function reinit_cache!(cache::JacobianCache{iip}, args...; p = cache.p, | ||
|
@@ -63,31 +62,13 @@ function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing, | |
|
||
if !has_analytic_jac && needs_jac | ||
autodiff = construct_concrete_adtype(f, autodiff) | ||
using_sparsedifftools = autodiff isa StructuredMatrixAutodiff | ||
# SparseMatrixColorings can't handle structured matrices | ||
if using_sparsedifftools | ||
di_extras = nothing | ||
uf = JacobianWrapper{iip}(f, p) | ||
sdifft_extras = if iip | ||
sparse_jacobian_cache( | ||
autodiff.autodiff, autodiff.sparsity_detection, uf, fu, u) | ||
else | ||
sparse_jacobian_cache(autodiff.autodiff, autodiff.sparsity_detection, | ||
uf, __maybe_mutable(u, autodiff); fx = fu) | ||
end | ||
autodiff = autodiff.autodiff # For saving we unwrap | ||
di_extras = if iip | ||
DI.prepare_jacobian(f, fu, autodiff, u, Constant(prob.p)) | ||
else | ||
sdifft_extras = nothing | ||
di_extras = if iip | ||
DI.prepare_jacobian(f, fu, autodiff, u, Constant(prob.p)) | ||
else | ||
DI.prepare_jacobian(f, autodiff, u, Constant(prob.p)) | ||
end | ||
DI.prepare_jacobian(f, autodiff, u, Constant(prob.p)) | ||
end | ||
else | ||
using_sparsedifftools = false | ||
di_extras = nothing | ||
sdifft_extras = nothing | ||
end | ||
|
||
J = if !needs_jac | ||
|
@@ -98,22 +79,18 @@ function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing, | |
JacobianOperator(prob, fu, u; jvp_autodiff, vjp_autodiff) | ||
else | ||
if f.jac_prototype === nothing | ||
if !using_sparsedifftools | ||
# While this is technically wasteful, it gives out the type of the Jacobian | ||
# which is needed to create the linear solver cache | ||
stats.njacs += 1 | ||
if has_analytic_jac | ||
__similar( | ||
fu, promote_type(eltype(fu), eltype(u)), length(fu), length(u)) | ||
# While this is technically wasteful, it gives out the type of the Jacobian | ||
# which is needed to create the linear solver cache | ||
stats.njacs += 1 | ||
if has_analytic_jac | ||
__similar( | ||
fu, promote_type(eltype(fu), eltype(u)), length(fu), length(u)) | ||
else | ||
if iip | ||
DI.jacobian(f, fu, di_extras, autodiff, u, Constant(p)) | ||
else | ||
if iip | ||
DI.jacobian(f, fu, di_extras, autodiff, u, Constant(p)) | ||
else | ||
DI.jacobian(f, autodiff, u, Constant(p)) | ||
end | ||
DI.jacobian(f, autodiff, u, Constant(p)) | ||
end | ||
else | ||
zero(init_jacobian(sdifft_extras; preserve_immutable = Val(true))) | ||
end | ||
else | ||
jac_proto = if eltype(f.jac_prototype) <: Bool | ||
|
@@ -126,20 +103,19 @@ function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing, | |
end | ||
end | ||
|
||
return JacobianCache{iip}( | ||
J, f, fu, u, p, stats, autodiff, di_extras, sdifft_extras) | ||
return JacobianCache{iip}(J, f, fu, u, p, stats, autodiff, di_extras) | ||
end | ||
|
||
function JacobianCache(prob, alg, f::F, ::Number, u::Number, p; stats, | ||
autodiff = nothing, kwargs...) where {F} | ||
fu = f(u, p) | ||
if SciMLBase.has_jac(f) || SciMLBase.has_vjp(f) || SciMLBase.has_jvp(f) | ||
return JacobianCache{false}(u, f, fu, u, p, stats, autodiff, nothing, nothing) | ||
return JacobianCache{false}(u, f, fu, u, p, stats, autodiff, nothing) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same question about cases where DI preparation is skipped There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only where the user provides an analytic jacobian. This one is a special case for scalars when DI is skipped if jacobian/jvp/vjp is provided. |
||
end | ||
autodiff = get_dense_ad(get_concrete_forward_ad( | ||
autodiff, prob; check_forward_mode = false)) | ||
di_extras = DI.prepare_derivative(f, get_dense_ad(autodiff), u, Constant(prob.p)) | ||
return JacobianCache{false}(u, f, fu, u, p, stats, autodiff, di_extras, nothing) | ||
return JacobianCache{false}(u, f, fu, u, p, stats, autodiff, di_extras) | ||
end | ||
|
||
(cache::JacobianCache)(u = cache.u) = cache(cache.J, u, cache.p) | ||
|
@@ -172,27 +148,16 @@ function (cache::JacobianCache{iip})( | |
if iip | ||
if SciMLBase.has_jac(cache.f) | ||
cache.f.jac(J, u, p) | ||
elseif cache.di_extras !== nothing | ||
else | ||
DI.jacobian!( | ||
cache.f, cache.fu, J, cache.di_extras, cache.autodiff, u, Constant(p)) | ||
else | ||
uf = JacobianWrapper{iip}(cache.f, p) | ||
sparse_jacobian!(J, cache.autodiff, cache.sdifft_extras, uf, cache.fu, u) | ||
end | ||
return J | ||
else | ||
if SciMLBase.has_jac(cache.f) | ||
return cache.f.jac(u, p) | ||
elseif cache.di_extras !== nothing | ||
return DI.jacobian(cache.f, cache.di_extras, cache.autodiff, u, Constant(p)) | ||
else | ||
uf = JacobianWrapper{iip}(cache.f, p) | ||
if __can_setindex(typeof(J)) | ||
sparse_jacobian!(J, cache.autodiff, cache.sdifft_extras, uf, u) | ||
return J | ||
else | ||
return sparse_jacobian(cache.autodiff, cache.sdifft_extras, uf, u) | ||
end | ||
return DI.jacobian(cache.f, cache.di_extras, cache.autodiff, u, Constant(p)) | ||
end | ||
end | ||
end | ||
|
@@ -207,10 +172,13 @@ function construct_concrete_adtype(f::NonlinearFunction, ad::AbstractADType) | |
end | ||
return ad # No sparse AD | ||
else | ||
if ArrayInterface.isstructured(f.jac_prototype) | ||
return select_fastest_structured_matrix_autodiff(f.jac_prototype, f, ad) | ||
if !sparse_or_structured_prototype(f.jac_prototype) | ||
if SciMLBase.has_colorvec(f) | ||
@warn "`colorvec` is provided but `jac_prototype` is not a sparse \ | ||
or structured matrix. `colorvec` will be ignored." | ||
end | ||
return ad | ||
end | ||
|
||
return AutoSparse( | ||
ad; | ||
sparsity_detector = KnownJacobianSparsityDetector(f.jac_prototype), | ||
|
@@ -220,17 +188,14 @@ function construct_concrete_adtype(f::NonlinearFunction, ad::AbstractADType) | |
end | ||
else | ||
if f.sparsity isa AbstractMatrix | ||
if f.jac_prototype !== nothing && f.jac_prototype !== f.sparsity | ||
throw(ArgumentError("`sparsity::AbstractMatrix` and `jac_prototype` cannot \ | ||
be both provided. Pass only `jac_prototype`.")) | ||
end | ||
Base.depwarn("`sparsity::typeof($(typeof(f.sparsity)))` is deprecated. \ | ||
Pass it as `jac_prototype` instead.", | ||
:NonlinearSolve) | ||
if ArrayInterface.isstructured(f.sparsity) | ||
return select_fastest_structured_matrix_autodiff(f.sparsity, f, ad) | ||
if f.jac_prototype !== f.sparsity | ||
if f.jac_prototype !== nothing && | ||
sparse_or_structured_prototype(f.jac_prototype) | ||
throw(ArgumentError("`sparsity::AbstractMatrix` and a sparse or \ | ||
structured `jac_prototype` cannot be both \ | ||
provided. Pass only `jac_prototype`.")) | ||
end | ||
end | ||
|
||
return AutoSparse( | ||
ad; | ||
sparsity_detector = KnownJacobianSparsityDetector(f.sparsity), | ||
|
@@ -252,11 +217,7 @@ function construct_concrete_adtype(f::NonlinearFunction, ad::AbstractADType) | |
coloring_algorithm = GreedyColoringAlgorithm(LargestFirst()) | ||
) | ||
else | ||
if ArrayInterface.isstructured(f.jac_prototype) | ||
return select_fastest_structured_matrix_autodiff(f.jac_prototype, f, ad) | ||
end | ||
|
||
if f.jac_prototype isa AbstractSparseMatrix | ||
if sparse_or_structured_prototype(f.jac_prototype) | ||
if !(sparsity_detector isa NoSparsityDetector) | ||
@warn "`jac_prototype` is a sparse matrix but sparsity = $(f.sparsity) \ | ||
has also been specified. Ignoring sparsity field and using \ | ||
|
@@ -275,38 +236,6 @@ function construct_concrete_adtype(f::NonlinearFunction, ad::AbstractADType) | |
end | ||
end | ||
|
||
@concrete struct StructuredMatrixAutodiff <: AbstractADType | ||
autodiff <: AbstractADType | ||
sparsity_detection | ||
end | ||
|
||
function select_fastest_structured_matrix_autodiff( | ||
prototype::AbstractMatrix, f::NonlinearFunction, ad::AbstractADType) | ||
sparsity_detection = if SciMLBase.has_colorvec(f) | ||
PrecomputedJacobianColorvec(; | ||
jac_prototype = prototype, | ||
f.colorvec, | ||
partition_by_rows = ADTypes.mode(ad) isa ADTypes.ReverseMode | ||
) | ||
else | ||
if ArrayInterface.fast_matrix_colors(prototype) | ||
colorvec = if ADTypes.mode(ad) isa ADTypes.ForwardMode | ||
ArrayInterface.matrix_colors(prototype) | ||
else | ||
ArrayInterface.matrix_colors(prototype') | ||
end | ||
PrecomputedJacobianColorvec(; | ||
jac_prototype = prototype, | ||
colorvec, | ||
partition_by_rows = ADTypes.mode(ad) isa ADTypes.ReverseMode | ||
) | ||
else | ||
JacPrototypeSparsityDetection(; jac_prototype = prototype) | ||
end | ||
end | ||
return StructuredMatrixAutodiff(AutoSparse(ad), sparsity_detection) | ||
end | ||
|
||
function select_fastest_coloring_algorithm( | ||
prototype, f::NonlinearFunction, ad::AbstractADType) | ||
if SciMLBase.has_colorvec(f) | ||
|
@@ -332,3 +261,8 @@ end | |
|
||
get_dense_ad(ad) = ad | ||
get_dense_ad(ad::AutoSparse) = ADTypes.dense_ad(ad) | ||
|
||
sparse_or_structured_prototype(::AbstractSparseMatrix) = true | ||
function sparse_or_structured_prototype(prototype::AbstractMatrix) | ||
return ArrayInterface.isstructured(prototype) | ||
end |
Uh oh!
There was an error while loading. Please reload this page.