Skip to content

Commit

Permalink
rapide: add tests and fix race issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Jorropo committed Feb 2, 2023
1 parent 244cc4b commit 78269ef
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 10 deletions.
28 changes: 19 additions & 9 deletions rapide/rapide.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,10 @@ func (d *download) finish() {
}

func (d *download) workerFinished() {
// don't decrement d.done because if we suceeded we don't want them to attempt to return an error
d.root.mu.Lock()
defer d.root.mu.Unlock()
if d.root.state == done && len(d.root.childrens) == 0 {
d.finish() // file was downloaded !
var minusOne uint64
minusOne--
if atomic.AddUint64(&d.done, minusOne) == 0 {
d.finish()
}
}

Expand Down Expand Up @@ -121,7 +120,8 @@ type node struct {
state nodeState
}

// expand will run the Traversal and create childrens, it must be called while holding n.mu.Mutex
// expand will run the Traversal and create childrens, it must be called while holding n.mu.Mutex.
// it will unlock n.mu.Mutex
func (n *node) expand(d *download, b blocks.Block) error {
if n.state != todo {
panic(fmt.Sprintf("expanding a node that is not todo: %d", n.state))
Expand All @@ -130,6 +130,7 @@ func (n *node) expand(d *download, b blocks.Block) error {
newResults, err := n.traversal.Traverse(b)
if err != nil {
d.err(err)
n.mu.Unlock()
return err
}

Expand All @@ -147,10 +148,18 @@ func (n *node) expand(d *download, b blocks.Block) error {
}
n.childrens = childrens

for node, parent := n, n.parent; len(node.childrens) == 0; node, parent = parent, parent.parent {
// bubble up node removal
node, parent := n, n.parent
for {
haveChildrens := len(node.childrens) != 0
node.mu.Unlock()

if haveChildrens {
break
}

if parent == nil {
// finished!
d.finish()
return io.EOF
}

Expand All @@ -165,7 +174,8 @@ func (n *node) expand(d *download, b blocks.Block) error {
parent.childrens = append(childrens, nil)[:len(childrens)] // null out for gc
break
}
parent.mu.Unlock()

node, parent = parent, parent.parent
}

return nil
Expand Down
3 changes: 2 additions & 1 deletion rapide/serverdriven.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,11 @@ func (w *serverDrivenWorker) doOneDownload(ctx context.Context, workCid cid.Cid,
return errGotDoneBlock
}
if err := task.expand(w.download, b); err != nil {
task.mu.Unlock()
return err
}

task.mu.Lock()

Switch:
switch len(task.childrens) {
case 0:
Expand Down
200 changes: 200 additions & 0 deletions rapide/serverdriven_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
package rapide_test

import (
"context"
"encoding/binary"
"fmt"
"math"
"testing"
"time"

"github.com/ipfs/go-cid"
"github.com/ipfs/go-libipfs/blocks"
"github.com/ipfs/go-libipfs/ipsl"
"github.com/ipfs/go-libipfs/ipsl/helpers"
. "github.com/ipfs/go-libipfs/rapide"
mh "github.com/multiformats/go-multihash"
)

type mockBlockstore struct {
t *testing.T
delay time.Duration

m map[cid.Cid][]ipsl.CidTraversalPair
}

func (b *mockBlockstore) makeDag(width, depth uint, i *uint64) cid.Cid {
if b.m == nil {
b.m = make(map[cid.Cid][]ipsl.CidTraversalPair)
}

var bytes [8]byte
binary.LittleEndian.PutUint64(bytes[:], *i)
hash, err := mh.Encode(bytes[:], mh.IDENTITY)
if err != nil {
b.t.Fatal(err)
}
*i += 1

var childs []ipsl.CidTraversalPair
if depth == 0 {
childs = []ipsl.CidTraversalPair{}
} else {
childs = make([]ipsl.CidTraversalPair, width)
for idx := range childs {
childs[idx] = ipsl.CidTraversalPair{
Cid: b.makeDag(width, depth-1, i),
Traversal: b,
}
}
}

c := cid.NewCidV1(cid.Raw, hash)
b.m[c] = childs

return c
}

func (bs *mockBlockstore) Traverse(b blocks.Block) ([]ipsl.CidTraversalPair, error) {
c := b.Cid()
childrens, ok := bs.m[c]
if !ok {
bs.t.Fatalf("Traversed not existing cid: %q", c)
}

return childrens, nil
}

func (*mockBlockstore) Serialize() (ipsl.AstNode, []ipsl.BoundScope, error) {
panic("MOCK!")
}

func (*mockBlockstore) SerializeForNetwork() (ipsl.AstNode, []ipsl.BoundScope, error) {
panic("MOCK!")
}

func (bs *mockBlockstore) Download(ctx context.Context, root cid.Cid, traversal ipsl.Traversal) (ClosableBlockIterator, error) {
ctx, cancel := context.WithCancel(ctx)
r := make(chan blocks.BlockOrError)

go func() {
defer close(r)
helpers.SyncDFS(ctx, root, traversal, bs, math.MaxUint, func(b blocks.Block) error {
select {
case r <- blocks.Is(b):
return nil
case <-ctx.Done():
return ctx.Err()
}
})
}()

return download{r, cancel, ctx}, nil
}

func (bs *mockBlockstore) GetBlock(ctx context.Context, c cid.Cid) (blocks.Block, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}

time.Sleep(bs.delay)

h := c.Hash()[1:] // skip 0x00 prefix
_, n := binary.Uvarint(h)
h = h[n:]
return blocks.NewBlockWithCid(h, c)
}

func (bs *mockBlockstore) GetBlocks(ctx context.Context, ks []cid.Cid) <-chan blocks.Block {
r := make(chan blocks.Block, len(ks))
for _, c := range ks {
b, err := bs.GetBlock(ctx, c)
if err != nil {
break
}

r <- b
}

return r
}

func (*mockBlockstore) String() string {
return "mock"
}

type download struct {
c <-chan blocks.BlockOrError
cancel context.CancelFunc
ctx context.Context
}

func (d download) Next() (blocks.Block, error) {
select {
case <-d.ctx.Done():
return nil, d.ctx.Err()
case v := <-d.c:
return v.Get()
}
}

func (d download) Close() error {
d.cancel()
return nil
}

func TestServerDrivenDownloader(t *testing.T) {
for _, tc := range [...]struct {
delay time.Duration
runners uint
width uint
depth uint
}{
{0, 1, 2, 2},
{0, 10, 5, 5},
{0, 100, 3, 10},
{time.Nanosecond, 1, 2, 2},
{time.Nanosecond, 10, 5, 5},
{time.Nanosecond, 100, 3, 10},
{time.Microsecond, 1, 2, 2},
{time.Microsecond, 10, 5, 5},
{time.Microsecond, 100, 3, 10},
{time.Millisecond, 1, 2, 2},
{time.Millisecond, 10, 5, 5},
{time.Millisecond, 100, 3, 10},
} {
t.Run(fmt.Sprintf("%v %v %v %v", tc.delay, tc.runners, tc.width, tc.depth), func(t *testing.T) {
bs := &mockBlockstore{
t: t,
delay: tc.delay,
}
var i uint64
root := bs.makeDag(tc.width, tc.depth, &i)

clients := make([]ServerDrivenDownloader, tc.runners)
for i := tc.runners; i != 0; {
i--
clients[i] = bs
}

seen := make(map[cid.Cid]struct{})
for b := range (&Client{ServerDrivenDownloaders: clients}).Get(context.Background(), root, bs) {
block, err := b.Get()
if err != nil {
t.Fatalf("got error from rapide: %s", err)
}
c := block.Cid()
if _, ok := bs.m[c]; !ok {
t.Fatalf("got cid not in blockstore %s", c)
}
seen[c] = struct{}{}
}

if len(seen) != len(bs.m) {
t.Fatalf("seen less blocks than in blockstore: expected %d; got %d", len(bs.m), len(seen))
}
})
}
}

0 comments on commit 78269ef

Please sign in to comment.