Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(common.socket): Switch to context to simplify closing #15589

Merged
merged 4 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion plugins/common/socket/socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type listener interface {
}

type Config struct {
MaxConnections int `toml:"max_connections"`
MaxConnections uint64 `toml:"max_connections"`
ReadBufferSize config.Size `toml:"read_buffer_size"`
ReadTimeout config.Duration `toml:"read_timeout"`
KeepAlivePeriod *config.Duration `toml:"keep_alive_period"`
Expand Down
223 changes: 222 additions & 1 deletion plugins/common/socket/socket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package socket

import (
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"os"
"runtime"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -481,7 +483,7 @@ func TestClosingConnections(t *testing.T) {
listener, ok := sock.listener.(*streamListener)
require.True(t, ok)
listener.Lock()
conns := len(listener.connections)
conns := listener.connections
listener.Unlock()
require.NotZero(t, conns)

Expand All @@ -496,6 +498,130 @@ func TestClosingConnections(t *testing.T) {
require.Empty(t, logger.Errors())
require.Empty(t, logger.Warnings())
}
func TestMaxConnections(t *testing.T) {
if runtime.GOOS == "darwin" {
t.Skip("Skipping on darwin due to missing socket options")
}

// Setup the configuration
period := config.Duration(10 * time.Millisecond)
cfg := &Config{
MaxConnections: 5,
KeepAlivePeriod: &period,
}

// Create the socket
serviceAddress := "tcp://127.0.0.1:0"
sock, err := cfg.NewSocket(serviceAddress, nil, &testutil.Logger{})
require.NoError(t, err)

// Create callback
var errs []error
var mu sync.Mutex
onData := func(_ net.Addr, _ []byte) {}
onError := func(err error) {
mu.Lock()
errs = append(errs, err)
mu.Unlock()
}

// Start the listener
require.NoError(t, sock.Setup())
sock.Listen(onData, onError)
defer sock.Close()

addr := sock.Address()

// Create maximum number of connections and write some data. All of this
// should succeed...
clients := make([]*net.TCPConn, 0, cfg.MaxConnections)
for i := 0; i < int(cfg.MaxConnections); i++ {
c, err := net.DialTCP("tcp", nil, addr.(*net.TCPAddr))
require.NoError(t, err)
require.NoError(t, c.SetWriteBuffer(0))
require.NoError(t, c.SetNoDelay(true))
clients = append(clients, c)

_, err = c.Write([]byte("test value=42i\n"))
require.NoError(t, err)
}

func() {
mu.Lock()
defer mu.Unlock()
require.Empty(t, errs)
}()

// Create another client. This should fail because we already reached the
// connection limit and the connection should be closed...
client, err := net.DialTCP("tcp", nil, addr.(*net.TCPAddr))
require.NoError(t, err)
require.NoError(t, client.SetWriteBuffer(0))
require.NoError(t, client.SetNoDelay(true))

require.Eventually(t, func() bool {
mu.Lock()
defer mu.Unlock()
return len(errs) > 0
}, 3*time.Second, 100*time.Millisecond)
func() {
mu.Lock()
defer mu.Unlock()
require.Len(t, errs, 1)
require.ErrorContains(t, errs[0], "too many connections")
errs = make([]error, 0)
}()

require.Eventually(t, func() bool {
_, err := client.Write([]byte("fail\n"))
return err != nil
}, 3*time.Second, 100*time.Millisecond)
_, err = client.Write([]byte("test\n"))
require.Error(t, err)

// Check other connections are still good
for _, c := range clients {
_, err := c.Write([]byte("test\n"))
require.NoError(t, err)
}
func() {
mu.Lock()
defer mu.Unlock()
require.Empty(t, errs)
}()

// Close the first client and check if we can connect now
require.NoError(t, clients[0].Close())
client, err = net.DialTCP("tcp", nil, addr.(*net.TCPAddr))
require.NoError(t, err)
require.NoError(t, client.SetWriteBuffer(0))
require.NoError(t, client.SetNoDelay(true))
_, err = client.Write([]byte("success\n"))
require.NoError(t, err)

// Close all connections
require.NoError(t, client.Close())
for _, c := range clients[1:] {
require.NoError(t, c.Close())
}

// Close the clients and check the connection counter
listener, ok := sock.listener.(*streamListener)
require.True(t, ok)
require.Eventually(t, func() bool {
listener.Lock()
conns := listener.connections
listener.Unlock()
return conns == 0
}, 3*time.Second, 100*time.Millisecond)

// Close the socket and check again...
sock.Close()
listener.Lock()
conns := listener.connections
listener.Unlock()
require.Zero(t, conns)
}

func TestNoSplitter(t *testing.T) {
messages := [][]byte{
Expand Down Expand Up @@ -605,6 +731,101 @@ func TestNoSplitter(t *testing.T) {
testutil.RequireMetricsEqual(t, expected, actual, testutil.SortMetrics())
}

func TestTLSMemLeak(t *testing.T) {
// For issue https://github.com/influxdata/telegraf/issues/15509

// Prepare the address and socket if needed
serviceAddress := "tcp://127.0.0.1:0"

// Setup a TLS socket to trigger the issue
cfg := &Config{
ServerConfig: *pki.TLSServerConfig(),
}

// Create the socket
sock, err := cfg.NewSocket(serviceAddress, nil, &testutil.Logger{})
require.NoError(t, err)

// Create callbacks
onConnection := func(_ net.Addr, reader io.ReadCloser) {
//nolint:errcheck // We are not interested in the data so ignore all errors
io.Copy(io.Discard, reader)
}

// Start the listener
require.NoError(t, sock.Setup())
sock.ListenConnection(onConnection, nil)
defer sock.Close()

addr := sock.Address()

// Setup the client side TLS
tlsCfg, err := pki.TLSClientConfig().TLSConfig()
require.NoError(t, err)

// Define a single client write sequence
data := []byte("test value=42i")
write := func() error {
conn, err := tls.Dial("tcp", addr.String(), tlsCfg)
if err != nil {
return err
}
defer conn.Close()

_, err = conn.Write(data)
return err
}

// Define a test with the given number of connections
maxConcurrency := runtime.GOMAXPROCS(0)
testCycle := func(connections int) (uint64, error) {
var mu sync.Mutex
var errs []error
var wg sync.WaitGroup
for count := 1; count < connections; count++ {
wg.Add(1)
go func() {
defer wg.Done()
if err := write(); err != nil {
mu.Lock()
errs = append(errs, err)
mu.Unlock()
}
}()
if count%maxConcurrency == 0 {
wg.Wait()
mu.Lock()
if len(errs) > 0 {
mu.Unlock()
return 0, errors.Join(errs...)
}
mu.Unlock()
}
}
//nolint:revive // We need to actively run the garbage collector to get reliable measurements
runtime.GC()

var stats runtime.MemStats
runtime.ReadMemStats(&stats)
return stats.HeapObjects, nil
}

// Measure the memory usage after a short warmup and after some time.
// The final number of heap objects should not exceed the number of
// runs by a save margin

// Warmup, do a low number of runs to initialize all data structures
// taking them out of the equation.
initial, err := testCycle(100)
require.NoError(t, err)

// Do some more runs and make sure the memory growth is bound
final, err := testCycle(2000)
require.NoError(t, err)

require.Less(t, final, 2*initial)
}

func createClient(endpoint string, addr net.Addr, tlsCfg *tls.Config) (net.Conn, error) {
// Determine the protocol in a crude fashion
parts := strings.SplitN(endpoint, "://", 2)
Expand Down
Loading
Loading