Skip to content

Fix complexity of getindex(::ScalarFunctionIterator, i) #1257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 4, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 138 additions & 48 deletions src/Utilities/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,81 @@ function scalar_type(::Type{MOI.VectorQuadraticFunction{T}}) where {T}
return MOI.ScalarQuadraticFunction{T}
end

struct ScalarFunctionIterator{F<:MOI.AbstractVectorFunction}
"""
ScalarFunctionIterator{F<:MOI.AbstractVectorFunction}

A type that allows iterating over the scalar-functions that comprise an
`AbstractVectorFunction`.
"""
struct ScalarFunctionIterator{F<:MOI.AbstractVectorFunction, C}
f::F
# Cache that can be used to store a precomputed datastructure that allows
# an efficient implementation of `getindex`.
cache::C
end
function ScalarFunctionIterator(func::MOI.AbstractVectorFunction)
return ScalarFunctionIterator(
func,
scalar_iterator_cache(func),
)
end

scalar_iterator_cache(func::MOI.AbstractVectorFunction) = nothing

function output_index_iterator(terms::AbstractVector, output_dimension)
start = zeros(Int, output_dimension)
next = Vector{Int}(undef, length(terms))
last = zeros(Int, output_dimension)
for i in eachindex(terms)
j = terms[i].output_index
if iszero(last[j])
start[j] = i
else
next[last[j]] = i
end
last[j] = i
end
for j in eachindex(last)
if !iszero(last[j])
next[last[j]] = 0
end
end
return ChainedIterator(start, next)
end
struct ChainedIterator
start::Vector{Int}
next::Vector{Int}
end
struct ChainedIteratorAtIndex
start::Int
next::Vector{Int}
end
function ChainedIteratorAtIndex(it::ChainedIterator, index::Int)
return ChainedIteratorAtIndex(it.start[index], it.next)
end
#TODO We could also precompute the length for each `output_index`,
# check that it's a win.
Base.IteratorSize(::ChainedIteratorAtIndex) = Base.SizeUnknown()
function Base.iterate(it::ChainedIteratorAtIndex, i = it.start)
if iszero(i)
return nothing
else
return i, it.next[i]
end
end

function ScalarFunctionIterator(f::MOI.VectorAffineFunction)
return ScalarFunctionIterator(f, output_index_iterator(f.terms, MOI.output_dimension(f)))
end

function ScalarFunctionIterator(f::MOI.VectorQuadraticFunction)
return ScalarFunctionIterator(
f,
(output_index_iterator(f.affine_terms, MOI.output_dimension(f)),
output_index_iterator(f.quadratic_terms, MOI.output_dimension(f))),
)
end

eachscalar(f::MOI.AbstractVectorFunction) = ScalarFunctionIterator(f)
eachscalar(f::AbstractVector) = f

Expand All @@ -344,70 +416,88 @@ Base.lastindex(it::ScalarFunctionIterator) = length(it)

# Define getindex for Vector functions

# VectorOfVariables

function Base.getindex(
it::ScalarFunctionIterator{MOI.VectorOfVariables},
i::Integer,
)
return MOI.SingleVariable(it.f.variables[i])
end
# Returns the scalar terms of output_index i
function scalar_terms_at_index(
terms::Vector{<:Union{MOI.VectorAffineTerm,MOI.VectorQuadraticTerm}},
i::Int,
output_index::Integer,
)
return [term.scalar_term for term in terms if term.output_index == i]
end
function Base.getindex(it::ScalarFunctionIterator{<:VAF}, i::Integer)
return SAF(scalar_terms_at_index(it.f.terms, i), it.f.constants[i])
end
function Base.getindex(it::ScalarFunctionIterator{<:VQF}, i::Integer)
lin = scalar_terms_at_index(it.f.affine_terms, i)
quad = scalar_terms_at_index(it.f.quadratic_terms, i)
return SQF(lin, quad, it.f.constants[i])
return MOI.SingleVariable(it.f.variables[output_index])
end

function Base.getindex(
it::ScalarFunctionIterator{MOI.VectorOfVariables},
I::AbstractVector,
output_indices::AbstractVector{<:Integer},
)
return MOI.VectorOfVariables(it.f.variables[I])
return MOI.VectorOfVariables(it.f.variables[output_indices])
end

# VectorAffineFunction

function Base.getindex(
it::ScalarFunctionIterator{MOI.VectorAffineFunction{T}},
output_index::Integer,
) where {T}
return MOI.ScalarAffineFunction{T}(
MOI.ScalarAffineTerm{T}[
it.f.terms[i].scalar_term
for i in ChainedIteratorAtIndex(it.cache, output_index)
],
it.f.constants[output_index],
)
end

function Base.getindex(
it::ScalarFunctionIterator{VAF{T}},
I::AbstractVector,
it::ScalarFunctionIterator{MOI.VectorAffineFunction{T}},
output_indices::AbstractVector{<:Integer},
) where {T}
terms = MOI.VectorAffineTerm{T}[]
# assume at least one term per index
sizehint!(terms, length(I))
constant = it.f.constants[I]
for term in it.f.terms
idx = findfirst(Base.Fix1(==, term.output_index), I)
if idx !== nothing
push!(terms, MOI.VectorAffineTerm(idx, term.scalar_term))
for (i, output_index) in enumerate(output_indices)
for j in ChainedIteratorAtIndex(it.cache, output_index)
push!(terms, MOI.VectorAffineTerm(i, it.f.terms[j].scalar_term))
end
end
return VAF(terms, constant)
return MOI.VectorAffineFunction(terms, it.f.constants[output_indices])
end

# VectorQuadraticFunction

function Base.getindex(
it::ScalarFunctionIterator{VQF{T}},
I::AbstractVector,
) where {T}
affine_terms = MOI.VectorAffineTerm{T}[]
quadratic_terms = MOI.VectorQuadraticTerm{T}[]
constant = Vector{T}(undef, length(I))
for (i, j) in enumerate(I)
g = it[j]
append!(
affine_terms,
map(t -> MOI.VectorAffineTerm(i, t), g.affine_terms),
)
append!(
quadratic_terms,
map(t -> MOI.VectorQuadraticTerm(i, t), g.quadratic_terms),
)
constant[i] = g.constant
it::ScalarFunctionIterator{MOI.VectorQuadraticFunction{T}},
output_index::Integer,
) where {T}
return MOI.ScalarQuadraticFunction(
MOI.ScalarAffineTerm{T}[
it.f.affine_terms[i].scalar_term for i in ChainedIteratorAtIndex(it.cache[1], output_index)
],
MOI.ScalarQuadraticTerm{T}[
it.f.quadratic_terms[i].scalar_term for i in ChainedIteratorAtIndex(it.cache[2], output_index)
],
it.f.constants[output_index],
)
end

function Base.getindex(
it::ScalarFunctionIterator{MOI.VectorQuadraticFunction{T}},
output_indices::AbstractVector{<:Integer},
) where {T}
vat = MOI.VectorAffineTerm{T}[]
vqt = MOI.VectorQuadraticTerm{T}[]
for (i, output_index) in enumerate(output_indices)
for j in ChainedIteratorAtIndex(it.cache[1], output_index)
push!(
vat,
MOI.VectorAffineTerm(i, it.f.affine_terms[j].scalar_term),
)
end
for j in ChainedIteratorAtIndex(it.cache[2], output_index)
push!(
vqt,
MOI.VectorQuadraticTerm(i, it.f.quadratic_terms[j].scalar_term),
)
end
end
return VQF(affine_terms, quadratic_terms, constant)
return MOI.VectorQuadraticFunction(vat, vqt, it.f.constants[output_indices])
end

function zero_with_output_dimension(::Type{Vector{T}}, n::Integer) where {T}
Expand Down