Skip to content
This repository has been archived by the owner on May 26, 2022. It is now read-only.

implement a Transport.Close that waits for the reuse's GC to finish #211

Merged
merged 4 commits into from
Jul 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/unit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ jobs:
run: go install
- name: Run tests
run: ginkgo -r -v --cover -coverprofile coverage.txt --randomizeAllSpecs --randomizeSuites --trace --progress
- name: Run tests with race detector
if: ${{ matrix.os == 'ubuntu' }} # speed things up. Windows and OSX VMs are slow
run: ginkgo -r -v -race --randomizeAllSpecs --randomizeSuites --trace --progress
- name: Run tests (32 bit)
if: ${{ matrix.os != 'macos' }} # can't run 32 bit tests on OSX.
env:
Expand Down
20 changes: 20 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"crypto/rand"
"fmt"
"io"
"io/ioutil"
mrand "math/rand"
"net"
Expand Down Expand Up @@ -70,11 +71,13 @@ var _ = Describe("Connection", func() {
It("handshakes on IPv4", func() {
serverTransport, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln.Close()

clientTransport, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer clientTransport.(io.Closer).Close()
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
Expand All @@ -94,11 +97,13 @@ var _ = Describe("Connection", func() {
It("handshakes on IPv6", func() {
serverTransport, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln := runServer(serverTransport, "/ip6/::1/udp/0/quic")
defer ln.Close()

clientTransport, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer clientTransport.(io.Closer).Close()
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
Expand All @@ -118,11 +123,13 @@ var _ = Describe("Connection", func() {
It("opens and accepts streams", func() {
serverTransport, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln.Close()

clientTransport, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer clientTransport.(io.Closer).Close()
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
Expand All @@ -147,13 +154,15 @@ var _ = Describe("Connection", func() {

serverTransport, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")

clientTransport, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
// dial, but expect the wrong peer ID
_, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), thirdPartyID)
Expect(err).To(HaveOccurred())
defer clientTransport.(io.Closer).Close()
Expect(err.Error()).To(ContainSubstring("CRYPTO_ERROR"))

done := make(chan struct{})
Expand All @@ -172,6 +181,7 @@ var _ = Describe("Connection", func() {
cg.EXPECT().InterceptAccept(gomock.Any())
serverTransport, err := NewTransport(serverKey, nil, cg)
Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln.Close()

Expand All @@ -185,6 +195,7 @@ var _ = Describe("Connection", func() {

clientTransport, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer clientTransport.(io.Closer).Close()
// make sure that connection attempts fails
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
Expect(err).ToNot(HaveOccurred())
Expand All @@ -205,6 +216,7 @@ var _ = Describe("Connection", func() {
It("gates secured connections", func() {
serverTransport, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln.Close()

Expand All @@ -213,6 +225,7 @@ var _ = Describe("Connection", func() {

clientTransport, err := NewTransport(clientKey, nil, cg)
Expect(err).ToNot(HaveOccurred())
defer clientTransport.(io.Closer).Close()

// make sure that connection attempts fails
_, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
Expand All @@ -232,10 +245,12 @@ var _ = Describe("Connection", func() {

serverTransport, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln1 := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln1.Close()
serverTransport2, err := NewTransport(serverKey2, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer serverTransport2.(io.Closer).Close()
ln2 := runServer(serverTransport2, "/ip4/127.0.0.1/udp/0/quic")
defer ln2.Close()

Expand All @@ -262,6 +277,7 @@ var _ = Describe("Connection", func() {

clientTransport, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer clientTransport.(io.Closer).Close()
c1, err := clientTransport.Dial(context.Background(), ln1.Multiaddr(), serverID)
Expect(err).ToNot(HaveOccurred())
defer c1.Close()
Expand Down Expand Up @@ -291,6 +307,7 @@ var _ = Describe("Connection", func() {
It("sends stateless resets", func() {
serverTransport, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")

var drop uint32
Expand All @@ -307,6 +324,7 @@ var _ = Describe("Connection", func() {
// establish a connection
clientTransport, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer clientTransport.(io.Closer).Close()
proxyAddr, err := toQuicMultiaddr(proxy.LocalAddr())
Expect(err).ToNot(HaveOccurred())
conn, err := clientTransport.Dial(context.Background(), proxyAddr, serverID)
Expand Down Expand Up @@ -349,6 +367,7 @@ var _ = Describe("Connection", func() {
It("hole punches", func() {
t1, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer t1.(io.Closer).Close()
laddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic")
Expect(err).ToNot(HaveOccurred())
ln1, err := t1.Listen(laddr)
Expand All @@ -364,6 +383,7 @@ var _ = Describe("Connection", func() {

t2, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer t2.(io.Closer).Close()
ln2, err := t2.Listen(laddr)
Expect(err).ToNot(HaveOccurred())
done2 := make(chan struct{})
Expand Down
13 changes: 1 addition & 12 deletions libp2pquic_suite_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
package libp2pquic

import (
"bytes"
mrand "math/rand"
"runtime/pprof"
"strings"
"testing"
"time"

gomock "github.com/golang/mock/gomock"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go"

. "github.com/onsi/ginkgo"
Expand All @@ -31,16 +28,9 @@ var (
mockCtrl *gomock.Controller
)

func isGarbageCollectorRunning() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "go-libp2p-quic-transport.(*reuse).runGarbageCollector")
}

var _ = BeforeEach(func() {
mockCtrl = gomock.NewController(GinkgoT())

Expect(isGarbageCollectorRunning()).To(BeFalse())
garbageCollectIntervalOrig = garbageCollectInterval
maxUnusedDurationOrig = maxUnusedDuration
garbageCollectInterval = 50 * time.Millisecond
Expand All @@ -52,7 +42,6 @@ var _ = BeforeEach(func() {
var _ = AfterEach(func() {
mockCtrl.Finish()

Eventually(isGarbageCollectorRunning).Should(BeFalse())
garbageCollectInterval = garbageCollectIntervalOrig
maxUnusedDuration = maxUnusedDurationOrig
quicConfig = origQuicConfig
Expand Down
7 changes: 6 additions & 1 deletion listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@ import (
"crypto/x509"
"errors"
"fmt"
"io"
"net"
"syscall"

ic "github.com/libp2p/go-libp2p-core/crypto"
tpt "github.com/libp2p/go-libp2p-core/transport"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go"

ma "github.com/multiformats/go-multiaddr"
. "github.com/onsi/ginkgo"
Expand All @@ -38,6 +39,10 @@ var _ = Describe("Listener", func() {
Expect(err).ToNot(HaveOccurred())
})

AfterEach(func() {
Expect(t.(io.Closer).Close()).To(Succeed())
})

It("uses a conn that can interface assert to a UDPConn for listening", func() {
origQuicListen := quicListen
defer func() { quicListen = origQuicListen }()
Expand Down
93 changes: 42 additions & 51 deletions reuse.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"sync"
"time"

"github.com/libp2p/go-libp2p-core/connmgr"

"github.com/libp2p/go-netroute"
)

Expand All @@ -24,7 +22,7 @@ type reuseConn struct {
unusedSince time.Time
}

func newReuseConn(conn *net.UDPConn, gater connmgr.ConnectionGater) *reuseConn {
func newReuseConn(conn *net.UDPConn) *reuseConn {
return &reuseConn{UDPConn: conn}
}

Expand Down Expand Up @@ -53,68 +51,58 @@ func (c *reuseConn) ShouldGarbageCollect(now time.Time) bool {
type reuse struct {
mutex sync.Mutex

gater connmgr.ConnectionGater

garbageCollectorRunning bool
closeChan chan struct{}
gcStopChan chan struct{}

unicast map[string] /* IP.String() */ map[int] /* port */ *reuseConn
// global contains connections that are listening on 0.0.0.0 / ::
global map[int]*reuseConn
}

func newReuse(gater connmgr.ConnectionGater) *reuse {
return &reuse{
gater: gater,
unicast: make(map[string]map[int]*reuseConn),
global: make(map[int]*reuseConn),
func newReuse() *reuse {
r := &reuse{
unicast: make(map[string]map[int]*reuseConn),
global: make(map[int]*reuseConn),
closeChan: make(chan struct{}),
gcStopChan: make(chan struct{}),
}
go r.gc()
return r
}

func (r *reuse) runGarbageCollector() {
func (r *reuse) gc() {
defer close(r.gcStopChan)
ticker := time.NewTicker(garbageCollectInterval)
defer ticker.Stop()

for now := range ticker.C {
var shouldExit bool
r.mutex.Lock()
for key, conn := range r.global {
if conn.ShouldGarbageCollect(now) {
conn.Close()
delete(r.global, key)
}
}
for ukey, conns := range r.unicast {
for key, conn := range conns {
for {
select {
case <-r.closeChan:
return
case now := <-ticker.C:
r.mutex.Lock()
for key, conn := range r.global {
if conn.ShouldGarbageCollect(now) {
conn.Close()
delete(conns, key)
delete(r.global, key)
}
}
if len(conns) == 0 {
delete(r.unicast, ukey)
for ukey, conns := range r.unicast {
for key, conn := range conns {
if conn.ShouldGarbageCollect(now) {
conn.Close()
delete(conns, key)
}
}
if len(conns) == 0 {
delete(r.unicast, ukey)
}
}
}

// stop the garbage collector if we're not tracking any connections
if len(r.global) == 0 && len(r.unicast) == 0 {
r.garbageCollectorRunning = false
shouldExit = true
}
r.mutex.Unlock()

if shouldExit {
return
r.mutex.Unlock()
}
}
}

// must be called while holding the mutex
func (r *reuse) maybeStartGarbageCollector() {
if !r.garbageCollectorRunning {
r.garbageCollectorRunning = true
go r.runGarbageCollector()
}
}
func (r *reuse) Dial(network string, raddr *net.UDPAddr) (*reuseConn, error) {
var ip *net.IP
if router, err := netroute.New(); err == nil {
Expand All @@ -127,16 +115,15 @@ func (r *reuse) Dial(network string, raddr *net.UDPAddr) (*reuseConn, error) {
r.mutex.Lock()
defer r.mutex.Unlock()

conn, err := r.dialLocked(network, raddr, ip)
conn, err := r.dialLocked(network, ip)
if err != nil {
return nil, err
}
conn.IncreaseCount()
r.maybeStartGarbageCollector()
return conn, nil
}

func (r *reuse) dialLocked(network string, raddr *net.UDPAddr, source *net.IP) (*reuseConn, error) {
func (r *reuse) dialLocked(network string, source *net.IP) (*reuseConn, error) {
if source != nil {
// We already have at least one suitable connection...
if conns, ok := r.unicast[source.String()]; ok {
Expand Down Expand Up @@ -166,7 +153,7 @@ func (r *reuse) dialLocked(network string, raddr *net.UDPAddr, source *net.IP) (
if err != nil {
return nil, err
}
rconn := newReuseConn(conn, r.gater)
rconn := newReuseConn(conn)
r.global[conn.LocalAddr().(*net.UDPAddr).Port] = rconn
return rconn, nil
}
Expand All @@ -178,14 +165,12 @@ func (r *reuse) Listen(network string, laddr *net.UDPAddr) (*reuseConn, error) {
}
localAddr := conn.LocalAddr().(*net.UDPAddr)

rconn := newReuseConn(conn, r.gater)
rconn := newReuseConn(conn)
rconn.IncreaseCount()

r.mutex.Lock()
defer r.mutex.Unlock()

r.maybeStartGarbageCollector()

// Deal with listen on a global address
if localAddr.IP.IsUnspecified() {
// The kernel already checked that the laddr is not already listen
Expand All @@ -204,3 +189,9 @@ func (r *reuse) Listen(network string, laddr *net.UDPAddr) (*reuseConn, error) {
r.unicast[localAddr.IP.String()][localAddr.Port] = rconn
return rconn, err
}

func (r *reuse) Close() error {
close(r.closeChan)
<-r.gcStopChan
return nil
}
Loading