Skip to content

Commit ff529cb

Browse files
authored
Improve type stability of Jacobians and Hessian, fix test scenarios (#337)
* Ensure type stability of test scenarios in 1.11 * Preallocate results * Typo * Fix GPU scenario * Remporarily disable tests on 1.11 * I said disable them
1 parent 83de009 commit ff529cb

File tree

7 files changed

+150
-93
lines changed

7 files changed

+150
-93
lines changed

.github/workflows/Test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
version:
3030
- '1'
3131
- '1.6'
32-
- '~1.11.0-0'
32+
# - '~1.11.0-0'
3333
group:
3434
- Formalities
3535
- Internals
@@ -118,7 +118,7 @@ jobs:
118118
version:
119119
- '1'
120120
- '1.6'
121-
- '~1.11.0-0'
121+
# - '~1.11.0-0'
122122
group:
123123
- Formalities
124124
- Zero

DifferentiationInterface/src/first_order/jacobian.jl

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,18 @@ abstract type JacobianExtras <: Extras end
5555

5656
struct NoJacobianExtras <: JacobianExtras end
5757

58-
struct PushforwardJacobianExtras{B,D,E<:PushforwardExtras,Y} <: JacobianExtras
58+
struct PushforwardJacobianExtras{B,D,R,E<:PushforwardExtras} <: JacobianExtras
5959
batched_seeds::Vector{Batch{B,D}}
60+
batched_results::Vector{Batch{B,R}}
6061
pushforward_batched_extras::E
61-
y_example::Y
62+
N::Int
6263
end
6364

64-
struct PullbackJacobianExtras{B,D,E<:PullbackExtras,Y} <: JacobianExtras
65+
struct PullbackJacobianExtras{B,D,R,E<:PullbackExtras} <: JacobianExtras
6566
batched_seeds::Vector{Batch{B,D}}
67+
batched_results::Vector{Batch{B,R}}
6668
pullback_batched_extras::E
67-
y_example::Y
69+
M::Int
6870
end
6971

7072
function prepare_jacobian(f::F, backend::AbstractADType, x) where {F}
@@ -85,14 +87,15 @@ function prepare_jacobian_aux(f_or_f!y::FY, backend, x, y, ::PushforwardFast) wh
8587
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for
8688
a in 1:div(N, B, RoundUp)
8789
])
90+
batched_results = Batch.([ntuple(b -> similar(y), Val(B)) for _ in batched_seeds])
8891
pushforward_batched_extras = prepare_pushforward_batched(
8992
f_or_f!y..., backend, x, batched_seeds[1]
9093
)
91-
D = eltype(seeds)
94+
D = eltype(batched_seeds[1])
95+
R = eltype(batched_results[1])
9296
E = typeof(pushforward_batched_extras)
93-
Y = typeof(y)
94-
return PushforwardJacobianExtras{B,D,E,Y}(
95-
batched_seeds, pushforward_batched_extras, copy(y)
97+
return PushforwardJacobianExtras{B,D,R,E}(
98+
batched_seeds, batched_results, pushforward_batched_extras, N
9699
)
97100
end
98101

@@ -105,13 +108,16 @@ function prepare_jacobian_aux(f_or_f!y::FY, backend, x, y, ::PushforwardSlow) wh
105108
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % M], Val(B)) for
106109
a in 1:div(M, B, RoundUp)
107110
])
111+
batched_results = Batch.([ntuple(b -> similar(x), Val(B)) for _ in batched_seeds])
108112
pullback_batched_extras = prepare_pullback_batched(
109113
f_or_f!y..., backend, x, batched_seeds[1]
110114
)
111-
D = eltype(seeds)
115+
D = eltype(batched_seeds[1])
116+
R = eltype(batched_results[1])
112117
E = typeof(pullback_batched_extras)
113-
Y = typeof(y)
114-
return PullbackJacobianExtras{B,D,E,Y}(batched_seeds, pullback_batched_extras, copy(y))
118+
return PullbackJacobianExtras{B,D,R,E}(
119+
batched_seeds, batched_results, pullback_batched_extras, M
120+
)
115121
end
116122

117123
## One argument
@@ -209,8 +215,7 @@ end
209215
function jacobian_aux(
210216
f_or_f!y::FY, backend, x::AbstractArray, extras::PushforwardJacobianExtras{B}
211217
) where {FY,B}
212-
@compat (; batched_seeds, pushforward_batched_extras, y_example) = extras
213-
N = length(x)
218+
@compat (; batched_seeds, pushforward_batched_extras, N) = extras
214219

215220
pushforward_batched_extras_same = prepare_pushforward_batched_same_point(
216221
f_or_f!y..., backend, x, batched_seeds[1], pushforward_batched_extras
@@ -233,8 +238,7 @@ end
233238
function jacobian_aux(
234239
f_or_f!y::FY, backend, x::AbstractArray, extras::PullbackJacobianExtras{B}
235240
) where {FY,B}
236-
@compat (; batched_seeds, pullback_batched_extras, y_example) = extras
237-
M = length(y_example)
241+
@compat (; batched_seeds, pullback_batched_extras, M) = extras
238242

239243
pullback_batched_extras_same = prepare_pullback_batched_same_point(
240244
f_or_f!y..., backend, x, batched_seeds[1], extras.pullback_batched_extras
@@ -261,27 +265,32 @@ function jacobian_aux!(
261265
x::AbstractArray,
262266
extras::PushforwardJacobianExtras{B},
263267
) where {FY,B}
264-
@compat (; batched_seeds, pushforward_batched_extras, y_example) = extras
265-
N = length(x)
268+
@compat (; batched_seeds, batched_results, pushforward_batched_extras, N) = extras
266269

267270
pushforward_batched_extras_same = prepare_pushforward_batched_same_point(
268271
f_or_f!y..., backend, x, batched_seeds[1], pushforward_batched_extras
269272
)
270273

271-
for a in eachindex(batched_seeds)
272-
dy_batch_elements = ntuple(Val(B)) do b
273-
reshape(view(jac, :, 1 + ((a - 1) * B + (b - 1)) % N), size(y_example))
274-
end
274+
for a in eachindex(batched_seeds, batched_results)
275275
pushforward_batched!(
276276
f_or_f!y...,
277-
Batch(dy_batch_elements),
277+
batched_results[a],
278278
backend,
279279
x,
280280
batched_seeds[a],
281281
pushforward_batched_extras_same,
282282
)
283283
end
284284

285+
for a in eachindex(batched_results)
286+
for b in eachindex(batched_results[a].elements)
287+
copyto!(
288+
view(jac, :, 1 + ((a - 1) * B + (b - 1)) % N),
289+
vec(batched_results[a].elements[b]),
290+
)
291+
end
292+
end
293+
285294
return jac
286295
end
287296

@@ -292,26 +301,31 @@ function jacobian_aux!(
292301
x::AbstractArray,
293302
extras::PullbackJacobianExtras{B},
294303
) where {FY,B}
295-
@compat (; batched_seeds, pullback_batched_extras, y_example) = extras
296-
M = length(y_example)
304+
@compat (; batched_seeds, batched_results, pullback_batched_extras, M) = extras
297305

298306
pullback_batched_extras_same = prepare_pullback_batched_same_point(
299307
f_or_f!y..., backend, x, batched_seeds[1], extras.pullback_batched_extras
300308
)
301309

302-
for a in eachindex(batched_seeds)
303-
dx_batch_elements = ntuple(Val(B)) do b
304-
reshape(view(jac, 1 + ((a - 1) * B + (b - 1)) % M, :), size(x))
305-
end
310+
for a in eachindex(batched_seeds, batched_results)
306311
pullback_batched!(
307312
f_or_f!y...,
308-
Batch(dx_batch_elements),
313+
batched_results[a],
309314
backend,
310315
x,
311316
batched_seeds[a],
312317
pullback_batched_extras_same,
313318
)
314319
end
315320

321+
for a in eachindex(batched_results)
322+
for b in eachindex(batched_results[a].elements)
323+
copyto!(
324+
view(jac, 1 + ((a - 1) * B + (b - 1)) % M, :),
325+
vec(batched_results[a].elements[b]),
326+
)
327+
end
328+
end
329+
316330
return jac
317331
end

DifferentiationInterface/src/second_order/hessian.jl

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,12 @@ abstract type HessianExtras <: Extras end
4949

5050
struct NoHessianExtras <: HessianExtras end
5151

52-
struct HVPGradientHessianExtras{B,D,E2<:HVPExtras,E1<:GradientExtras} <: HessianExtras
52+
struct HVPGradientHessianExtras{B,D,R,E2<:HVPExtras,E1<:GradientExtras} <: HessianExtras
5353
batched_seeds::Vector{Batch{B,D}}
54+
batched_results::Vector{Batch{B,R}}
5455
hvp_batched_extras::E2
5556
gradient_extras::E1
57+
N::Int
5658
end
5759

5860
function prepare_hessian(f::F, backend::AbstractADType, x) where {F}
@@ -64,12 +66,14 @@ function prepare_hessian(f::F, backend::AbstractADType, x) where {F}
6466
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for
6567
a in 1:div(N, B, RoundUp)
6668
])
69+
batched_results = Batch.([ntuple(b -> similar(x), Val(B)) for _ in batched_seeds])
6770
hvp_batched_extras = prepare_hvp_batched(f, backend, x, batched_seeds[1])
6871
gradient_extras = prepare_gradient(f, maybe_inner(backend), x)
69-
D = eltype(seeds)
72+
D = eltype(batched_seeds[1])
73+
R = eltype(batched_results[1])
7074
E2, E1 = typeof(hvp_batched_extras), typeof(gradient_extras)
71-
return HVPGradientHessianExtras{B,D,E2,E1}(
72-
batched_seeds, hvp_batched_extras, gradient_extras
75+
return HVPGradientHessianExtras{B,D,R,E2,E1}(
76+
batched_seeds, batched_results, hvp_batched_extras, gradient_extras, N
7377
)
7478
end
7579

@@ -100,8 +104,7 @@ end
100104
function hessian(
101105
f::F, backend::AbstractADType, x, extras::HVPGradientHessianExtras{B}
102106
) where {F,B}
103-
@compat (; batched_seeds, hvp_batched_extras) = extras
104-
N = length(x)
107+
@compat (; batched_seeds, hvp_batched_extras, N) = extras
105108

106109
hvp_batched_extras_same = prepare_hvp_batched_same_point(
107110
f, backend, x, batched_seeds[1], hvp_batched_extras
@@ -122,27 +125,27 @@ end
122125
function hessian!(
123126
f::F, hess, backend::AbstractADType, x, extras::HVPGradientHessianExtras{B}
124127
) where {F,B}
125-
@compat (; batched_seeds, hvp_batched_extras) = extras
126-
N = length(x)
128+
@compat (; batched_seeds, batched_results, hvp_batched_extras, N) = extras
127129

128130
hvp_batched_extras_same = prepare_hvp_batched_same_point(
129131
f, backend, x, batched_seeds[1], hvp_batched_extras
130132
)
131133

132-
for a in eachindex(batched_seeds)
133-
dg_batch_elements = ntuple(Val(B)) do b
134-
reshape(view(hess, :, 1 + ((a - 1) * B + (b - 1)) % N), size(x))
135-
end
134+
for a in eachindex(batched_seeds, batched_results)
136135
hvp_batched!(
137-
f,
138-
Batch(dg_batch_elements),
139-
backend,
140-
x,
141-
batched_seeds[a],
142-
hvp_batched_extras_same,
136+
f, batched_results[a], backend, x, batched_seeds[a], hvp_batched_extras_same
143137
)
144138
end
145139

140+
for a in eachindex(batched_results)
141+
for b in eachindex(batched_results[a].elements)
142+
copyto!(
143+
view(hess, :, 1 + ((a - 1) * B + (b - 1)) % N),
144+
vec(batched_results[a].elements[b]),
145+
)
146+
end
147+
end
148+
146149
return hess
147150
end
148151

DifferentiationInterface/src/sparse/hessian.jl

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
struct SparseHessianExtras{
2-
B,S<:AbstractMatrix{Bool},C<:AbstractMatrix{<:Real},D,E2<:HVPExtras,E1<:GradientExtras
2+
B,S<:AbstractMatrix{Bool},C<:AbstractMatrix{<:Real},D,R,E2<:HVPExtras,E1<:GradientExtras
33
} <: HessianExtras
44
sparsity::S
55
colors::Vector{Int}
66
groups::Vector{Vector{Int}}
77
compressed::C
88
batched_seeds::Vector{Batch{B,D}}
9+
batched_results::Vector{Batch{B,R}}
910
hvp_batched_extras::E2
1011
gradient_extras::E1
1112
end
@@ -16,16 +17,18 @@ function SparseHessianExtras{B}(;
1617
groups,
1718
compressed::C,
1819
batched_seeds::Vector{Batch{B,D}},
20+
batched_results::Vector{Batch{B,R}},
1921
hvp_batched_extras::E2,
2022
gradient_extras::E1,
21-
) where {B,S,C,D,E2,E1}
23+
) where {B,S,C,D,R,E2,E1}
2224
@assert size(sparsity, 1) == size(sparsity, 2) == size(compressed, 1) == length(colors)
23-
return SparseHessianExtras{B,S,C,D,E2,E1}(
25+
return SparseHessianExtras{B,S,C,D,R,E2,E1}(
2426
sparsity,
2527
colors,
2628
groups,
2729
compressed,
2830
batched_seeds,
31+
batched_results,
2932
hvp_batched_extras,
3033
gradient_extras,
3134
)
@@ -48,6 +51,7 @@ function prepare_hessian(f::F, backend::AutoSparse, x) where {F}
4851
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % Ng], Val(B)) for
4952
a in 1:div(Ng, B, RoundUp)
5053
])
54+
batched_results = Batch.([ntuple(b -> similar(x), Val(B)) for _ in batched_seeds])
5155
hvp_batched_extras = prepare_hvp_batched(f, dense_backend, x, batched_seeds[1])
5256
gradient_extras = prepare_gradient(f, maybe_inner(dense_backend), x)
5357
return SparseHessianExtras{B}(;
@@ -56,6 +60,7 @@ function prepare_hessian(f::F, backend::AutoSparse, x) where {F}
5660
groups,
5761
compressed,
5862
batched_seeds,
63+
batched_results,
5964
hvp_batched_extras,
6065
gradient_extras,
6166
)
@@ -86,29 +91,42 @@ end
8691
function hessian!(
8792
f::F, hess, backend::AutoSparse, x, extras::SparseHessianExtras{B}
8893
) where {F,B}
89-
@compat (; sparsity, compressed, colors, groups, batched_seeds, hvp_batched_extras) =
90-
extras
94+
@compat (;
95+
sparsity,
96+
compressed,
97+
colors,
98+
groups,
99+
batched_seeds,
100+
batched_results,
101+
hvp_batched_extras,
102+
) = extras
91103
dense_backend = dense_ad(backend)
92104
Ng = length(groups)
93105

94106
hvp_batched_extras_same = prepare_hvp_batched_same_point(
95107
f, dense_backend, x, batched_seeds[1], hvp_batched_extras
96108
)
97109

98-
for a in 1:div(Ng, B, RoundUp)
99-
dg_batch_elements = ntuple(Val(B)) do b
100-
reshape(view(compressed, :, 1 + ((a - 1) * B + (b - 1)) % Ng), size(x))
101-
end
110+
for a in eachindex(batched_seeds, batched_results)
102111
hvp_batched!(
103112
f,
104-
Batch(dg_batch_elements),
113+
batched_results[a],
105114
dense_backend,
106115
x,
107116
batched_seeds[a],
108117
hvp_batched_extras_same,
109118
)
110119
end
111120

121+
for a in eachindex(batched_results)
122+
for b in eachindex(batched_results[a].elements)
123+
copyto!(
124+
view(compressed, :, 1 + ((a - 1) * B + (b - 1)) % Ng),
125+
vec(batched_results[a].elements[b]),
126+
)
127+
end
128+
end
129+
112130
decompress_symmetric!(hess, sparsity, compressed, colors)
113131
return hess
114132
end

0 commit comments

Comments
 (0)