Skip to content

Commit 1325478

Browse files
Merge pull request #54 from SciML/Diattempt2
[WIP] Fresh attempt at DI integration
2 parents 6ebd646 + c1a5e1f commit 1325478

19 files changed

+900
-3279
lines changed

.github/workflows/Downstream.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
julia-version: [1]
1818
os: [ubuntu-latest]
1919
package:
20-
- {user: SciML, repo: Optimization.jl, group: Optimization}
20+
- {user: SciML, repo: Optimization.jl, group: All}
2121

2222
steps:
2323
- uses: actions/checkout@v4

Project.toml

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ version = "1.3.3"
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
10+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1112
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
1213
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -16,15 +17,15 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1617
SymbolicAnalysis = "4297ee4d-0239-47d8-ba5d-195ecdf594fe"
1718
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
1819
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
20+
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
21+
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
1922

2023
[weakdeps]
2124
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
22-
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
2325
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
26+
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
2427
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
2528
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
26-
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
27-
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2829
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2930

3031
[extensions]
@@ -33,29 +34,21 @@ OptimizationFiniteDiffExt = "FiniteDiff"
3334
OptimizationForwardDiffExt = "ForwardDiff"
3435
OptimizationMTKExt = "ModelingToolkit"
3536
OptimizationReverseDiffExt = "ReverseDiff"
36-
OptimizationSparseDiffExt = ["SparseDiffTools", "ReverseDiff"]
37-
OptimizationTrackerExt = "Tracker"
3837
OptimizationZygoteExt = "Zygote"
3938

4039
[compat]
41-
ADTypes = "1.3"
40+
ADTypes = "1.5"
4241
ArrayInterface = "7.6"
42+
DifferentiationInterface = "0.5.2"
4343
DocStringExtensions = "0.9"
44-
Enzyme = "0.12.12"
45-
FiniteDiff = "2.12"
46-
ForwardDiff = "0.10.26"
4744
LinearAlgebra = "1.9, 1.10"
4845
ModelingToolkit = "9"
49-
PDMats = "0.11"
5046
Reexport = "1.2"
5147
Requires = "1"
52-
ReverseDiff = "1.14"
5348
SciMLBase = "2"
54-
SparseDiffTools = "2.14"
5549
SymbolicAnalysis = "0.1, 0.2"
5650
SymbolicIndexingInterface = "0.3"
5751
Symbolics = "5.12"
58-
Tracker = "0.2.29"
5952
Zygote = "0.6.67"
6053
julia = "1.10"
6154

ext/OptimizationEnzymeExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
136136
end
137137
Enzyme.make_zero!(y)
138138
Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache),
139-
BatchDuplicated(θ, seeds), Const(p), Const.(args)...)[1]
139+
BatchDuplicated(θ, seeds), Const(p), Const.(args)...)
140140
for i in 1:length(θ)
141141
if J isa Vector
142142
J[i] = Jaccache[i][1]
@@ -257,7 +257,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true},
257257
end
258258
Enzyme.make_zero!(y)
259259
Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache),
260-
BatchDuplicated(θ, seeds), Const(p), Const.(args)...)[1]
260+
BatchDuplicated(θ, seeds), Const(p), Const.(args)...)
261261
for i in 1:length(θ)
262262
if J isa Vector
263263
J[i] = Jaccache[i][1]

0 commit comments

Comments
 (0)