Skip to content
This repository has been archived by the owner on Aug 2, 2021. It is now read-only.

retrieval: fix memory leak #2103

Merged
merged 4 commits into from
Feb 19, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions network/retrieval/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ func (p *Peer) addRetrieval(ruid uint, addr storage.Address) {
p.retrievals[ruid] = addr
}

func (p *Peer) expireRetrieval(ruid uint) {
p.mtx.Lock()
defer p.mtx.Unlock()

delete(p.retrievals, ruid)
}

// chunkReceived is called upon ChunkDelivery message reception
// it is meant to idenfify unsolicited chunk deliveries
func (p *Peer) checkRequest(ruid uint, addr storage.Address) error {
Expand Down
14 changes: 9 additions & 5 deletions network/retrieval/retrieve.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ func (r *Retrieval) handleChunkDelivery(ctx context.Context, p *Peer, msg *Chunk
}

// RequestFromPeers sends a chunk retrieve request to the next found peer
func (r *Retrieval) RequestFromPeers(ctx context.Context, req *storage.Request, localID enode.ID) (*enode.ID, error) {
func (r *Retrieval) RequestFromPeers(ctx context.Context, req *storage.Request, localID enode.ID) (*enode.ID, func(), error) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to have a short explanation what is the purpose of the returned function, as now, it requires to go through the code. Or to name it in function signature.

r.logger.Debug("retrieval.requestFromPeers", "req.Addr", req.Addr, "localID", localID)
metrics.GetOrRegisterCounter("network.retrieve.request_from_peers", nil).Inc(1)

Expand All @@ -395,7 +395,7 @@ FINDPEER:
sp, err := r.findPeerLB(ctx, req)
if err != nil {
r.logger.Trace(err.Error())
return nil, err
return nil, func() {}, err
}

protoPeer := r.getPeer(sp.ID())
Expand All @@ -405,7 +405,7 @@ FINDPEER:
retries++
if retries == maxFindPeerRetries {
r.logger.Error("max find peer retries reached", "max retries", maxFindPeerRetries, "ref", req.Addr)
return nil, ErrNoPeerFound
return nil, func() {}, ErrNoPeerFound
}

goto FINDPEER
Expand All @@ -417,14 +417,18 @@ FINDPEER:
}
protoPeer.logger.Trace("sending retrieve request", "ref", ret.Addr, "origin", localID, "ruid", ret.Ruid)
protoPeer.addRetrieval(ret.Ruid, ret.Addr)
cleanup := func() {
protoPeer.expireRetrieval(ret.Ruid)
}
err = protoPeer.Send(ctx, ret)
if err != nil {
protoPeer.logger.Error("error sending retrieve request to peer", "ruid", ret.Ruid, "err", err)
return nil, err
cleanup()
return nil, func() {}, err
}

spID := protoPeer.ID()
return &spID, nil
return &spID, cleanup, nil
}

func (r *Retrieval) Start(server *p2p.Server) error {
Expand Down
8 changes: 4 additions & 4 deletions network/retrieval/retrieve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ func TestUnsolicitedChunkDeliveryFaultyAddr(t *testing.T) {
t.Fatal(err)
}
defer teardown()
ns.RemoteGet = func(ctx context.Context, req *storage.Request, localID enode.ID) (*enode.ID, error) {
return &enode.ID{}, nil
ns.RemoteGet = func(ctx context.Context, req *storage.Request, localID enode.ID) (*enode.ID, func(), error) {
return &enode.ID{}, func() {}, nil
}
node := tester.Nodes[0]

Expand Down Expand Up @@ -267,8 +267,8 @@ func TestUnsolicitedChunkDeliveryDouble(t *testing.T) {
t.Fatal(err)
}
defer teardown()
ns.RemoteGet = func(ctx context.Context, req *storage.Request, localID enode.ID) (*enode.ID, error) {
return &enode.ID{}, nil
ns.RemoteGet = func(ctx context.Context, req *storage.Request, localID enode.ID) (*enode.ID, func(), error) {
return &enode.ID{}, func() {}, nil
}
node := tester.Nodes[0]

Expand Down
8 changes: 4 additions & 4 deletions storage/feed/testutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ func NewTestHandler(datadir string, params *HandlerParams) (*TestHandler, error)
localStore := chunk.NewValidatorStore(db, storage.NewContentAddressValidator(storage.MakeHashFunc(feedsHashAlgorithm)), fh)

netStore := storage.NewNetStore(localStore, network.NewBzzAddr(make([]byte, 32), nil))
netStore.RemoteGet = func(ctx context.Context, req *storage.Request, localID enode.ID) (*enode.ID, error) {
return nil, errors.New("not found")
netStore.RemoteGet = func(ctx context.Context, req *storage.Request, localID enode.ID) (*enode.ID, func(), error) {
return nil, func() {}, errors.New("not found")
}
fh.SetStore(netStore)
return &TestHandler{fh}, nil
Expand All @@ -69,8 +69,8 @@ func newTestHandlerWithStore(fh *Handler, datadir string, db chunk.Store, params
localStore := chunk.NewValidatorStore(db, storage.NewContentAddressValidator(storage.MakeHashFunc(feedsHashAlgorithm)), fh)

netStore := storage.NewNetStore(localStore, network.NewBzzAddr(make([]byte, 32), nil))
netStore.RemoteGet = func(ctx context.Context, req *storage.Request, localID enode.ID) (*enode.ID, error) {
return nil, errors.New("not found")
netStore.RemoteGet = func(ctx context.Context, req *storage.Request, localID enode.ID) (*enode.ID, func(), error) {
return nil, func() {}, errors.New("not found")
}
fh.SetStore(netStore)
return &TestHandler{fh}, nil
Expand Down
5 changes: 3 additions & 2 deletions storage/netstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (fi *Fetcher) SafeClose(ch chunk.Chunk) {
})
}

type RemoteGetFunc func(ctx context.Context, req *Request, localID enode.ID) (*enode.ID, error)
type RemoteGetFunc func(ctx context.Context, req *Request, localID enode.ID) (*enode.ID, func(), error)

// NetStore is an extension of LocalStore
// it implements the ChunkStore interface
Expand Down Expand Up @@ -247,13 +247,14 @@ func (n *NetStore) RemoteFetch(ctx context.Context, req *Request, fi *Fetcher) (

log.Trace("remote.fetch", "ref", ref)

currentPeer, err := n.RemoteGet(ctx, req, n.LocalID)
currentPeer, cleanup, err := n.RemoteGet(ctx, req, n.LocalID)
if err != nil {
n.logger.Trace(err.Error(), "ref", ref)
osp.LogFields(olog.String("err", err.Error()))
osp.Finish()
return nil, ErrNoSuitablePeer
}
defer cleanup()

// add peer to the set of peers to skip from now
n.logger.Trace("remote.fetch, adding peer to skip", "ref", ref, "peer", currentPeer.String())
Expand Down