Skip to content

Commit ef08ebd

Browse files
authored
Cached getindex for scalar iterator (#1263)
1 parent b82a5bc commit ef08ebd

File tree

1 file changed

+63
-29
lines changed

1 file changed

+63
-29
lines changed

src/Utilities/functions.jl

Lines changed: 63 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -326,39 +326,73 @@ end
326326
A type that allows iterating over the scalar-functions that comprise an
327327
`AbstractVectorFunction`.
328328
"""
329-
struct ScalarFunctionIterator{F<:MOI.AbstractVectorFunction}
329+
struct ScalarFunctionIterator{F<:MOI.AbstractVectorFunction, C}
330330
f::F
331-
# Vectors which map output indices to their terms.
332-
affine::Vector{Vector{Int}}
333-
quadratic::Vector{Vector{Int}}
331+
# Cache that can be used to store a precomputed datastructure that allows
332+
# an efficient implementation of `getindex`.
333+
cache::C
334334
end
335-
336-
function ScalarFunctionIterator(f::MOI.VectorOfVariables)
335+
function ScalarFunctionIterator(func::MOI.AbstractVectorFunction)
337336
return ScalarFunctionIterator(
338-
f,
339-
Vector{Int}[],
340-
Vector{Int}[],
337+
func,
338+
scalar_iterator_cache(func),
341339
)
342340
end
343341

344-
function ScalarFunctionIterator(f::MOI.VectorAffineFunction)
345-
d = [Int[] for i = 1:MOI.output_dimension(f)]
346-
for (i, term) in enumerate(f.terms)
347-
push!(d[term.output_index], i)
342+
scalar_iterator_cache(func::MOI.AbstractVectorFunction) = nothing
343+
344+
function output_index_iterator(terms::AbstractVector, output_dimension)
345+
start = zeros(Int, output_dimension)
346+
next = Vector{Int}(undef, length(terms))
347+
last = zeros(Int, output_dimension)
348+
for i in eachindex(terms)
349+
j = terms[i].output_index
350+
if iszero(last[j])
351+
start[j] = i
352+
else
353+
next[last[j]] = i
354+
end
355+
last[j] = i
348356
end
349-
return ScalarFunctionIterator(f, d, Vector{Int}[])
357+
for j in eachindex(last)
358+
if !iszero(last[j])
359+
next[last[j]] = 0
360+
end
361+
end
362+
return ChainedIterator(start, next)
363+
end
364+
struct ChainedIterator
365+
start::Vector{Int}
366+
next::Vector{Int}
367+
end
368+
struct ChainedIteratorAtIndex
369+
start::Int
370+
next::Vector{Int}
371+
end
372+
function ChainedIteratorAtIndex(it::ChainedIterator, index::Int)
373+
return ChainedIteratorAtIndex(it.start[index], it.next)
374+
end
375+
#TODO We could also precompute the length for each `output_index`,
376+
# check that it's a win.
377+
Base.IteratorSize(::ChainedIteratorAtIndex) = Base.SizeUnknown()
378+
function Base.iterate(it::ChainedIteratorAtIndex, i = it.start)
379+
if iszero(i)
380+
return nothing
381+
else
382+
return i, it.next[i]
383+
end
384+
end
385+
386+
function ScalarFunctionIterator(f::MOI.VectorAffineFunction)
387+
return ScalarFunctionIterator(f, output_index_iterator(f.terms, MOI.output_dimension(f)))
350388
end
351389

352390
function ScalarFunctionIterator(f::MOI.VectorQuadraticFunction)
353-
aff = [Int[] for i = 1:MOI.output_dimension(f)]
354-
quad = [Int[] for i = 1:MOI.output_dimension(f)]
355-
for (i, term) in enumerate(f.affine_terms)
356-
push!(aff[term.output_index], i)
357-
end
358-
for (i, term) in enumerate(f.quadratic_terms)
359-
push!(quad[term.output_index], i)
360-
end
361-
return ScalarFunctionIterator(f, aff, quad)
391+
return ScalarFunctionIterator(
392+
f,
393+
(output_index_iterator(f.affine_terms, MOI.output_dimension(f)),
394+
output_index_iterator(f.quadratic_terms, MOI.output_dimension(f))),
395+
)
362396
end
363397

364398
eachscalar(f::MOI.AbstractVectorFunction) = ScalarFunctionIterator(f)
@@ -407,7 +441,7 @@ function Base.getindex(
407441
return MOI.ScalarAffineFunction{T}(
408442
MOI.ScalarAffineTerm{T}[
409443
it.f.terms[i].scalar_term
410-
for i in it.affine[output_index]
444+
for i in ChainedIteratorAtIndex(it.cache, output_index)
411445
],
412446
it.f.constants[output_index],
413447
)
@@ -419,7 +453,7 @@ function Base.getindex(
419453
) where {T}
420454
terms = MOI.VectorAffineTerm{T}[]
421455
for (i, output_index) in enumerate(output_indices)
422-
for j in it.affine[output_index]
456+
for j in ChainedIteratorAtIndex(it.cache, output_index)
423457
push!(terms, MOI.VectorAffineTerm(i, it.f.terms[j].scalar_term))
424458
end
425459
end
@@ -434,10 +468,10 @@ function Base.getindex(
434468
) where {T}
435469
return MOI.ScalarQuadraticFunction(
436470
MOI.ScalarAffineTerm{T}[
437-
it.f.affine_terms[i].scalar_term for i in it.affine[output_index]
471+
it.f.affine_terms[i].scalar_term for i in ChainedIteratorAtIndex(it.cache[1], output_index)
438472
],
439473
MOI.ScalarQuadraticTerm{T}[
440-
it.f.quadratic_terms[i].scalar_term for i in it.quadratic[output_index]
474+
it.f.quadratic_terms[i].scalar_term for i in ChainedIteratorAtIndex(it.cache[2], output_index)
441475
],
442476
it.f.constants[output_index],
443477
)
@@ -450,13 +484,13 @@ function Base.getindex(
450484
vat = MOI.VectorAffineTerm{T}[]
451485
vqt = MOI.VectorQuadraticTerm{T}[]
452486
for (i, output_index) in enumerate(output_indices)
453-
for j in it.affine[output_index]
487+
for j in ChainedIteratorAtIndex(it.cache[1], output_index)
454488
push!(
455489
vat,
456490
MOI.VectorAffineTerm(i, it.f.affine_terms[j].scalar_term),
457491
)
458492
end
459-
for j in it.quadratic[output_index]
493+
for j in ChainedIteratorAtIndex(it.cache[2], output_index)
460494
push!(
461495
vqt,
462496
MOI.VectorQuadraticTerm(i, it.f.quadratic_terms[j].scalar_term),

0 commit comments

Comments
 (0)