@@ -243,7 +243,7 @@ func (it *nodeIterator) seek(prefix []byte) error {
243
243
key = key [:len (key )- 1 ]
244
244
// Move forward until we're just before the closest match to key.
245
245
for {
246
- state , parentIndex , path , err := it .peek ( bytes . HasPrefix ( key , it . path ) )
246
+ state , parentIndex , path , err := it .peekSeek ( key )
247
247
if err == errIteratorEnd {
248
248
return errIteratorEnd
249
249
} else if err != nil {
@@ -255,16 +255,21 @@ func (it *nodeIterator) seek(prefix []byte) error {
255
255
}
256
256
}
257
257
258
+ // init initializes the the iterator.
259
+ func (it * nodeIterator ) init () (* nodeIteratorState , error ) {
260
+ root := it .trie .Hash ()
261
+ state := & nodeIteratorState {node : it .trie .root , index : - 1 }
262
+ if root != emptyRoot {
263
+ state .hash = root
264
+ }
265
+ return state , state .resolve (it .trie , nil )
266
+ }
267
+
258
268
// peek creates the next state of the iterator.
259
269
func (it * nodeIterator ) peek (descend bool ) (* nodeIteratorState , * int , []byte , error ) {
270
+ // Initialize the iterator if we've just started.
260
271
if len (it .stack ) == 0 {
261
- // Initialize the iterator if we've just started.
262
- root := it .trie .Hash ()
263
- state := & nodeIteratorState {node : it .trie .root , index : - 1 }
264
- if root != emptyRoot {
265
- state .hash = root
266
- }
267
- err := state .resolve (it .trie , nil )
272
+ state , err := it .init ()
268
273
return state , nil , nil , err
269
274
}
270
275
if ! descend {
@@ -292,6 +297,39 @@ func (it *nodeIterator) peek(descend bool) (*nodeIteratorState, *int, []byte, er
292
297
return nil , nil , nil , errIteratorEnd
293
298
}
294
299
300
+ // peekSeek is like peek, but it also tries to skip resolving hashes by skipping
301
+ // over the siblings that do not lead towards the desired seek position.
302
+ func (it * nodeIterator ) peekSeek (seekKey []byte ) (* nodeIteratorState , * int , []byte , error ) {
303
+ // Initialize the iterator if we've just started.
304
+ if len (it .stack ) == 0 {
305
+ state , err := it .init ()
306
+ return state , nil , nil , err
307
+ }
308
+ if ! bytes .HasPrefix (seekKey , it .path ) {
309
+ // If we're skipping children, pop the current node first
310
+ it .pop ()
311
+ }
312
+
313
+ // Continue iteration to the next child
314
+ for len (it .stack ) > 0 {
315
+ parent := it .stack [len (it .stack )- 1 ]
316
+ ancestor := parent .hash
317
+ if (ancestor == common.Hash {}) {
318
+ ancestor = parent .parent
319
+ }
320
+ state , path , ok := it .nextChildAt (parent , ancestor , seekKey )
321
+ if ok {
322
+ if err := state .resolve (it .trie , path ); err != nil {
323
+ return parent , & parent .index , path , err
324
+ }
325
+ return state , & parent .index , path , nil
326
+ }
327
+ // No more child nodes, move back up.
328
+ it .pop ()
329
+ }
330
+ return nil , nil , nil , errIteratorEnd
331
+ }
332
+
295
333
func (st * nodeIteratorState ) resolve (tr * Trie , path []byte ) error {
296
334
if hash , ok := st .node .(hashNode ); ok {
297
335
resolved , err := tr .resolveHash (hash , path )
@@ -304,25 +342,38 @@ func (st *nodeIteratorState) resolve(tr *Trie, path []byte) error {
304
342
return nil
305
343
}
306
344
345
+ func findChild (n * fullNode , index int , path []byte , ancestor common.Hash ) (node , * nodeIteratorState , []byte , int ) {
346
+ var (
347
+ child node
348
+ state * nodeIteratorState
349
+ childPath []byte
350
+ )
351
+ for ; index < len (n .Children ); index ++ {
352
+ if n .Children [index ] != nil {
353
+ child = n .Children [index ]
354
+ hash , _ := child .cache ()
355
+ state = & nodeIteratorState {
356
+ hash : common .BytesToHash (hash ),
357
+ node : child ,
358
+ parent : ancestor ,
359
+ index : - 1 ,
360
+ pathlen : len (path ),
361
+ }
362
+ childPath = append (childPath , path ... )
363
+ childPath = append (childPath , byte (index ))
364
+ return child , state , childPath , index
365
+ }
366
+ }
367
+ return nil , nil , nil , 0
368
+ }
369
+
307
370
func (it * nodeIterator ) nextChild (parent * nodeIteratorState , ancestor common.Hash ) (* nodeIteratorState , []byte , bool ) {
308
371
switch node := parent .node .(type ) {
309
372
case * fullNode :
310
- // Full node, move to the first non-nil child.
311
- for i := parent .index + 1 ; i < len (node .Children ); i ++ {
312
- child := node .Children [i ]
313
- if child != nil {
314
- hash , _ := child .cache ()
315
- state := & nodeIteratorState {
316
- hash : common .BytesToHash (hash ),
317
- node : child ,
318
- parent : ancestor ,
319
- index : - 1 ,
320
- pathlen : len (it .path ),
321
- }
322
- path := append (it .path , byte (i ))
323
- parent .index = i - 1
324
- return state , path , true
325
- }
373
+ //Full node, move to the first non-nil child.
374
+ if child , state , path , index := findChild (node , parent .index + 1 , it .path , ancestor ); child != nil {
375
+ parent .index = index - 1
376
+ return state , path , true
326
377
}
327
378
case * shortNode :
328
379
// Short node, return the pointer singleton child
@@ -342,6 +393,52 @@ func (it *nodeIterator) nextChild(parent *nodeIteratorState, ancestor common.Has
342
393
return parent , it .path , false
343
394
}
344
395
396
+ // nextChildAt is similar to nextChild, except that it targets a child as close to the
397
+ // target key as possible, thus skipping siblings.
398
+ func (it * nodeIterator ) nextChildAt (parent * nodeIteratorState , ancestor common.Hash , key []byte ) (* nodeIteratorState , []byte , bool ) {
399
+ switch n := parent .node .(type ) {
400
+ case * fullNode :
401
+ // Full node, move to the first non-nil child before the desired key position
402
+ child , state , path , index := findChild (n , parent .index + 1 , it .path , ancestor )
403
+ if child == nil {
404
+ // No more children in this fullnode
405
+ return parent , it .path , false
406
+ }
407
+ // If the child we found is already past the seek position, just return it.
408
+ if bytes .Compare (path , key ) >= 0 {
409
+ parent .index = index - 1
410
+ return state , path , true
411
+ }
412
+ // The child is before the seek position. Try advancing
413
+ for {
414
+ nextChild , nextState , nextPath , nextIndex := findChild (n , index + 1 , it .path , ancestor )
415
+ // If we run out of children, or skipped past the target, return the
416
+ // previous one
417
+ if nextChild == nil || bytes .Compare (nextPath , key ) >= 0 {
418
+ parent .index = index - 1
419
+ return state , path , true
420
+ }
421
+ // We found a better child closer to the target
422
+ state , path , index = nextState , nextPath , nextIndex
423
+ }
424
+ case * shortNode :
425
+ // Short node, return the pointer singleton child
426
+ if parent .index < 0 {
427
+ hash , _ := n .Val .cache ()
428
+ state := & nodeIteratorState {
429
+ hash : common .BytesToHash (hash ),
430
+ node : n .Val ,
431
+ parent : ancestor ,
432
+ index : - 1 ,
433
+ pathlen : len (it .path ),
434
+ }
435
+ path := append (it .path , n .Key ... )
436
+ return state , path , true
437
+ }
438
+ }
439
+ return parent , it .path , false
440
+ }
441
+
345
442
func (it * nodeIterator ) push (state * nodeIteratorState , parentIndex * int , path []byte ) {
346
443
it .path = path
347
444
it .stack = append (it .stack , state )
0 commit comments