@@ -14,20 +14,30 @@ void_setindex!(args...) = (setindex!(args...); return)
14
14
15
15
const default_chunk_size = ForwardDiff. pickchunksize
16
16
17
- function ForwardColorJacCache (f,x,_chunksize = nothing ;
17
+ function ForwardColorJacCache (f:: F ,x,_chunksize = nothing ;
18
18
dx = nothing ,
19
19
colorvec= 1 : length (x),
20
- sparsity:: Union{AbstractArray,Nothing} = nothing )
20
+ sparsity:: Union{AbstractArray,Nothing} = nothing ) where {F}
21
21
22
22
if _chunksize isa Nothing
23
23
chunksize = ForwardDiff. pickchunksize (maximum (colorvec))
24
24
else
25
25
chunksize = _chunksize
26
26
end
27
27
28
- p = adapt .(parameterless_type (x),generate_chunked_partials (x,colorvec,chunksize))
29
- _t = Dual {typeof(ForwardDiff.Tag(f,eltype(vec(x))))} .(vec (x),first (p))
30
- t = ArrayInterface. restructure (x,_t)
28
+ if x isa Array
29
+ p = generate_chunked_partials (x,colorvec,chunksize)
30
+ t = similar (x,Dual{typeof (ForwardDiff. Tag (f,eltype (vec (x))))})
31
+ for i in eachindex (t)
32
+ t[i] = Dual {typeof(ForwardDiff.Tag(f,eltype(vec(x))))} (x[i],first (p)[1 ])
33
+ end
34
+ else
35
+ p = adapt .(parameterless_type (x),generate_chunked_partials (x,colorvec,chunksize))
36
+ _t = Dual {typeof(ForwardDiff.Tag(f,eltype(vec(x))))} .(vec (x),first (p))
37
+ t = ArrayInterface. restructure (x,_t)
38
+ end
39
+
40
+
31
41
if dx isa Nothing
32
42
fx = similar (t)
33
43
_dx = similar (x)
@@ -46,13 +56,27 @@ function generate_chunked_partials(x,colorvec,::Val{chunksize}) where chunksize
46
56
maxcolor = maximum (colorvec)
47
57
num_of_chunks = Int (ceil (maxcolor / chunksize))
48
58
padding_size = (chunksize - (maxcolor % chunksize)) % chunksize
49
- partials = colorvec .== (1 : maxcolor)'
59
+
60
+ # partials = colorvec .== (1:maxcolor)'
61
+ partials = BitMatrix (undef, length (colorvec), maxcolor)
62
+ for i in 1 : maxcolor, j in 1 : length (colorvec)
63
+ partials[j,i] = colorvec[j] == i
64
+ end
65
+
50
66
padding_matrix = BitMatrix (undef, length (x), padding_size)
51
67
partials = hcat (partials, padding_matrix)
52
68
53
- chunked_partials = map (i -> Tuple .(eachrow (partials[:,(i- 1 )* chunksize+ 1 : i* chunksize])),1 : num_of_chunks)
54
- chunked_partials
55
69
70
+ # chunked_partials = map(i -> Tuple.(eachrow(partials[:,(i-1)*chunksize+1:i*chunksize])),1:num_of_chunks)
71
+ chunked_partials = Vector {Vector{NTuple{chunksize,eltype(x)}}} (undef, num_of_chunks)
72
+ for i in 1 : num_of_chunks
73
+ tmp = Vector {NTuple{chunksize,eltype(x)}} (undef, size (partials,1 ))
74
+ for j in 1 : size (partials,1 )
75
+ tmp[j] = Tuple (@view partials[j,(i- 1 )* chunksize+ 1 : i* chunksize])
76
+ end
77
+ chunked_partials[i] = tmp
78
+ end
79
+ chunked_partials
56
80
end
57
81
58
82
@inline function forwarddiff_color_jacobian (f,
@@ -280,11 +304,26 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
280
304
281
305
for i in eachindex (p)
282
306
partial_i = p[i]
283
- vect .= Dual {typeof(ForwardDiff.Tag(f,eltype(vecx)))} .(vecx, partial_i)
307
+
308
+ if vect isa Array
309
+ @inbounds @simd ivdep for j in eachindex (vect)
310
+ vect[j] = Dual {typeof(ForwardDiff.Tag(f,eltype(vecx)))} (vecx[j], partial_i[j])
311
+ end
312
+ else
313
+ vect .= Dual {typeof(ForwardDiff.Tag(f,eltype(vecx)))} .(vecx, partial_i)
314
+ end
315
+
284
316
f (fx,t)
285
317
if ! (sparsity isa Nothing)
286
318
for j in 1 : chunksize
287
- dx .= partials .(fx, j)
319
+
320
+ if dx isa Array
321
+ @inbounds @simd ivdep for k in eachindex (dx)
322
+ dx[k] = partials (fx[k], j)
323
+ end
324
+ else
325
+ dx .= partials .(fx, j)
326
+ end
288
327
289
328
if ArrayInterface. fast_scalar_indexing (dx)
290
329
# dx is implicitly used in vecdx
@@ -313,7 +352,13 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
313
352
for j in 1 : chunksize
314
353
col_index = (i- 1 )* chunksize + j
315
354
(col_index > ncols) && return J
316
- J[:, col_index] .= partials .(vecfx, j)
355
+ if J isa Array
356
+ @inbounds @simd for k in 1 : size (J,1 )
357
+ J[k, col_index] = partials (vecfx[k], j)
358
+ end
359
+ else
360
+ J[:, col_index] .= partials .(vecfx, j)
361
+ end
317
362
end
318
363
end
319
364
end
0 commit comments