Skip to content

Commit

Permalink
proxy: extract tcp logic out further (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
xhebox authored Aug 3, 2022
1 parent d86b629 commit bc0f00e
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 181 deletions.
83 changes: 29 additions & 54 deletions pkg/proxy/backend/authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
package backend

import (
"crypto/tls"
"net"
"strings"
"testing"

Expand Down Expand Up @@ -44,22 +42,17 @@ func TestTLSConnection(t *testing.T) {
},
}

tc := newTCPConnSuite(t)
cfgOverriders := getCfgCombinations(cfgs)
runTest(t, func(backendListener, proxyListener net.Listener, clientTLSConfig, backendTLSConfig *tls.Config) {
for _, cfgs := range cfgOverriders {
cfg := newTestConfig(cfgs...)
cfg.setTLSConfig(clientTLSConfig, backendTLSConfig)
ts := newTestSuite(t, cfg)
clientErr, proxyErr, backendErr := ts.authenticateFirstTime(backendListener, proxyListener)
if cfg.backendConfig.capability&mysql.ClientSSL == 0 {
require.ErrorContains(t, proxyErr, "must enable TLS")
} else {
require.NoError(t, clientErr)
require.NoError(t, proxyErr)
require.NoError(t, backendErr)
for _, cfgs := range cfgOverriders {
ts, clean := newTestSuite(t, tc, cfgs...)
ts.authenticateFirstTime(t, func(t *testing.T, _ *testSuite, _, _, perr error) {
if ts.mb.backendConfig.capability&mysql.ClientSSL == 0 {
require.ErrorContains(t, perr, "must enable TLS")
}
}
})
})
clean()
}
}

func TestAuthPlugin(t *testing.T) {
Expand Down Expand Up @@ -114,18 +107,13 @@ func TestAuthPlugin(t *testing.T) {
},
}

tc := newTCPConnSuite(t)
cfgOverriders := getCfgCombinations(cfgs)
runTest(t, func(backendListener, proxyListener net.Listener, clientTLSConfig, backendTLSConfig *tls.Config) {
for _, cfgs := range cfgOverriders {
cfg := newTestConfig(cfgs...)
cfg.setTLSConfig(clientTLSConfig, backendTLSConfig)
ts := newTestSuite(t, cfg)
clientErr, proxyErr, backendErr := ts.authenticateFirstTime(backendListener, proxyListener)
require.NoError(t, clientErr)
require.NoError(t, proxyErr)
require.NoError(t, backendErr)
}
})
for _, cfgs := range cfgOverriders {
ts, clean := newTestSuite(t, tc, cfgs...)
ts.authenticateFirstTime(t, nil)
clean()
}
}

func TestCapability(t *testing.T) {
Expand Down Expand Up @@ -165,18 +153,13 @@ func TestCapability(t *testing.T) {
},
}

tc := newTCPConnSuite(t)
cfgOverriders := getCfgCombinations(cfgs)
runTest(t, func(backendListener, proxyListener net.Listener, clientTLSConfig, backendTLSConfig *tls.Config) {
for _, cfgs := range cfgOverriders {
cfg := newTestConfig(cfgs...)
cfg.setTLSConfig(clientTLSConfig, backendTLSConfig)
ts := newTestSuite(t, cfg)
clientErr, proxyErr, backendErr := ts.authenticateFirstTime(backendListener, proxyListener)
require.NoError(t, clientErr)
require.NoError(t, proxyErr)
require.NoError(t, backendErr)
}
})
for _, cfgs := range cfgOverriders {
ts, clean := newTestSuite(t, tc, cfgs...)
ts.authenticateFirstTime(t, nil)
clean()
}
}

func TestSecondHandshake(t *testing.T) {
Expand All @@ -193,20 +176,12 @@ func TestSecondHandshake(t *testing.T) {
},
}

runTest(t, func(backendListener, proxyListener net.Listener, clientTLSConfig, backendTLSConfig *tls.Config) {
for _, hook := range hooks {
cfg := newTestConfig()
cfg.setTLSConfig(clientTLSConfig, backendTLSConfig)
ts := newTestSuite(t, cfg)
clientErr, proxyErr, backendErr := ts.authenticateFirstTime(backendListener, proxyListener)
require.NoError(t, clientErr)
require.NoError(t, proxyErr)
require.NoError(t, backendErr)
// Call the hook after first handshake.
hook(ts)
proxyErr, backendErr = ts.authenticateSecondTime(backendListener, proxyListener)
require.NoError(t, proxyErr)
require.NoError(t, backendErr)
}
})
tc := newTCPConnSuite(t)
for _, hook := range hooks {
ts, clean := newTestSuite(t, tc)
ts.authenticateFirstTime(t, nil)
hook(ts)
ts.authenticateSecondTime(t, nil)
clean()
}
}
118 changes: 118 additions & 0 deletions pkg/proxy/backend/common_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// Copyright 2022 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package backend

import (
"crypto/tls"
"net"
"testing"

pnet "github.com/pingcap/TiProxy/pkg/proxy/net"
"github.com/pingcap/TiProxy/pkg/util/security"
"github.com/pingcap/tidb/util"
"github.com/stretchr/testify/require"
)

type tcpConnSuite struct {
backendListener net.Listener
proxyListener net.Listener
backendTLSConfig *tls.Config
clientTLSConfig *tls.Config
backendIO *pnet.PacketIO
proxyBIO *pnet.PacketIO
proxyCIO *pnet.PacketIO
clientIO *pnet.PacketIO
}

func newTCPConnSuite(t *testing.T) *tcpConnSuite {
var err error

r := &tcpConnSuite{}

r.backendListener, err = net.Listen("tcp", "0.0.0.0:0")
require.NoError(t, err)
r.proxyListener, err = net.Listen("tcp", "0.0.0.0:0")
require.NoError(t, err)
r.backendTLSConfig, r.clientTLSConfig, err = security.CreateTLSConfigForTest()
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, r.backendListener.Close())
require.NoError(t, r.proxyListener.Close())
})

return r
}

func (tc *tcpConnSuite) newConn(t *testing.T) func() {
var wg util.WaitGroupWrapper
wg.Run(func() {
conn, err := tc.backendListener.Accept()
require.NoError(t, err)
tc.backendIO = pnet.NewPacketIO(conn)
})
wg.Run(func() {
backendConn, err := net.Dial("tcp", tc.backendListener.Addr().String())
require.NoError(t, err)
tc.proxyBIO = pnet.NewPacketIO(backendConn)

clientConn, err := tc.proxyListener.Accept()
require.NoError(t, err)
tc.proxyCIO = pnet.NewPacketIO(clientConn)
})
wg.Run(func() {
conn, err := net.Dial("tcp", tc.proxyListener.Addr().String())
require.NoError(t, err)
tc.clientIO = pnet.NewPacketIO(conn)
})
wg.Wait()
return func() {
// may be closed twice
_ = tc.clientIO.Close()
_ = tc.proxyCIO.Close()
_ = tc.proxyBIO.Close()
_ = tc.backendIO.Close()
}
}

func (tc *tcpConnSuite) run(t *testing.T, clientRunner, backendRunner func(*pnet.PacketIO) error, proxyRunner func(*pnet.PacketIO, *pnet.PacketIO) error) (cerr, berr, perr error) {
var wg util.WaitGroupWrapper
if clientRunner != nil {
wg.Run(func() {
cerr = clientRunner(tc.clientIO)
if cerr != nil {
require.NoError(t, tc.clientIO.Close())
}
})
}
if backendRunner != nil {
wg.Run(func() {
berr = backendRunner(tc.backendIO)
if berr != nil {
require.NoError(t, tc.backendIO.Close())
}
})
}
if proxyRunner != nil {
wg.Run(func() {
perr = proxyRunner(tc.proxyCIO, tc.proxyBIO)
if perr != nil {
require.NoError(t, tc.proxyCIO.Close())
require.NoError(t, tc.proxyBIO.Close())
}
})
}
wg.Wait()
return
}
Loading

0 comments on commit bc0f00e

Please sign in to comment.