Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

proxy: fix TLS buffering #19

Merged
merged 11 commits into from
Jul 28, 2022
35 changes: 21 additions & 14 deletions pkg/proxy/net/packetio.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,27 +62,39 @@ const (
defaultReaderSize = 16 * 1024
)

type rdbufConn struct {
net.Conn
*bufio.Reader
}

func (f *rdbufConn) Read(b []byte) (int, error) {
return f.Reader.Read(b)
}

// PacketIO is a helper to read and write sql and proxy protocol.
type PacketIO struct {
conn net.Conn
tlsConn net.Conn
buf *bufio.ReadWriter
sequence uint8
proxyInited bool
proxy *Proxy
}

func NewPacketIO(conn net.Conn) *PacketIO {
buf := bufio.NewReadWriter(
xhebox marked this conversation as resolved.
Show resolved Hide resolved
bufio.NewReaderSize(conn, defaultReaderSize),
bufio.NewWriterSize(conn, defaultWriterSize),
)
p := &PacketIO{
conn: conn,
conn: &rdbufConn{
xhebox marked this conversation as resolved.
Show resolved Hide resolved
conn,
buf.Reader,
},
sequence: 0,
// TODO: enable proxy probe for clients only
// disable it by default now
proxyInited: true,
buf: bufio.NewReadWriter(
bufio.NewReaderSize(conn, defaultReaderSize),
bufio.NewWriterSize(conn, defaultWriterSize),
),
buf: buf,
}
return p
}
Expand All @@ -102,7 +114,7 @@ func (p *PacketIO) ResetSequence() {
func (p *PacketIO) ReadOnePacket() ([]byte, bool, error) {
var header [4]byte

if _, err := io.ReadFull(p.buf, header[:]); err != nil {
if _, err := io.ReadFull(p.conn, header[:]); err != nil {
return nil, false, errors.WithStack(errors.Wrap(ErrReadConn, err))
}

Expand All @@ -124,7 +136,7 @@ func (p *PacketIO) ReadOnePacket() ([]byte, bool, error) {

// refill mysql headers
if refill {
if _, err := io.ReadFull(p.buf, header[:]); err != nil {
if _, err := io.ReadFull(p.conn, header[:]); err != nil {
return nil, false, errors.WithStack(errors.Wrap(ErrReadConn, err))
}
}
Expand All @@ -137,7 +149,7 @@ func (p *PacketIO) ReadOnePacket() ([]byte, bool, error) {
length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)

data := make([]byte, length)
if _, err := io.ReadFull(p.buf, data); err != nil {
if _, err := io.ReadFull(p.conn, data); err != nil {
return nil, false, errors.WithStack(errors.Wrap(ErrReadConn, err))
}
return data, length == mysql.MaxPayloadLen, nil
Expand Down Expand Up @@ -224,11 +236,6 @@ func (p *PacketIO) Close() error {
errs = append(errs, err)
}
*/
if p.tlsConn != nil {
if err := p.tlsConn.Close(); err != nil {
errs = append(errs, err)
}
}
if p.conn != nil {
if err := p.conn.Close(); err != nil {
errs = append(errs, err)
Expand Down
158 changes: 158 additions & 0 deletions pkg/proxy/net/packetio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,18 @@
package net

import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/binary"
"encoding/pem"
"math/big"
"net"
"testing"
"time"

"github.com/pingcap/TiProxy/pkg/util/waitgroup"
"github.com/pingcap/tidb/parser/mysql"
Expand Down Expand Up @@ -96,3 +105,152 @@ func TestPacketIO(t *testing.T) {
},
)
}

// certsetup is from https://gist.github.com/shaneutt/5e1995295cff6721c89a71d13a71c251.
func certsetup() (serverTLSConf *tls.Config, clientTLSConf *tls.Config, err error) {
// set up our CA certificate
ca := &x509.Certificate{
xhebox marked this conversation as resolved.
Show resolved Hide resolved
SerialNumber: big.NewInt(2019),
Subject: pkix.Name{
Organization: []string{"Company, INC."},
Country: []string{"US"},
Province: []string{""},
Locality: []string{"San Francisco"},
StreetAddress: []string{"Golden Gate Bridge"},
PostalCode: []string{"94016"},
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}

// create our private and public key
caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, err
}

// create the CA
caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey)
if err != nil {
return nil, nil, err
}

// pem encode
caPEM := new(bytes.Buffer)
pem.Encode(caPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: caBytes,
})

caPrivKeyPEM := new(bytes.Buffer)
pem.Encode(caPrivKeyPEM, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(caPrivKey),
})

// set up our server certificate
cert := &x509.Certificate{
SerialNumber: big.NewInt(2019),
Subject: pkix.Name{
Organization: []string{"Company, INC."},
Country: []string{"US"},
Province: []string{""},
Locality: []string{"San Francisco"},
StreetAddress: []string{"Golden Gate Bridge"},
PostalCode: []string{"94016"},
},
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
SubjectKeyId: []byte{1, 2, 3, 4, 6},
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
}

certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, err
}

certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &certPrivKey.PublicKey, caPrivKey)
if err != nil {
return nil, nil, err
}

certPEM := new(bytes.Buffer)
pem.Encode(certPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
})

certPrivKeyPEM := new(bytes.Buffer)
pem.Encode(certPrivKeyPEM, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
})

serverCert, err := tls.X509KeyPair(certPEM.Bytes(), certPrivKeyPEM.Bytes())
if err != nil {
return nil, nil, err
}

serverTLSConf = &tls.Config{
Certificates: []tls.Certificate{serverCert},
}

certpool := x509.NewCertPool()
certpool.AppendCertsFromPEM(caPEM.Bytes())
clientTLSConf = &tls.Config{
RootCAs: certpool,
}

return
}

func TestTLS(t *testing.T) {
listener, err := net.Listen("tcp", "0.0.0.0:0")
require.NoError(t, err)
defer func() {
require.NoError(t, listener.Close())
}()

var wg waitgroup.WaitGroup
stls, ctls, err := certsetup()
require.NoError(t, err)
for i := 0; i < 500; i++ {
wg.Run(func() {
srv, err := listener.Accept()
require.NoError(t, err)

srvIO := NewPacketIO(srv)
err = srvIO.WritePacket([]byte("hello"), true)
require.NoError(t, err)

_, err = srvIO.UpgradeToServerTLS(stls)
require.NoError(t, err)

data, err := srvIO.ReadPacket()
require.NoError(t, err)
require.Equal(t, []byte("world"), data)
xhebox marked this conversation as resolved.
Show resolved Hide resolved
})
wg.Run(func() {
cli, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)

cliIO := NewPacketIO(cli)
data, err := cliIO.ReadPacket()
require.NoError(t, err)
require.Equal(t, []byte("hello"), data)

require.NoError(t, cliIO.UpgradeToClientTLS(ctls))

err = cliIO.WritePacket([]byte("world"), true)
require.NoError(t, err)
})
wg.Wait()
}
}
2 changes: 1 addition & 1 deletion pkg/proxy/net/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func TestProxy(t *testing.T) {
func(t *testing.T, srv *PacketIO) {
// skip 4 bytes of magic
var hdr [4]byte
_, err := io.ReadFull(srv.buf, hdr[:])
_, err := io.ReadFull(srv.conn, hdr[:])
require.NoError(t, err)

// try to parse V2
Expand Down
16 changes: 8 additions & 8 deletions pkg/proxy/net/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,27 @@ import (
)

func (p *PacketIO) UpgradeToServerTLS(tlsConfig *tls.Config) (tls.ConnectionState, error) {
tlsConfig = tlsConfig.Clone()
tlsConn := tls.Server(p.conn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
return tlsConn.ConnectionState(), errors.WithStack(errors.Wrap(ErrHandshakeTLS, err))
}
p.buf.Reader.Reset(tlsConn)
p.buf.Writer.Reset(tlsConn)
p.conn = tlsConn
xhebox marked this conversation as resolved.
Show resolved Hide resolved
p.buf.Writer.Reset(p.conn)
return tlsConn.ConnectionState(), nil
}

func (p *PacketIO) UpgradeToClientTLS(tlsConfig *tls.Config) error {
tlsConfig = tlsConfig.Clone()
host, _, err := net.SplitHostPort(p.conn.RemoteAddr().String())
if err != nil {
return errors.WithStack(errors.Wrap(ErrHandshakeTLS, err))
if err == nil {
tlsConfig.ServerName = host
}
tlsConfig = tlsConfig.Clone()
tlsConfig.ServerName = host
tlsConn := tls.Client(p.conn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
return errors.WithStack(errors.Wrap(ErrHandshakeTLS, err))
}
p.buf.Reader.Reset(tlsConn)
p.buf.Writer.Reset(tlsConn)
p.conn = tlsConn
p.buf.Writer.Reset(p.conn)
return nil
}