Skip to content
This repository has been archived by the owner on May 26, 2022. It is now read-only.

implement a Transport.Close that waits for the reuse's GC to finish #211

Merged
merged 4 commits into from
Jul 20, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/unit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ jobs:
run: go install
- name: Run tests
run: ginkgo -r -v --cover -coverprofile coverage.txt --randomizeAllSpecs --randomizeSuites --trace --progress
- name: Run tests with race detector
if: ${{ matrix.os == 'ubuntu' }} # speed things up. Windows and OSX VMs are slow
run: ginkgo -r -v -race --randomizeAllSpecs --randomizeSuites --trace --progress
- name: Run tests (32 bit)
if: ${{ matrix.os != 'macos' }} # can't run 32 bit tests on OSX.
env:
Expand Down
20 changes: 20 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"crypto/rand"
"fmt"
"io"
"io/ioutil"
mrand "math/rand"
"net"
Expand Down Expand Up @@ -70,11 +71,13 @@ var _ = Describe("Connection", func() {
It("handshakes on IPv4", func() {
serverTransport, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln.Close()

clientTransport, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer clientTransport.(io.Closer).Close()
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
Expand All @@ -94,11 +97,13 @@ var _ = Describe("Connection", func() {
It("handshakes on IPv6", func() {
serverTransport, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln := runServer(serverTransport, "/ip6/::1/udp/0/quic")
defer ln.Close()

clientTransport, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer clientTransport.(io.Closer).Close()
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
Expand All @@ -118,11 +123,13 @@ var _ = Describe("Connection", func() {
It("opens and accepts streams", func() {
serverTransport, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln.Close()

clientTransport, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer clientTransport.(io.Closer).Close()
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
Expand All @@ -147,13 +154,15 @@ var _ = Describe("Connection", func() {

serverTransport, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")

clientTransport, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
// dial, but expect the wrong peer ID
_, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), thirdPartyID)
Expect(err).To(HaveOccurred())
defer clientTransport.(io.Closer).Close()
Expect(err.Error()).To(ContainSubstring("CRYPTO_ERROR"))

done := make(chan struct{})
Expand All @@ -172,6 +181,7 @@ var _ = Describe("Connection", func() {
cg.EXPECT().InterceptAccept(gomock.Any())
serverTransport, err := NewTransport(serverKey, nil, cg)
Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln.Close()

Expand All @@ -185,6 +195,7 @@ var _ = Describe("Connection", func() {

clientTransport, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer clientTransport.(io.Closer).Close()
// make sure that connection attempts fails
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
Expect(err).ToNot(HaveOccurred())
Expand All @@ -205,6 +216,7 @@ var _ = Describe("Connection", func() {
It("gates secured connections", func() {
serverTransport, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln.Close()

Expand All @@ -213,6 +225,7 @@ var _ = Describe("Connection", func() {

clientTransport, err := NewTransport(clientKey, nil, cg)
Expect(err).ToNot(HaveOccurred())
defer clientTransport.(io.Closer).Close()

// make sure that connection attempts fails
_, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
Expand All @@ -232,10 +245,12 @@ var _ = Describe("Connection", func() {

serverTransport, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln1 := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln1.Close()
serverTransport2, err := NewTransport(serverKey2, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer serverTransport2.(io.Closer).Close()
ln2 := runServer(serverTransport2, "/ip4/127.0.0.1/udp/0/quic")
defer ln2.Close()

Expand All @@ -262,6 +277,7 @@ var _ = Describe("Connection", func() {

clientTransport, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer clientTransport.(io.Closer).Close()
c1, err := clientTransport.Dial(context.Background(), ln1.Multiaddr(), serverID)
Expect(err).ToNot(HaveOccurred())
defer c1.Close()
Expand Down Expand Up @@ -291,6 +307,7 @@ var _ = Describe("Connection", func() {
It("sends stateless resets", func() {
serverTransport, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")

var drop uint32
Expand All @@ -307,6 +324,7 @@ var _ = Describe("Connection", func() {
// establish a connection
clientTransport, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer clientTransport.(io.Closer).Close()
proxyAddr, err := toQuicMultiaddr(proxy.LocalAddr())
Expect(err).ToNot(HaveOccurred())
conn, err := clientTransport.Dial(context.Background(), proxyAddr, serverID)
Expand Down Expand Up @@ -349,6 +367,7 @@ var _ = Describe("Connection", func() {
It("hole punches", func() {
t1, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer t1.(io.Closer).Close()
laddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic")
Expect(err).ToNot(HaveOccurred())
ln1, err := t1.Listen(laddr)
Expand All @@ -364,6 +383,7 @@ var _ = Describe("Connection", func() {

t2, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
defer t2.(io.Closer).Close()
ln2, err := t2.Listen(laddr)
Expect(err).ToNot(HaveOccurred())
done2 := make(chan struct{})
Expand Down
13 changes: 1 addition & 12 deletions libp2pquic_suite_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
package libp2pquic

import (
"bytes"
mrand "math/rand"
"runtime/pprof"
"strings"
"testing"
"time"

gomock "github.com/golang/mock/gomock"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go"

. "github.com/onsi/ginkgo"
Expand All @@ -31,16 +28,9 @@ var (
mockCtrl *gomock.Controller
)

func isGarbageCollectorRunning() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "go-libp2p-quic-transport.(*reuse).runGarbageCollector")
}

var _ = BeforeEach(func() {
mockCtrl = gomock.NewController(GinkgoT())

Expect(isGarbageCollectorRunning()).To(BeFalse())
garbageCollectIntervalOrig = garbageCollectInterval
maxUnusedDurationOrig = maxUnusedDuration
garbageCollectInterval = 50 * time.Millisecond
Expand All @@ -52,7 +42,6 @@ var _ = BeforeEach(func() {
var _ = AfterEach(func() {
mockCtrl.Finish()

Eventually(isGarbageCollectorRunning).Should(BeFalse())
garbageCollectInterval = garbageCollectIntervalOrig
maxUnusedDuration = maxUnusedDurationOrig
quicConfig = origQuicConfig
Expand Down
7 changes: 6 additions & 1 deletion listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@ import (
"crypto/x509"
"errors"
"fmt"
"io"
"net"
"syscall"

ic "github.com/libp2p/go-libp2p-core/crypto"
tpt "github.com/libp2p/go-libp2p-core/transport"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go"

ma "github.com/multiformats/go-multiaddr"
. "github.com/onsi/ginkgo"
Expand All @@ -38,6 +39,10 @@ var _ = Describe("Listener", func() {
Expect(err).ToNot(HaveOccurred())
})

AfterEach(func() {
Expect(t.(io.Closer).Close()).To(Succeed())
})

It("uses a conn that can interface assert to a UDPConn for listening", func() {
origQuicListen := quicListen
defer func() { quicListen = origQuicListen }()
Expand Down
89 changes: 53 additions & 36 deletions reuse.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"sync"
"time"

"github.com/libp2p/go-libp2p-core/connmgr"

"github.com/libp2p/go-netroute"
)

Expand All @@ -24,7 +22,7 @@ type reuseConn struct {
unusedSince time.Time
}

func newReuseConn(conn *net.UDPConn, gater connmgr.ConnectionGater) *reuseConn {
func newReuseConn(conn *net.UDPConn) *reuseConn {
return &reuseConn{UDPConn: conn}
}

Expand Down Expand Up @@ -53,57 +51,64 @@ func (c *reuseConn) ShouldGarbageCollect(now time.Time) bool {
type reuse struct {
mutex sync.Mutex

gater connmgr.ConnectionGater

garbageCollectorRunning bool

closeChan chan struct{}
garbageCollectorStopChan chan struct{}

unicast map[string] /* IP.String() */ map[int] /* port */ *reuseConn
// global contains connections that are listening on 0.0.0.0 / ::
global map[int]*reuseConn
}

func newReuse(gater connmgr.ConnectionGater) *reuse {
func newReuse() *reuse {
return &reuse{
gater: gater,
unicast: make(map[string]map[int]*reuseConn),
global: make(map[int]*reuseConn),
unicast: make(map[string]map[int]*reuseConn),
global: make(map[int]*reuseConn),
closeChan: make(chan struct{}),
}
}

func (r *reuse) runGarbageCollector() {
defer close(r.garbageCollectorStopChan)
ticker := time.NewTicker(garbageCollectInterval)
defer ticker.Stop()

for now := range ticker.C {
var shouldExit bool
r.mutex.Lock()
for key, conn := range r.global {
if conn.ShouldGarbageCollect(now) {
conn.Close()
delete(r.global, key)
}
}
for ukey, conns := range r.unicast {
for key, conn := range conns {
for {
select {
case <-r.closeChan:
return
case now := <-ticker.C:
var shouldExit bool
r.mutex.Lock()
for key, conn := range r.global {
if conn.ShouldGarbageCollect(now) {
conn.Close()
delete(conns, key)
delete(r.global, key)
}
}
if len(conns) == 0 {
delete(r.unicast, ukey)
for ukey, conns := range r.unicast {
for key, conn := range conns {
if conn.ShouldGarbageCollect(now) {
conn.Close()
delete(conns, key)
}
}
if len(conns) == 0 {
delete(r.unicast, ukey)
}
}
}

// stop the garbage collector if we're not tracking any connections
if len(r.global) == 0 && len(r.unicast) == 0 {
r.garbageCollectorRunning = false
shouldExit = true
}
r.mutex.Unlock()
// stop the garbage collector if we're not tracking any connections
if len(r.global) == 0 && len(r.unicast) == 0 {
r.garbageCollectorRunning = false
shouldExit = true
}
r.mutex.Unlock()

if shouldExit {
return
if shouldExit {
return
}
}
}
}
Expand All @@ -112,6 +117,7 @@ func (r *reuse) runGarbageCollector() {
func (r *reuse) maybeStartGarbageCollector() {
if !r.garbageCollectorRunning {
r.garbageCollectorRunning = true
r.garbageCollectorStopChan = make(chan struct{})
go r.runGarbageCollector()
}
}
Expand All @@ -127,7 +133,7 @@ func (r *reuse) Dial(network string, raddr *net.UDPAddr) (*reuseConn, error) {
r.mutex.Lock()
defer r.mutex.Unlock()

conn, err := r.dialLocked(network, raddr, ip)
conn, err := r.dialLocked(network, ip)
if err != nil {
return nil, err
}
Expand All @@ -136,7 +142,7 @@ func (r *reuse) Dial(network string, raddr *net.UDPAddr) (*reuseConn, error) {
return conn, nil
}

func (r *reuse) dialLocked(network string, raddr *net.UDPAddr, source *net.IP) (*reuseConn, error) {
func (r *reuse) dialLocked(network string, source *net.IP) (*reuseConn, error) {
if source != nil {
// We already have at least one suitable connection...
if conns, ok := r.unicast[source.String()]; ok {
Expand Down Expand Up @@ -166,7 +172,7 @@ func (r *reuse) dialLocked(network string, raddr *net.UDPAddr, source *net.IP) (
if err != nil {
return nil, err
}
rconn := newReuseConn(conn, r.gater)
rconn := newReuseConn(conn)
r.global[conn.LocalAddr().(*net.UDPAddr).Port] = rconn
return rconn, nil
}
Expand All @@ -178,7 +184,7 @@ func (r *reuse) Listen(network string, laddr *net.UDPAddr) (*reuseConn, error) {
}
localAddr := conn.LocalAddr().(*net.UDPAddr)

rconn := newReuseConn(conn, r.gater)
rconn := newReuseConn(conn)
rconn.IncreaseCount()

r.mutex.Lock()
Expand All @@ -204,3 +210,14 @@ func (r *reuse) Listen(network string, laddr *net.UDPAddr) (*reuseConn, error) {
r.unicast[localAddr.IP.String()][localAddr.Port] = rconn
return rconn, err
}

func (r *reuse) Close() error {
r.mutex.Lock()
defer r.mutex.Unlock()
close(r.closeChan)
if r.garbageCollectorRunning {
<-r.garbageCollectorStopChan
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can deadlock if we're waiting for the garbage collection routine to take the lock.

r.garbageCollectorRunning = false
}
return nil
}
Loading