@@ -26,6 +26,8 @@ import (
26
26
"os"
27
27
"sync"
28
28
29
+ "golang.org/x/sync/errgroup"
30
+
29
31
format "github.com/ipfs/go-unixfs"
30
32
"github.com/ipfs/go-unixfs/internal"
31
33
@@ -438,91 +440,71 @@ func (ds *Shard) walkLinks(processLinkValues func(formattedLink *ipld.Link) erro
438
440
439
441
func parallelWalkDepth (ctx context.Context , root * Shard , dserv ipld.DAGService , processShardValues func (formattedLink * ipld.Link ) error ) error {
440
442
const concurrency = 32
441
- visit := cid .NewSet ().Visit
443
+
444
+ var visitlk sync.Mutex
445
+ visitSet := cid .NewSet ()
446
+ visit := visitSet .Visit
442
447
443
448
type shardCidUnion struct {
444
449
cid cid.Cid
445
450
shard * Shard
446
451
}
447
452
453
+ // Setup synchronization
454
+ grp , errGrpCtx := errgroup .WithContext (ctx )
455
+
456
+ // Input and output queues for workers.
448
457
feed := make (chan * shardCidUnion )
449
458
out := make (chan * listCidShardUnion )
450
459
done := make (chan struct {})
451
460
452
- var visitlk sync.Mutex
453
- var wg sync.WaitGroup
454
-
455
- errChan := make (chan error )
456
- fetchersCtx , cancel := context .WithCancel (ctx )
457
- defer wg .Wait ()
458
- defer cancel ()
459
461
for i := 0 ; i < concurrency ; i ++ {
460
- wg .Add (1 )
461
- go func () {
462
- defer wg .Done ()
463
- for cdepth := range feed {
462
+ grp .Go (func () error {
463
+ for shardOrCID := range feed {
464
464
var shouldVisit bool
465
465
466
- if cdepth .shard != nil {
466
+ if shardOrCID .shard != nil {
467
467
shouldVisit = true
468
468
} else {
469
469
visitlk .Lock ()
470
- shouldVisit = visit (cdepth .cid )
470
+ shouldVisit = visit (shardOrCID .cid )
471
471
visitlk .Unlock ()
472
472
}
473
473
474
474
if shouldVisit {
475
475
var nextShard * Shard
476
- if cdepth .shard != nil {
477
- nextShard = cdepth .shard
476
+ if shardOrCID .shard != nil {
477
+ nextShard = shardOrCID .shard
478
478
} else {
479
- nd , err := dserv .Get (ctx , cdepth .cid )
479
+ nd , err := dserv .Get (ctx , shardOrCID .cid )
480
480
if err != nil {
481
- if err != nil {
482
- select {
483
- case errChan <- err :
484
- case <- fetchersCtx .Done ():
485
- }
486
- return
487
- }
481
+ return err
488
482
}
489
483
nextShard , err = NewHamtFromDag (dserv , nd )
490
484
if err != nil {
491
- if err != nil {
492
- if err != nil {
493
- select {
494
- case errChan <- err :
495
- case <- fetchersCtx .Done ():
496
- }
497
- return
498
- }
499
- }
485
+ return err
500
486
}
501
487
}
502
488
503
489
nextLinks , err := nextShard .walkLinks (processShardValues )
504
490
if err != nil {
505
- select {
506
- case errChan <- err :
507
- case <- fetchersCtx .Done ():
508
- }
509
- return
491
+ return err
510
492
}
511
493
512
494
select {
513
495
case out <- nextLinks :
514
- case <- fetchersCtx .Done ():
515
- return
496
+ case <- errGrpCtx .Done ():
497
+ return nil
516
498
}
517
499
}
518
500
select {
519
501
case done <- struct {}{}:
520
- case <- fetchersCtx .Done ():
502
+ case <- errGrpCtx .Done ():
521
503
}
522
504
}
523
- }()
505
+ return nil
506
+ })
524
507
}
525
- defer close (feed )
526
508
527
509
send := feed
528
510
var todoQueue []* shardCidUnion
@@ -532,6 +514,7 @@ func parallelWalkDepth(ctx context.Context, root *Shard, dserv ipld.DAGService,
532
514
shard : root ,
533
515
}
534
516
517
+ dispatcherLoop:
535
518
for {
536
519
select {
537
520
case send <- next :
@@ -546,40 +529,39 @@ func parallelWalkDepth(ctx context.Context, root *Shard, dserv ipld.DAGService,
546
529
case <- done :
547
530
inProgress --
548
531
if inProgress == 0 && next == nil {
549
- return nil
532
+ break dispatcherLoop
550
533
}
551
- case linksDepth := <- out :
552
- for _ , c := range linksDepth .links {
553
- cd := & shardCidUnion {
534
+ case nextNodes := <- out :
535
+ for _ , c := range nextNodes .links {
536
+ shardOrCid := & shardCidUnion {
554
537
cid : c ,
555
538
}
556
539
557
540
if next == nil {
558
- next = cd
541
+ next = shardOrCid
559
542
send = feed
560
543
} else {
561
- todoQueue = append (todoQueue , cd )
544
+ todoQueue = append (todoQueue , shardOrCid )
562
545
}
563
546
}
564
- for _ , shard := range linksDepth .shards {
565
- cd := & shardCidUnion {
547
+ for _ , shard := range nextNodes .shards {
548
+ shardOrCid := & shardCidUnion {
566
549
shard : shard ,
567
550
}
568
551
569
552
if next == nil {
570
- next = cd
553
+ next = shardOrCid
571
554
send = feed
572
555
} else {
573
- todoQueue = append (todoQueue , cd )
556
+ todoQueue = append (todoQueue , shardOrCid )
574
557
}
575
558
}
576
- case err := <- errChan :
577
- return err
578
-
579
- case <- ctx .Done ():
580
- return ctx .Err ()
559
+ case <- errGrpCtx .Done ():
560
+ break dispatcherLoop
581
561
}
582
562
}
563
+ close (feed )
564
+ return grp .Wait ()
583
565
}
584
566
585
567
func emitResult (ctx context.Context , linkResults chan <- format.LinkResult , r format.LinkResult ) {
0 commit comments