@@ -326,39 +326,73 @@ end
326
326
A type that allows iterating over the scalar-functions that comprise an
327
327
`AbstractVectorFunction`.
328
328
"""
329
- struct ScalarFunctionIterator{F<: MOI.AbstractVectorFunction }
329
+ struct ScalarFunctionIterator{F<: MOI.AbstractVectorFunction , C }
330
330
f:: F
331
- # Vectors which map output indices to their terms.
332
- affine :: Vector{Vector{Int}}
333
- quadratic :: Vector{Vector{Int}}
331
+ # Cache that can be used to store a precomputed datastructure that allows
332
+ # an efficient implementation of `getindex`.
333
+ cache :: C
334
334
end
335
-
336
- function ScalarFunctionIterator (f:: MOI.VectorOfVariables )
335
+ function ScalarFunctionIterator (func:: MOI.AbstractVectorFunction )
337
336
return ScalarFunctionIterator (
338
- f,
339
- Vector{Int}[],
340
- Vector{Int}[],
337
+ func,
338
+ scalar_iterator_cache (func),
341
339
)
342
340
end
343
341
344
- function ScalarFunctionIterator (f:: MOI.VectorAffineFunction )
345
- d = [Int[] for i = 1 : MOI. output_dimension (f)]
346
- for (i, term) in enumerate (f. terms)
347
- push! (d[term. output_index], i)
342
+ scalar_iterator_cache (func:: MOI.AbstractVectorFunction ) = nothing
343
+
344
+ function output_index_iterator (terms:: AbstractVector , output_dimension)
345
+ start = zeros (Int, output_dimension)
346
+ next = Vector {Int} (undef, length (terms))
347
+ last = zeros (Int, output_dimension)
348
+ for i in eachindex (terms)
349
+ j = terms[i]. output_index
350
+ if iszero (last[j])
351
+ start[j] = i
352
+ else
353
+ next[last[j]] = i
354
+ end
355
+ last[j] = i
348
356
end
349
- return ScalarFunctionIterator (f, d, Vector{Int}[])
357
+ for j in eachindex (last)
358
+ if ! iszero (last[j])
359
+ next[last[j]] = 0
360
+ end
361
+ end
362
+ return ChainedIterator (start, next)
363
+ end
364
+ struct ChainedIterator
365
+ start:: Vector{Int}
366
+ next:: Vector{Int}
367
+ end
368
+ struct ChainedIteratorAtIndex
369
+ start:: Int
370
+ next:: Vector{Int}
371
+ end
372
+ function ChainedIteratorAtIndex (it:: ChainedIterator , index:: Int )
373
+ return ChainedIteratorAtIndex (it. start[index], it. next)
374
+ end
375
+ # TODO We could also precompute the length for each `output_index`,
376
+ # check that it's a win.
377
+ Base. IteratorSize (:: ChainedIteratorAtIndex ) = Base. SizeUnknown ()
378
+ function Base. iterate (it:: ChainedIteratorAtIndex , i = it. start)
379
+ if iszero (i)
380
+ return nothing
381
+ else
382
+ return i, it. next[i]
383
+ end
384
+ end
385
+
386
+ function ScalarFunctionIterator (f:: MOI.VectorAffineFunction )
387
+ return ScalarFunctionIterator (f, output_index_iterator (f. terms, MOI. output_dimension (f)))
350
388
end
351
389
352
390
function ScalarFunctionIterator (f:: MOI.VectorQuadraticFunction )
353
- aff = [Int[] for i = 1 : MOI. output_dimension (f)]
354
- quad = [Int[] for i = 1 : MOI. output_dimension (f)]
355
- for (i, term) in enumerate (f. affine_terms)
356
- push! (aff[term. output_index], i)
357
- end
358
- for (i, term) in enumerate (f. quadratic_terms)
359
- push! (quad[term. output_index], i)
360
- end
361
- return ScalarFunctionIterator (f, aff, quad)
391
+ return ScalarFunctionIterator (
392
+ f,
393
+ (output_index_iterator (f. affine_terms, MOI. output_dimension (f)),
394
+ output_index_iterator (f. quadratic_terms, MOI. output_dimension (f))),
395
+ )
362
396
end
363
397
364
398
eachscalar (f:: MOI.AbstractVectorFunction ) = ScalarFunctionIterator (f)
@@ -407,7 +441,7 @@ function Base.getindex(
407
441
return MOI. ScalarAffineFunction {T} (
408
442
MOI. ScalarAffineTerm{T}[
409
443
it. f. terms[i]. scalar_term
410
- for i in it. affine[ output_index]
444
+ for i in ChainedIteratorAtIndex ( it. cache, output_index)
411
445
],
412
446
it. f. constants[output_index],
413
447
)
@@ -419,7 +453,7 @@ function Base.getindex(
419
453
) where {T}
420
454
terms = MOI. VectorAffineTerm{T}[]
421
455
for (i, output_index) in enumerate (output_indices)
422
- for j in it. affine[ output_index]
456
+ for j in ChainedIteratorAtIndex ( it. cache, output_index)
423
457
push! (terms, MOI. VectorAffineTerm (i, it. f. terms[j]. scalar_term))
424
458
end
425
459
end
@@ -434,10 +468,10 @@ function Base.getindex(
434
468
) where {T}
435
469
return MOI. ScalarQuadraticFunction (
436
470
MOI. ScalarAffineTerm{T}[
437
- it. f. affine_terms[i]. scalar_term for i in it. affine[output_index]
471
+ it. f. affine_terms[i]. scalar_term for i in ChainedIteratorAtIndex ( it. cache[ 1 ], output_index)
438
472
],
439
473
MOI. ScalarQuadraticTerm{T}[
440
- it. f. quadratic_terms[i]. scalar_term for i in it. quadratic[output_index]
474
+ it. f. quadratic_terms[i]. scalar_term for i in ChainedIteratorAtIndex ( it. cache[ 2 ], output_index)
441
475
],
442
476
it. f. constants[output_index],
443
477
)
@@ -450,13 +484,13 @@ function Base.getindex(
450
484
vat = MOI. VectorAffineTerm{T}[]
451
485
vqt = MOI. VectorQuadraticTerm{T}[]
452
486
for (i, output_index) in enumerate (output_indices)
453
- for j in it. affine[output_index]
487
+ for j in ChainedIteratorAtIndex ( it. cache[ 1 ], output_index)
454
488
push! (
455
489
vat,
456
490
MOI. VectorAffineTerm (i, it. f. affine_terms[j]. scalar_term),
457
491
)
458
492
end
459
- for j in it. quadratic[output_index]
493
+ for j in ChainedIteratorAtIndex ( it. cache[ 2 ], output_index)
460
494
push! (
461
495
vqt,
462
496
MOI. VectorQuadraticTerm (i, it. f. quadratic_terms[j]. scalar_term),
0 commit comments