Skip to content

Commit

Permalink
fix: adding timeout to the typha server TLS handshake (#7909)
Browse files Browse the repository at this point in the history
  • Loading branch information
rodrigorfk authored Aug 23, 2023
1 parent 9a4557c commit ad8bd00
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 1 deletion.
24 changes: 24 additions & 0 deletions typha/fv-tests/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"encoding/gob"
"errors"
"fmt"
"io"
"net"
"os"
"path/filepath"
Expand Down Expand Up @@ -1611,6 +1612,19 @@ var _ = Describe("with server requiring TLS", func() {
}
}

testTcpHalfOpen := func() {
deadline := time.Now().Add(9 * time.Second)
serverAddr := fmt.Sprintf("127.0.0.1:%d", server.Port())
tcpConn, err := net.Dial("tcp", serverAddr)
Expect(err).NotTo(HaveOccurred())
err = tcpConn.SetDeadline(time.Now().Add(30 * time.Second))
Expect(err).NotTo(HaveOccurred())
received := make([]byte, 1024)
_, err = tcpConn.Read(received)
Expect(err).Should(Equal(io.EOF))
Expect(time.Now().Unix()).Should(BeNumerically(">=", deadline.Unix()))
}

Describe("and CN or URI SAN", func() {
BeforeEach(func() {
requiredClientCN = clientCN
Expand Down Expand Up @@ -1686,4 +1700,14 @@ var _ = Describe("with server requiring TLS", func() {

It("TLS connection with good CN and URI should fail", testTLSGoodCNURI(false))
})

Describe("handshake", func() {
BeforeEach(func() {
requiredClientCN = clientCN
requiredClientURISAN = clientURISAN
serverCertName = "server"
})

It("should timeout after 10 seconds for TCP half open connections", testTcpHalfOpen)
})
})
1 change: 1 addition & 0 deletions typha/pkg/config/config_params.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ type Config struct {
ServerMinBatchingAgeThresholdSecs time.Duration `config:"seconds;0.01"`
ServerPingIntervalSecs time.Duration `config:"seconds;10"`
ServerPongTimeoutSecs time.Duration `config:"seconds;60"`
ServerHandshakeTimeoutSecs time.Duration `config:"seconds;10"`
ServerPort int `config:"port;0"`

// Server-side TLS config for Typha's communication with Felix. If any of these are
Expand Down
1 change: 1 addition & 0 deletions typha/pkg/daemon/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ func (t *TyphaDaemon) CreateServer() {
NewClientFallBehindGracePeriod: t.ConfigParams.ServerNewClientFallBehindGracePeriod,
PingInterval: t.ConfigParams.ServerPingIntervalSecs,
PongTimeout: t.ConfigParams.ServerPongTimeoutSecs,
HandshakeTimeout: t.ConfigParams.ServerHandshakeTimeoutSecs,
DropInterval: t.ConfigParams.ConnectionDropIntervalSecs,
ShutdownTimeout: t.ConfigParams.ShutdownTimeoutSecs,
ShutdownMaxDropInterval: t.ConfigParams.ShutdownConnectionDropIntervalMaxSecs,
Expand Down
15 changes: 14 additions & 1 deletion typha/pkg/syncserver/sync_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ const (
defaultBatchingAgeThreshold = 100 * time.Millisecond
defaultPingInterval = 10 * time.Second
defaultWriteTimeout = 120 * time.Second
defaultHandshakeTimeout = 10 * time.Second
defaultDropInterval = 1 * time.Second
defaultShutdownTimeout = 300 * time.Second
defaultMaxConns = math.MaxInt32
Expand Down Expand Up @@ -140,6 +141,7 @@ type Config struct {
MinBatchingAgeThreshold time.Duration
PingInterval time.Duration
PongTimeout time.Duration
HandshakeTimeout time.Duration
WriteTimeout time.Duration
DropInterval time.Duration
ShutdownTimeout time.Duration
Expand Down Expand Up @@ -217,6 +219,13 @@ func (c *Config) ApplyDefaults() {
}).Info("PongTimeout < PingInterval * 2; Defaulting PongTimeout.")
c.PongTimeout = defaultTimeout
}
if c.HandshakeTimeout <= 0 {
log.WithFields(log.Fields{
"value": c.HandshakeTimeout,
"default": defaultHandshakeTimeout,
}).Info("Defaulting HandshakeTimeout.")
c.HandshakeTimeout = defaultHandshakeTimeout
}
if c.WriteTimeout <= 0 {
log.WithField("default", defaultWriteTimeout).Info("Defaulting write timeout.")
c.WriteTimeout = defaultWriteTimeout
Expand Down Expand Up @@ -424,7 +433,11 @@ func (s *Server) serve(cxt context.Context) {
// Doing TLS, we must do the handshake...
tlsConn := conn.(*tls.Conn)
logCxt.Debug("TLS connection")
err = tlsConn.Handshake()
err = func() error {
handshakeCxt, handshakeCancel := context.WithTimeout(cxt, s.config.HandshakeTimeout)
defer handshakeCancel()
return tlsConn.HandshakeContext(handshakeCxt)
}()
if err != nil {
logCxt.WithError(err).Error("TLS handshake error")
err = conn.Close()
Expand Down
1 change: 1 addition & 0 deletions typha/pkg/syncserver/syncserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ var _ = Describe("With zero config", func() {
MinBatchingAgeThreshold: 100 * time.Millisecond,
PingInterval: 10 * time.Second,
PongTimeout: 60 * time.Second,
HandshakeTimeout: 10 * time.Second,
DropInterval: time.Second,
ShutdownTimeout: 300 * time.Second,
ShutdownMaxDropInterval: time.Second,
Expand Down

0 comments on commit ad8bd00

Please sign in to comment.