Skip to content

Cached getindex for scalar iterator #1263

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 4, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 63 additions & 29 deletions src/Utilities/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -326,39 +326,73 @@ end
A type that allows iterating over the scalar-functions that comprise an
`AbstractVectorFunction`.
"""
struct ScalarFunctionIterator{F<:MOI.AbstractVectorFunction}
struct ScalarFunctionIterator{F<:MOI.AbstractVectorFunction, C}
f::F
# Vectors which map output indices to their terms.
affine::Vector{Vector{Int}}
quadratic::Vector{Vector{Int}}
# Cache that can be used to store a precomputed datastructure that allows
# an efficient implementation of `getindex`.
cache::C
end

function ScalarFunctionIterator(f::MOI.VectorOfVariables)
function ScalarFunctionIterator(func::MOI.AbstractVectorFunction)
return ScalarFunctionIterator(
f,
Vector{Int}[],
Vector{Int}[],
func,
scalar_iterator_cache(func),
)
end

function ScalarFunctionIterator(f::MOI.VectorAffineFunction)
d = [Int[] for i = 1:MOI.output_dimension(f)]
for (i, term) in enumerate(f.terms)
push!(d[term.output_index], i)
scalar_iterator_cache(func::MOI.AbstractVectorFunction) = nothing

function output_index_iterator(terms::AbstractVector, output_dimension)
start = zeros(Int, output_dimension)
next = Vector{Int}(undef, length(terms))
last = zeros(Int, output_dimension)
for i in eachindex(terms)
j = terms[i].output_index
if iszero(last[j])
start[j] = i
else
next[last[j]] = i
end
last[j] = i
end
return ScalarFunctionIterator(f, d, Vector{Int}[])
for j in eachindex(last)
if !iszero(last[j])
next[last[j]] = 0
end
end
return ChainedIterator(start, next)
end
struct ChainedIterator
start::Vector{Int}
next::Vector{Int}
end
struct ChainedIteratorAtIndex
start::Int
next::Vector{Int}
end
function ChainedIteratorAtIndex(it::ChainedIterator, index::Int)
return ChainedIteratorAtIndex(it.start[index], it.next)
end
#TODO We could also precompute the length for each `output_index`,
# check that it's a win.
Base.IteratorSize(::ChainedIteratorAtIndex) = Base.SizeUnknown()
function Base.iterate(it::ChainedIteratorAtIndex, i = it.start)
if iszero(i)
return nothing
else
return i, it.next[i]
end
end

function ScalarFunctionIterator(f::MOI.VectorAffineFunction)
return ScalarFunctionIterator(f, output_index_iterator(f.terms, MOI.output_dimension(f)))
end

function ScalarFunctionIterator(f::MOI.VectorQuadraticFunction)
aff = [Int[] for i = 1:MOI.output_dimension(f)]
quad = [Int[] for i = 1:MOI.output_dimension(f)]
for (i, term) in enumerate(f.affine_terms)
push!(aff[term.output_index], i)
end
for (i, term) in enumerate(f.quadratic_terms)
push!(quad[term.output_index], i)
end
return ScalarFunctionIterator(f, aff, quad)
return ScalarFunctionIterator(
f,
(output_index_iterator(f.affine_terms, MOI.output_dimension(f)),
output_index_iterator(f.quadratic_terms, MOI.output_dimension(f))),
)
end

eachscalar(f::MOI.AbstractVectorFunction) = ScalarFunctionIterator(f)
Expand Down Expand Up @@ -407,7 +441,7 @@ function Base.getindex(
return MOI.ScalarAffineFunction{T}(
MOI.ScalarAffineTerm{T}[
it.f.terms[i].scalar_term
for i in it.affine[output_index]
for i in ChainedIteratorAtIndex(it.cache, output_index)
],
it.f.constants[output_index],
)
Expand All @@ -419,7 +453,7 @@ function Base.getindex(
) where {T}
terms = MOI.VectorAffineTerm{T}[]
for (i, output_index) in enumerate(output_indices)
for j in it.affine[output_index]
for j in ChainedIteratorAtIndex(it.cache, output_index)
push!(terms, MOI.VectorAffineTerm(i, it.f.terms[j].scalar_term))
end
end
Expand All @@ -434,10 +468,10 @@ function Base.getindex(
) where {T}
return MOI.ScalarQuadraticFunction(
MOI.ScalarAffineTerm{T}[
it.f.affine_terms[i].scalar_term for i in it.affine[output_index]
it.f.affine_terms[i].scalar_term for i in ChainedIteratorAtIndex(it.cache[1], output_index)
],
MOI.ScalarQuadraticTerm{T}[
it.f.quadratic_terms[i].scalar_term for i in it.quadratic[output_index]
it.f.quadratic_terms[i].scalar_term for i in ChainedIteratorAtIndex(it.cache[2], output_index)
],
it.f.constants[output_index],
)
Expand All @@ -450,13 +484,13 @@ function Base.getindex(
vat = MOI.VectorAffineTerm{T}[]
vqt = MOI.VectorQuadraticTerm{T}[]
for (i, output_index) in enumerate(output_indices)
for j in it.affine[output_index]
for j in ChainedIteratorAtIndex(it.cache[1], output_index)
push!(
vat,
MOI.VectorAffineTerm(i, it.f.affine_terms[j].scalar_term),
)
end
for j in it.quadratic[output_index]
for j in ChainedIteratorAtIndex(it.cache[2], output_index)
push!(
vqt,
MOI.VectorQuadraticTerm(i, it.f.quadratic_terms[j].scalar_term),
Expand Down