Skip to content

Commit 306f83d

Browse files
Merge pull request #149 from JuliaDiff/map
Make inference's job much easier by avoiding `map`
2 parents fb09091 + 73f2018 commit 306f83d

File tree

1 file changed

+56
-11
lines changed

1 file changed

+56
-11
lines changed

src/differentiation/compute_jacobian_ad.jl

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,30 @@ void_setindex!(args...) = (setindex!(args...); return)
1414

1515
const default_chunk_size = ForwardDiff.pickchunksize
1616

17-
function ForwardColorJacCache(f,x,_chunksize = nothing;
17+
function ForwardColorJacCache(f::F,x,_chunksize = nothing;
1818
dx = nothing,
1919
colorvec=1:length(x),
20-
sparsity::Union{AbstractArray,Nothing}=nothing)
20+
sparsity::Union{AbstractArray,Nothing}=nothing) where {F}
2121

2222
if _chunksize isa Nothing
2323
chunksize = ForwardDiff.pickchunksize(maximum(colorvec))
2424
else
2525
chunksize = _chunksize
2626
end
2727

28-
p = adapt.(parameterless_type(x),generate_chunked_partials(x,colorvec,chunksize))
29-
_t = Dual{typeof(ForwardDiff.Tag(f,eltype(vec(x))))}.(vec(x),first(p))
30-
t = ArrayInterface.restructure(x,_t)
28+
if x isa Array
29+
p = generate_chunked_partials(x,colorvec,chunksize)
30+
t = similar(x,Dual{typeof(ForwardDiff.Tag(f,eltype(vec(x))))})
31+
for i in eachindex(t)
32+
t[i] = Dual{typeof(ForwardDiff.Tag(f,eltype(vec(x))))}(x[i],first(p)[1])
33+
end
34+
else
35+
p = adapt.(parameterless_type(x),generate_chunked_partials(x,colorvec,chunksize))
36+
_t = Dual{typeof(ForwardDiff.Tag(f,eltype(vec(x))))}.(vec(x),first(p))
37+
t = ArrayInterface.restructure(x,_t)
38+
end
39+
40+
3141
if dx isa Nothing
3242
fx = similar(t)
3343
_dx = similar(x)
@@ -46,13 +56,27 @@ function generate_chunked_partials(x,colorvec,::Val{chunksize}) where chunksize
4656
maxcolor = maximum(colorvec)
4757
num_of_chunks = Int(ceil(maxcolor / chunksize))
4858
padding_size = (chunksize - (maxcolor % chunksize)) % chunksize
49-
partials = colorvec .== (1:maxcolor)'
59+
60+
# partials = colorvec .== (1:maxcolor)'
61+
partials = BitMatrix(undef, length(colorvec), maxcolor)
62+
for i in 1:maxcolor, j in 1:length(colorvec)
63+
partials[j,i] = colorvec[j] == i
64+
end
65+
5066
padding_matrix = BitMatrix(undef, length(x), padding_size)
5167
partials = hcat(partials, padding_matrix)
5268

53-
chunked_partials = map(i -> Tuple.(eachrow(partials[:,(i-1)*chunksize+1:i*chunksize])),1:num_of_chunks)
54-
chunked_partials
5569

70+
#chunked_partials = map(i -> Tuple.(eachrow(partials[:,(i-1)*chunksize+1:i*chunksize])),1:num_of_chunks)
71+
chunked_partials = Vector{Vector{NTuple{chunksize,eltype(x)}}}(undef, num_of_chunks)
72+
for i in 1:num_of_chunks
73+
tmp = Vector{NTuple{chunksize,eltype(x)}}(undef, size(partials,1))
74+
for j in 1:size(partials,1)
75+
tmp[j] = Tuple(@view partials[j,(i-1)*chunksize+1:i*chunksize])
76+
end
77+
chunked_partials[i] = tmp
78+
end
79+
chunked_partials
5680
end
5781

5882
@inline function forwarddiff_color_jacobian(f,
@@ -280,11 +304,26 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
280304

281305
for i in eachindex(p)
282306
partial_i = p[i]
283-
vect .= Dual{typeof(ForwardDiff.Tag(f,eltype(vecx)))}.(vecx, partial_i)
307+
308+
if vect isa Array
309+
@inbounds @simd ivdep for j in eachindex(vect)
310+
vect[j] = Dual{typeof(ForwardDiff.Tag(f,eltype(vecx)))}(vecx[j], partial_i[j])
311+
end
312+
else
313+
vect .= Dual{typeof(ForwardDiff.Tag(f,eltype(vecx)))}.(vecx, partial_i)
314+
end
315+
284316
f(fx,t)
285317
if !(sparsity isa Nothing)
286318
for j in 1:chunksize
287-
dx .= partials.(fx, j)
319+
320+
if dx isa Array
321+
@inbounds @simd ivdep for k in eachindex(dx)
322+
dx[k] = partials(fx[k], j)
323+
end
324+
else
325+
dx .= partials.(fx, j)
326+
end
288327

289328
if ArrayInterface.fast_scalar_indexing(dx)
290329
#dx is implicitly used in vecdx
@@ -313,7 +352,13 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
313352
for j in 1:chunksize
314353
col_index = (i-1)*chunksize + j
315354
(col_index > ncols) && return J
316-
J[:, col_index] .= partials.(vecfx, j)
355+
if J isa Array
356+
@inbounds @simd for k in 1:size(J,1)
357+
J[k, col_index] = partials(vecfx[k], j)
358+
end
359+
else
360+
J[:, col_index] .= partials.(vecfx, j)
361+
end
317362
end
318363
end
319364
end

0 commit comments

Comments
 (0)