Skip to content

Do not seed structural zeros #739

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 44 additions & 8 deletions src/apiutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,32 +40,68 @@ end
return Expr(:tuple, [:(single_seed(Partials{N,V}, Val{$i}())) for i in 1:N]...)
end

# Only seed indices that are structurally non-zero
structural_eachindex(x::AbstractArray) = structural_eachindex(x, x)
function structural_eachindex(x::AbstractArray, y::AbstractArray)
require_one_based_indexing(x, y)
eachindex(x, y)
end
function structural_eachindex(x::UpperTriangular, y::AbstractArray)
require_one_based_indexing(x, y)
if size(x) != size(y)
throw(DimensionMismatch())
end
n = size(x, 1)
return (CartesianIndex(i, j) for j in 1:n for i in 1:j)
end
function structural_eachindex(x::LowerTriangular, y::AbstractArray)
require_one_based_indexing(x, y)
if size(x) != size(y)
throw(DimensionMismatch())
end
n = size(x, 1)
return (CartesianIndex(i, j) for j in 1:n for i in j:n)
end
function structural_eachindex(x::Diagonal, y::AbstractArray)
require_one_based_indexing(x, y)
if size(x) != size(y)
throw(DimensionMismatch())
end
return diagind(x)
end

function seed!(duals::AbstractArray{Dual{T,V,N}}, x,
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
duals .= Dual{T,V,N}.(x, Ref(seed))
for idx in structural_eachindex(duals, x)
duals[idx] = Dual{T,V,N}(x[idx], seed)
end
return duals
end

function seed!(duals::AbstractArray{Dual{T,V,N}}, x,
seeds::NTuple{N,Partials{N,V}}) where {T,V,N}
dual_inds = 1:N
duals[dual_inds] .= Dual{T,V,N}.(view(x,dual_inds), seeds)
for (i, idx) in zip(1:N, structural_eachindex(duals, x))
duals[idx] = Dual{T,V,N}(x[idx], seeds[i])
end
return duals
end

function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index,
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
offset = index - 1
dual_inds = (1:N) .+ offset
duals[dual_inds] .= Dual{T,V,N}.(view(x, dual_inds), Ref(seed))
idxs = Iterators.drop(structural_eachindex(duals, x), offset)
for idx in idxs
duals[idx] = Dual{T,V,N}(x[idx], seed)
end
return duals
end

function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index,
seeds::NTuple{N,Partials{N,V}}, chunksize = N) where {T,V,N}
offset = index - 1
seed_inds = 1:chunksize
dual_inds = seed_inds .+ offset
duals[dual_inds] .= Dual{T,V,N}.(view(x, dual_inds), getindex.(Ref(seeds), seed_inds))
idxs = Iterators.drop(structural_eachindex(duals, x), offset)
for (i, idx) in zip(1:chunksize, idxs)
duals[idx] = Dual{T,V,N}(x[idx], seeds[i])
end
return duals
end
21 changes: 14 additions & 7 deletions src/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Set `check` to `Val{false}()` to disable tag checking. This can lead to perturba
function gradient(f::F, x::AbstractArray, cfg::GradientConfig{T} = GradientConfig(f, x), ::Val{CHK}=Val{true}()) where {F, T, CHK}
require_one_based_indexing(x)
CHK && checktag(T, f, x)
if chunksize(cfg) == length(x)
if chunksize(cfg) == structural_length(x)
return vector_mode_gradient(f, x, cfg)
else
return chunk_mode_gradient(f, x, cfg)
Expand All @@ -35,7 +35,7 @@ This method assumes that `isa(f(x), Real)`.
function gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::AbstractArray, cfg::GradientConfig{T} = GradientConfig(f, x), ::Val{CHK}=Val{true}()) where {T, CHK, F}
result isa DiffResult ? require_one_based_indexing(x) : require_one_based_indexing(result, x)
CHK && checktag(T, f, x)
if chunksize(cfg) == length(x)
if chunksize(cfg) == structural_length(x)
vector_mode_gradient!(result, f, x, cfg)
else
chunk_mode_gradient!(result, f, x, cfg)
Expand Down Expand Up @@ -63,12 +63,19 @@ function extract_gradient!(::Type{T}, result::DiffResult, dual::Dual) where {T}
end

extract_gradient!(::Type{T}, result::AbstractArray, y::Real) where {T} = fill!(result, zero(y))
extract_gradient!(::Type{T}, result::AbstractArray, dual::Dual) where {T}= copyto!(result, partials(T, dual))
function extract_gradient!(::Type{T}, result::AbstractArray, dual::Dual) where {T}
idxs = structural_eachindex(result)
for (i, idx) in zip(1:npartials(dual), idxs)
result[idx] = partials(T, dual, i)
end
return result
end

function extract_gradient_chunk!(::Type{T}, result, dual, index, chunksize) where {T}
offset = index - 1
for i in 1:chunksize
result[i + offset] = partials(T, dual, i)
idxs = Iterators.drop(structural_eachindex(result), offset)
for (i, idx) in zip(1:chunksize, idxs)
result[idx] = partials(T, dual, i)
end
return result
end
Expand Down Expand Up @@ -106,10 +113,10 @@ end

function chunk_mode_gradient_expr(result_definition::Expr)
return quote
@assert length(x) >= N "chunk size cannot be greater than length(x) ($(N) > $(length(x)))"
@assert structural_length(x) >= N "chunk size cannot be greater than ForwardDiff.structural_length(x) ($(N) > $(structural_length(x)))"

# precalculate loop bounds
xlen = length(x)
xlen = structural_length(x)
remainder = xlen % N
lastchunksize = ifelse(remainder == 0, N, remainder)
lastchunkindex = xlen - lastchunksize + 1
Expand Down
12 changes: 6 additions & 6 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Set `check` to `Val{false}()` to disable tag checking. This can lead to perturba
function jacobian(f::F, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f, x), ::Val{CHK}=Val{true}()) where {F,T,CHK}
require_one_based_indexing(x)
CHK && checktag(T, f, x)
if chunksize(cfg) == length(x)
if chunksize(cfg) == structural_length(x)
return vector_mode_jacobian(f, x, cfg)
else
return chunk_mode_jacobian(f, x, cfg)
Expand All @@ -36,7 +36,7 @@ Set `check` to `Val{false}()` to disable tag checking. This can lead to perturba
function jacobian(f!::F, y::AbstractArray, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f!, y, x), ::Val{CHK}=Val{true}()) where {F,T, CHK}
require_one_based_indexing(y, x)
CHK && checktag(T, f!, x)
if chunksize(cfg) == length(x)
if chunksize(cfg) == structural_length(x)
return vector_mode_jacobian(f!, y, x, cfg)
else
return chunk_mode_jacobian(f!, y, x, cfg)
Expand All @@ -57,7 +57,7 @@ Set `check` to `Val{false}()` to disable tag checking. This can lead to perturba
function jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f, x), ::Val{CHK}=Val{true}()) where {F,T, CHK}
result isa DiffResult ? require_one_based_indexing(x) : require_one_based_indexing(result, x)
CHK && checktag(T, f, x)
if chunksize(cfg) == length(x)
if chunksize(cfg) == structural_length(x)
vector_mode_jacobian!(result, f, x, cfg)
else
chunk_mode_jacobian!(result, f, x, cfg)
Expand All @@ -78,7 +78,7 @@ Set `check` to `Val{false}()` to disable tag checking. This can lead to perturba
function jacobian!(result::Union{AbstractArray,DiffResult}, f!::F, y::AbstractArray, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f!, y, x), ::Val{CHK}=Val{true}()) where {F,T,CHK}
result isa DiffResult ? require_one_based_indexing(y, x) : require_one_based_indexing(result, y, x)
CHK && checktag(T, f!, x)
if chunksize(cfg) == length(x)
if chunksize(cfg) == structural_length(x)
vector_mode_jacobian!(result, f!, y, x, cfg)
else
chunk_mode_jacobian!(result, f!, y, x, cfg)
Expand Down Expand Up @@ -169,10 +169,10 @@ const JACOBIAN_ERROR = DimensionMismatch("jacobian(f, x) expects that f(x) is an
function jacobian_chunk_mode_expr(work_array_definition::Expr, compute_ydual::Expr,
result_definition::Expr, y_definition::Expr)
return quote
@assert length(x) >= N "chunk size cannot be greater than length(x) ($(N) > $(length(x)))"
@assert structural_length(x) >= N "chunk size cannot be greater than ForwardDiff.structural_length(x) ($(N) > $(structural_length(x)))"

# precalculate loop bounds
xlen = length(x)
xlen = structural_length(x)
remainder = xlen % N
lastchunksize = ifelse(remainder == 0, N, remainder)
lastchunkindex = xlen - lastchunksize + 1
Expand Down
9 changes: 8 additions & 1 deletion src/prelude.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,15 @@ function Chunk(input_length::Integer, threshold::Integer = DEFAULT_CHUNK_THRESHO
Base.@nif 12 d->(N == d) d->(Chunk{d}()) d->(Chunk{N}())
end

structural_length(x::AbstractArray) = length(x)
function structural_length(x::Union{LowerTriangular,UpperTriangular})
n = size(x, 1)
return (n * (n + 1)) >> 1
end
structural_length(x::Diagonal) = size(x, 1)

function Chunk(x::AbstractArray, threshold::Integer = DEFAULT_CHUNK_THRESHOLD)
return Chunk(length(x), threshold)
return Chunk(structural_length(x), threshold)
end

# Constrained to `N <= threshold`, minimize (in order of priority):
Expand Down
29 changes: 29 additions & 0 deletions test/GradientTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,4 +226,33 @@ end
@test dx ≈ sum(a * b)
end

# issue #738
@testset "LowerTriangular, UpperTriangular and Diagonal" begin
for n in (3, 10, 20)
M = rand(n, n)
for T in (LowerTriangular, UpperTriangular, Diagonal)
@test ForwardDiff.gradient(sum, T(randn(n, n))) == T(ones(n, n))
@test ForwardDiff.gradient(x -> dot(M, x), T(randn(n, n))) == T(M)

# Check number of function evaluations and chunk sizes
fevals = Ref(0)
npartials = Ref(0)
y = ForwardDiff.gradient(T(randn(n, n))) do x
fevals[] += 1
npartials[] += ForwardDiff.npartials(eltype(x))
return sum(x)
end
if npartials[] <= ForwardDiff.DEFAULT_CHUNK_THRESHOLD
# Vector mode (single evaluation)
@test fevals[] == 1
@test npartials[] == sum(y)
else
# Chunk mode (multiple evaluations)
@test fevals[] > 1
@test sum(y) <= npartials[] < sum(y) + fevals[]
end
end
end
end

end # module
Loading