Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions pkg/proxy/net/packetio.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import (
"github.com/pingcap/tidb/errno"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/util/dbterror"
"go.uber.org/atomic"
)

var (
Expand Down Expand Up @@ -76,7 +77,7 @@ type PacketIO struct {
conn net.Conn
buf *bufio.ReadWriter
sequence uint8
proxyInited bool
proxyInited *atomic.Bool
proxy *Proxy
}

Expand All @@ -91,14 +92,21 @@ func NewPacketIO(conn net.Conn) *PacketIO {
buf.Reader,
},
sequence: 0,
// TODO: enable proxy probe for clients only
// disable it by default now
proxyInited: true,
// TODO: disable it by default now
proxyInited: atomic.NewBool(true),
buf: buf,
}
return p
}

// Proxy returned parsed proxy header from clients if any.
func (p *PacketIO) Proxy() *Proxy {
if p.proxyInited.Load() {
return p.proxy
}
return nil
}

func (p *PacketIO) LocalAddr() net.Addr {
return p.conn.LocalAddr()
}
Expand All @@ -120,7 +128,7 @@ func (p *PacketIO) ReadOnePacket() ([]byte, bool, error) {

// probe proxy V2
refill := false
if !p.proxyInited {
if !p.proxyInited.Load() {
if bytes.Compare(header[:], proxyV2Magic[:4]) == 0 {
proxyHeader, err := p.parseProxyV2()
if err != nil {
Expand All @@ -131,7 +139,7 @@ func (p *PacketIO) ReadOnePacket() ([]byte, bool, error) {
refill = true
}
}
p.proxyInited = true
p.proxyInited.Store(true)
}

// refill mysql headers
Expand Down
4 changes: 3 additions & 1 deletion pkg/proxy/net/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,11 @@ func (p *PacketIO) parseProxyV2() (*Proxy, error) {
return m, nil
}

func (p *PacketIO) writeProxyV2(m *Proxy) error {
// WriteProxyV2 should only be called at the beginning of connection, before any write operations.
func (p *PacketIO) WriteProxyV2(m *Proxy) error {
if _, err := io.Copy(p.buf, bytes.NewReader(m.ToBytes())); err != nil {
return errors.WithStack(errors.Wrap(ErrWriteConn, err))
}
// according to the spec, we better flush to avoid server hanging
return p.Flush()
}
2 changes: 1 addition & 1 deletion pkg/proxy/net/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestProxy(t *testing.T) {

testPipeConn(t,
func(t *testing.T, cli *PacketIO) {
require.NoError(t, cli.writeProxyV2(&Proxy{
require.NoError(t, cli.WriteProxyV2(&Proxy{
Version: ProxyVersion2,
Command: ProxyCommandLocal,
SrcAddress: tcpaddr,
Expand Down