Skip to content

Commit

Permalink
security: report warning when TLS version is below 1.2 (#350)
Browse files Browse the repository at this point in the history
  • Loading branch information
djshow832 authored Aug 30, 2023
1 parent efd29b1 commit 7f0eed6
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 20 deletions.
17 changes: 0 additions & 17 deletions lib/config/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@ package config

import (
"bytes"
"crypto/tls"
"os"
"path/filepath"
"strings"
"time"

"github.com/BurntSushi/toml"
Expand Down Expand Up @@ -111,21 +109,6 @@ func (c TLSConfig) HasCert() bool {
return !(c.Cert == "" && c.Key == "")
}

func (c TLSConfig) MinTLSVer() uint16 {
switch {
case strings.HasSuffix(c.MinTLSVersion, "1.0"):
return tls.VersionTLS10
case strings.HasSuffix(c.MinTLSVersion, "1.1"):
return tls.VersionTLS11
case strings.HasSuffix(c.MinTLSVersion, "1.2"):
return tls.VersionTLS12
case strings.HasSuffix(c.MinTLSVersion, "1.3"):
return tls.VersionTLS13
default:
return tls.VersionTLS12
}
}

func (c TLSConfig) HasCA() bool {
return c.CA != ""
}
Expand Down
6 changes: 3 additions & 3 deletions lib/util/security/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ func (ci *CertInfo) buildServerConfig(lg *zap.Logger) (*tls.Config, error) {
}

tcfg := &tls.Config{
MinVersion: cfg.MinTLSVer(),
MinVersion: GetMinTLSVer(cfg.MinTLSVersion, lg),
GetCertificate: ci.getCert,
GetClientCertificate: ci.getClientCert,
VerifyPeerCertificate: ci.verifyPeerCertificate,
Expand Down Expand Up @@ -221,15 +221,15 @@ func (ci *CertInfo) buildClientConfig(lg *zap.Logger) (*tls.Config, error) {
// still enable TLS without verify server certs
return &tls.Config{
InsecureSkipVerify: true,
MinVersion: cfg.MinTLSVer(),
MinVersion: GetMinTLSVer(cfg.MinTLSVersion, lg),
}, nil
}
lg.Info("no CA to verify server connections, disable TLS")
return nil, nil
}

tcfg := &tls.Config{
MinVersion: cfg.MinTLSVer(),
MinVersion: GetMinTLSVer(cfg.MinTLSVersion, lg),
GetCertificate: ci.getCert,
GetClientCertificate: ci.getClientCert,
InsecureSkipVerify: true,
Expand Down
23 changes: 23 additions & 0 deletions lib/util/security/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"net"
"os"
"path/filepath"
"strings"
"time"

"github.com/pingcap/TiProxy/lib/config"
Expand Down Expand Up @@ -236,3 +237,25 @@ func BuildClientTLSConfig(logger *zap.Logger, cfg config.TLSConfig) (*tls.Config

return tcfg, nil
}

// GetMinTLSVer parses the min tls version from config and reports warning if necessary.
func GetMinTLSVer(tlsVerStr string, logger *zap.Logger) uint16 {
var minTLSVersion uint16 = tls.VersionTLS12
switch {
case strings.HasSuffix(tlsVerStr, "1.0"):
minTLSVersion = tls.VersionTLS10
case strings.HasSuffix(tlsVerStr, "1.1"):
minTLSVersion = tls.VersionTLS11
case strings.HasSuffix(tlsVerStr, "1.2"):
minTLSVersion = tls.VersionTLS12
case strings.HasSuffix(tlsVerStr, "1.3"):
minTLSVersion = tls.VersionTLS13
case len(tlsVerStr) == 0:
default:
logger.Warn("Invalid TLS version, using default instead", zap.String("tls-version", tlsVerStr))
}
if minTLSVersion < tls.VersionTLS12 {
logger.Warn("Minimum TLS version allows pre-TLSv1.2 protocols, this is not recommended")
}
return minTLSVersion
}
48 changes: 48 additions & 0 deletions lib/util/security/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
package security

import (
"crypto/tls"
"testing"

"github.com/pingcap/TiProxy/lib/util/logger"
"github.com/stretchr/testify/require"
)

Expand All @@ -15,3 +17,49 @@ func BenchmarkCreateTLS(b *testing.B) {
require.Nil(b, err)
}
}

func TestGetMinTLSVer(t *testing.T) {
tests := []struct {
verStr string
verInt uint16
warn string
}{
{
verStr: "v1.0",
verInt: tls.VersionTLS10,
warn: "not recommended",
},
{
verStr: "v1.1",
verInt: tls.VersionTLS11,
warn: "not recommended",
},
{
verStr: "v1.2",
verInt: tls.VersionTLS12,
},
{
verStr: "v1.3",
verInt: tls.VersionTLS13,
},
{
verStr: "unknown",
verInt: tls.VersionTLS12,
warn: "Invalid TLS version",
},
{
verInt: tls.VersionTLS12,
},
}

for _, test := range tests {
lg, text := logger.CreateLoggerForTest(t)
res := GetMinTLSVer(test.verStr, lg)
require.Equal(t, test.verInt, res)
if len(test.warn) > 0 {
require.Contains(t, text.String(), test.warn)
} else {
require.Empty(t, text.String())
}
}
}

0 comments on commit 7f0eed6

Please sign in to comment.