Skip to content

Commit

Permalink
fix(common.socket): Switch to context to simplify closing (#15589)
Browse files Browse the repository at this point in the history
(cherry picked from commit 3d9562b)
  • Loading branch information
srebhan authored and powersj committed Jul 22, 2024
1 parent a590193 commit e758275
Show file tree
Hide file tree
Showing 3 changed files with 289 additions and 54 deletions.
2 changes: 1 addition & 1 deletion plugins/common/socket/socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type listener interface {
}

type Config struct {
MaxConnections int `toml:"max_connections"`
MaxConnections uint64 `toml:"max_connections"`
ReadBufferSize config.Size `toml:"read_buffer_size"`
ReadTimeout config.Duration `toml:"read_timeout"`
KeepAlivePeriod *config.Duration `toml:"keep_alive_period"`
Expand Down
223 changes: 222 additions & 1 deletion plugins/common/socket/socket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package socket

import (
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"os"
"runtime"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -481,7 +483,7 @@ func TestClosingConnections(t *testing.T) {
listener, ok := sock.listener.(*streamListener)
require.True(t, ok)
listener.Lock()
conns := len(listener.connections)
conns := listener.connections
listener.Unlock()
require.NotZero(t, conns)

Expand All @@ -496,6 +498,130 @@ func TestClosingConnections(t *testing.T) {
require.Empty(t, logger.Errors())
require.Empty(t, logger.Warnings())
}
func TestMaxConnections(t *testing.T) {
if runtime.GOOS == "darwin" {
t.Skip("Skipping on darwin due to missing socket options")
}

// Setup the configuration
period := config.Duration(10 * time.Millisecond)
cfg := &Config{
MaxConnections: 5,
KeepAlivePeriod: &period,
}

// Create the socket
serviceAddress := "tcp://127.0.0.1:0"
sock, err := cfg.NewSocket(serviceAddress, nil, &testutil.Logger{})
require.NoError(t, err)

// Create callback
var errs []error
var mu sync.Mutex
onData := func(_ net.Addr, _ []byte) {}
onError := func(err error) {
mu.Lock()
errs = append(errs, err)
mu.Unlock()
}

// Start the listener
require.NoError(t, sock.Setup())
sock.Listen(onData, onError)
defer sock.Close()

addr := sock.Address()

// Create maximum number of connections and write some data. All of this
// should succeed...
clients := make([]*net.TCPConn, 0, cfg.MaxConnections)
for i := 0; i < int(cfg.MaxConnections); i++ {
c, err := net.DialTCP("tcp", nil, addr.(*net.TCPAddr))
require.NoError(t, err)
require.NoError(t, c.SetWriteBuffer(0))
require.NoError(t, c.SetNoDelay(true))
clients = append(clients, c)

_, err = c.Write([]byte("test value=42i\n"))
require.NoError(t, err)
}

func() {
mu.Lock()
defer mu.Unlock()
require.Empty(t, errs)
}()

// Create another client. This should fail because we already reached the
// connection limit and the connection should be closed...
client, err := net.DialTCP("tcp", nil, addr.(*net.TCPAddr))
require.NoError(t, err)
require.NoError(t, client.SetWriteBuffer(0))
require.NoError(t, client.SetNoDelay(true))

require.Eventually(t, func() bool {
mu.Lock()
defer mu.Unlock()
return len(errs) > 0
}, 3*time.Second, 100*time.Millisecond)
func() {
mu.Lock()
defer mu.Unlock()
require.Len(t, errs, 1)
require.ErrorContains(t, errs[0], "too many connections")
errs = make([]error, 0)
}()

require.Eventually(t, func() bool {
_, err := client.Write([]byte("fail\n"))
return err != nil
}, 3*time.Second, 100*time.Millisecond)
_, err = client.Write([]byte("test\n"))
require.Error(t, err)

// Check other connections are still good
for _, c := range clients {
_, err := c.Write([]byte("test\n"))
require.NoError(t, err)
}
func() {
mu.Lock()
defer mu.Unlock()
require.Empty(t, errs)
}()

// Close the first client and check if we can connect now
require.NoError(t, clients[0].Close())
client, err = net.DialTCP("tcp", nil, addr.(*net.TCPAddr))
require.NoError(t, err)
require.NoError(t, client.SetWriteBuffer(0))
require.NoError(t, client.SetNoDelay(true))
_, err = client.Write([]byte("success\n"))
require.NoError(t, err)

// Close all connections
require.NoError(t, client.Close())
for _, c := range clients[1:] {
require.NoError(t, c.Close())
}

// Close the clients and check the connection counter
listener, ok := sock.listener.(*streamListener)
require.True(t, ok)
require.Eventually(t, func() bool {
listener.Lock()
conns := listener.connections
listener.Unlock()
return conns == 0
}, 3*time.Second, 100*time.Millisecond)

// Close the socket and check again...
sock.Close()
listener.Lock()
conns := listener.connections
listener.Unlock()
require.Zero(t, conns)
}

func TestNoSplitter(t *testing.T) {
messages := [][]byte{
Expand Down Expand Up @@ -605,6 +731,101 @@ func TestNoSplitter(t *testing.T) {
testutil.RequireMetricsEqual(t, expected, actual, testutil.SortMetrics())
}

func TestTLSMemLeak(t *testing.T) {
// For issue https://github.com/influxdata/telegraf/issues/15509

// Prepare the address and socket if needed
serviceAddress := "tcp://127.0.0.1:0"

// Setup a TLS socket to trigger the issue
cfg := &Config{
ServerConfig: *pki.TLSServerConfig(),
}

// Create the socket
sock, err := cfg.NewSocket(serviceAddress, nil, &testutil.Logger{})
require.NoError(t, err)

// Create callbacks
onConnection := func(_ net.Addr, reader io.ReadCloser) {
//nolint:errcheck // We are not interested in the data so ignore all errors
io.Copy(io.Discard, reader)
}

// Start the listener
require.NoError(t, sock.Setup())
sock.ListenConnection(onConnection, nil)
defer sock.Close()

addr := sock.Address()

// Setup the client side TLS
tlsCfg, err := pki.TLSClientConfig().TLSConfig()
require.NoError(t, err)

// Define a single client write sequence
data := []byte("test value=42i")
write := func() error {
conn, err := tls.Dial("tcp", addr.String(), tlsCfg)
if err != nil {
return err
}
defer conn.Close()

_, err = conn.Write(data)
return err
}

// Define a test with the given number of connections
maxConcurrency := runtime.GOMAXPROCS(0)
testCycle := func(connections int) (uint64, error) {
var mu sync.Mutex
var errs []error
var wg sync.WaitGroup
for count := 1; count < connections; count++ {
wg.Add(1)
go func() {
defer wg.Done()
if err := write(); err != nil {
mu.Lock()
errs = append(errs, err)
mu.Unlock()
}
}()
if count%maxConcurrency == 0 {
wg.Wait()
mu.Lock()
if len(errs) > 0 {
mu.Unlock()
return 0, errors.Join(errs...)
}
mu.Unlock()
}
}
//nolint:revive // We need to actively run the garbage collector to get reliable measurements
runtime.GC()

var stats runtime.MemStats
runtime.ReadMemStats(&stats)
return stats.HeapObjects, nil
}

// Measure the memory usage after a short warmup and after some time.
// The final number of heap objects should not exceed the number of
// runs by a save margin

// Warmup, do a low number of runs to initialize all data structures
// taking them out of the equation.
initial, err := testCycle(100)
require.NoError(t, err)

// Do some more runs and make sure the memory growth is bound
final, err := testCycle(2000)
require.NoError(t, err)

require.Less(t, final, 2*initial)
}

func createClient(endpoint string, addr net.Addr, tlsCfg *tls.Config) (net.Conn, error) {
// Determine the protocol in a crude fashion
parts := strings.SplitN(endpoint, "://", 2)
Expand Down
Loading

0 comments on commit e758275

Please sign in to comment.