Skip to content
This repository has been archived by the owner on May 26, 2022. It is now read-only.

Commit

Permalink
implement a Transport.Close that waits for the reuse's GC to finish
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Jul 7, 2021
1 parent 41666cb commit 21966cc
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 39 deletions.
20 changes: 20 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"crypto/rand"
"fmt"
"io"
"io/ioutil"
mrand "math/rand"
"net"
Expand Down Expand Up @@ -70,11 +71,13 @@ var _ = Describe("Connection", func() {
It("handshakes on IPv4", func() {
serverTransport, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln.Close()

clientTransport, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer clientTransport.(io.Closer).Close()
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
Expand All @@ -94,11 +97,13 @@ var _ = Describe("Connection", func() {
It("handshakes on IPv6", func() {
serverTransport, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln := runServer(serverTransport, "/ip6/::1/udp/0/quic")
defer ln.Close()

clientTransport, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer clientTransport.(io.Closer).Close()
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
Expand All @@ -118,11 +123,13 @@ var _ = Describe("Connection", func() {
It("opens and accepts streams", func() {
serverTransport, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln.Close()

clientTransport, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer clientTransport.(io.Closer).Close()
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
Expand All @@ -147,13 +154,15 @@ var _ = Describe("Connection", func() {

serverTransport, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")

clientTransport, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
// dial, but expect the wrong peer ID
_, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), thirdPartyID)
Expect(err).To(HaveOccurred())
defer clientTransport.(io.Closer).Close()
Expect(err.Error()).To(ContainSubstring("CRYPTO_ERROR"))

done := make(chan struct{})
Expand All @@ -172,6 +181,7 @@ var _ = Describe("Connection", func() {
cg.EXPECT().InterceptAccept(gomock.Any())
serverTransport, err := NewTransport(serverKey, nil, cg)
Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln.Close()

Expand All @@ -185,6 +195,7 @@ var _ = Describe("Connection", func() {

clientTransport, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer clientTransport.(io.Closer).Close()
// make sure that connection attempts fails
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
Expect(err).ToNot(HaveOccurred())
Expand All @@ -205,6 +216,7 @@ var _ = Describe("Connection", func() {
It("gates secured connections", func() {
serverTransport, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln.Close()

Expand All @@ -213,6 +225,7 @@ var _ = Describe("Connection", func() {

clientTransport, err := NewTransport(clientKey, nil, cg)
Expect(err).ToNot(HaveOccurred())
defer clientTransport.(io.Closer).Close()

// make sure that connection attempts fails
_, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
Expand All @@ -232,10 +245,12 @@ var _ = Describe("Connection", func() {

serverTransport, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln1 := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln1.Close()
serverTransport2, err := NewTransport(serverKey2, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer serverTransport2.(io.Closer).Close()
ln2 := runServer(serverTransport2, "/ip4/127.0.0.1/udp/0/quic")
defer ln2.Close()

Expand All @@ -262,6 +277,7 @@ var _ = Describe("Connection", func() {

clientTransport, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer clientTransport.(io.Closer).Close()
c1, err := clientTransport.Dial(context.Background(), ln1.Multiaddr(), serverID)
Expect(err).ToNot(HaveOccurred())
defer c1.Close()
Expand Down Expand Up @@ -291,6 +307,7 @@ var _ = Describe("Connection", func() {
It("sends stateless resets", func() {
serverTransport, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")

var drop uint32
Expand All @@ -307,6 +324,7 @@ var _ = Describe("Connection", func() {
// establish a connection
clientTransport, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer clientTransport.(io.Closer).Close()
proxyAddr, err := toQuicMultiaddr(proxy.LocalAddr())
Expect(err).ToNot(HaveOccurred())
conn, err := clientTransport.Dial(context.Background(), proxyAddr, serverID)
Expand Down Expand Up @@ -349,6 +367,7 @@ var _ = Describe("Connection", func() {
It("hole punches", func() {
t1, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer t1.(io.Closer).Close()
laddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic")
Expect(err).ToNot(HaveOccurred())
ln1, err := t1.Listen(laddr)
Expand All @@ -364,6 +383,7 @@ var _ = Describe("Connection", func() {

t2, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer t2.(io.Closer).Close()
ln2, err := t2.Listen(laddr)
Expect(err).ToNot(HaveOccurred())
done2 := make(chan struct{})
Expand Down
13 changes: 1 addition & 12 deletions libp2pquic_suite_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
package libp2pquic

import (
"bytes"
mrand "math/rand"
"runtime/pprof"
"strings"
"testing"
"time"

gomock "github.com/golang/mock/gomock"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go"

. "github.com/onsi/ginkgo"
Expand All @@ -31,16 +28,9 @@ var (
mockCtrl *gomock.Controller
)

func isGarbageCollectorRunning() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "go-libp2p-quic-transport.(*reuse).runGarbageCollector")
}

var _ = BeforeEach(func() {
mockCtrl = gomock.NewController(GinkgoT())

Expect(isGarbageCollectorRunning()).To(BeFalse())
garbageCollectIntervalOrig = garbageCollectInterval
maxUnusedDurationOrig = maxUnusedDuration
garbageCollectInterval = 50 * time.Millisecond
Expand All @@ -52,7 +42,6 @@ var _ = BeforeEach(func() {
var _ = AfterEach(func() {
mockCtrl.Finish()

Eventually(isGarbageCollectorRunning).Should(BeFalse())
garbageCollectInterval = garbageCollectIntervalOrig
maxUnusedDuration = maxUnusedDurationOrig
quicConfig = origQuicConfig
Expand Down
7 changes: 6 additions & 1 deletion listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@ import (
"crypto/x509"
"errors"
"fmt"
"io"
"net"
"syscall"

ic "github.com/libp2p/go-libp2p-core/crypto"
tpt "github.com/libp2p/go-libp2p-core/transport"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go"

ma "github.com/multiformats/go-multiaddr"
. "github.com/onsi/ginkgo"
Expand All @@ -38,6 +39,10 @@ var _ = Describe("Listener", func() {
Expect(err).ToNot(HaveOccurred())
})

AfterEach(func() {
Expect(t.(io.Closer).Close()).To(Succeed())
})

It("uses a conn that can interface assert to a UDPConn for listening", func() {
origQuicListen := quicListen
defer func() { quicListen = origQuicListen }()
Expand Down
72 changes: 47 additions & 25 deletions reuse.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,52 +53,62 @@ type reuse struct {

garbageCollectorRunning bool

closeChan chan struct{}
garbageCollectorStopChan chan struct{}

unicast map[string] /* IP.String() */ map[int] /* port */ *reuseConn
// global contains connections that are listening on 0.0.0.0 / ::
global map[int]*reuseConn
}

func newReuse() *reuse {
return &reuse{
unicast: make(map[string]map[int]*reuseConn),
global: make(map[int]*reuseConn),
unicast: make(map[string]map[int]*reuseConn),
global: make(map[int]*reuseConn),
closeChan: make(chan struct{}),
}
}

func (r *reuse) runGarbageCollector() {
defer close(r.garbageCollectorStopChan)
ticker := time.NewTicker(garbageCollectInterval)
defer ticker.Stop()

for now := range ticker.C {
var shouldExit bool
r.mutex.Lock()
for key, conn := range r.global {
if conn.ShouldGarbageCollect(now) {
conn.Close()
delete(r.global, key)
}
}
for ukey, conns := range r.unicast {
for key, conn := range conns {
for {
select {
case <-r.closeChan:
return
case now := <-ticker.C:
var shouldExit bool
r.mutex.Lock()
for key, conn := range r.global {
if conn.ShouldGarbageCollect(now) {
conn.Close()
delete(conns, key)
delete(r.global, key)
}
}
if len(conns) == 0 {
delete(r.unicast, ukey)
for ukey, conns := range r.unicast {
for key, conn := range conns {
if conn.ShouldGarbageCollect(now) {
conn.Close()
delete(conns, key)
}
}
if len(conns) == 0 {
delete(r.unicast, ukey)
}
}
}

// stop the garbage collector if we're not tracking any connections
if len(r.global) == 0 && len(r.unicast) == 0 {
r.garbageCollectorRunning = false
shouldExit = true
}
r.mutex.Unlock()
// stop the garbage collector if we're not tracking any connections
if len(r.global) == 0 && len(r.unicast) == 0 {
r.garbageCollectorRunning = false
shouldExit = true
}
r.mutex.Unlock()

if shouldExit {
return
if shouldExit {
return
}
}
}
}
Expand All @@ -107,6 +117,7 @@ func (r *reuse) runGarbageCollector() {
func (r *reuse) maybeStartGarbageCollector() {
if !r.garbageCollectorRunning {
r.garbageCollectorRunning = true
r.garbageCollectorStopChan = make(chan struct{})
go r.runGarbageCollector()
}
}
Expand Down Expand Up @@ -199,3 +210,14 @@ func (r *reuse) Listen(network string, laddr *net.UDPAddr) (*reuseConn, error) {
r.unicast[localAddr.IP.String()][localAddr.Port] = rconn
return rconn, err
}

func (r *reuse) Close() error {
r.mutex.Lock()
defer r.mutex.Unlock()
close(r.closeChan)
if r.garbageCollectorRunning {
<-r.garbageCollectorStopChan
r.garbageCollectorRunning = false
}
return nil
}
14 changes: 13 additions & 1 deletion reuse_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package libp2pquic

import (
"bytes"
"net"
"runtime/pprof"
"strings"
"time"

"github.com/libp2p/go-netroute"
Expand Down Expand Up @@ -30,7 +33,6 @@ func closeAllConns(reuse *reuse) {
}
}
reuse.mutex.Unlock()
Eventually(isGarbageCollectorRunning).Should(BeFalse())
}

func OnPlatformsWithRoutingTablesIt(description string, f interface{}) {
Expand All @@ -48,6 +50,16 @@ var _ = Describe("Reuse", func() {
reuse = newReuse()
})

AfterEach(func() {
Expect(reuse.Close()).To(Succeed())
})

isGarbageCollectorRunning := func() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "go-libp2p-quic-transport.(*reuse).runGarbageCollector")
}

Context("creating and reusing connections", func() {
AfterEach(func() { closeAllConns(reuse) })

Expand Down
11 changes: 11 additions & 0 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ func (c *connManager) Dial(network string, raddr *net.UDPAddr) (*reuseConn, erro
return reuse.Dial(network, raddr)
}

func (c *connManager) Close() error {
if err := c.reuseUDP6.Close(); err != nil {
return err
}
return c.reuseUDP4.Close()
}

// The Transport implements the tpt.Transport interface for QUIC connections.
type transport struct {
privKey ic.PrivKey
Expand Down Expand Up @@ -349,3 +356,7 @@ func (t *transport) Protocols() []int {
func (t *transport) String() string {
return "QUIC"
}

func (t *transport) Close() error {
return t.connManager.Close()
}
Loading

0 comments on commit 21966cc

Please sign in to comment.