Skip to content

Commit

Permalink
net: fix proxy protocol (#274)
Browse files Browse the repository at this point in the history
  • Loading branch information
xhebox authored May 5, 2023
1 parent 396efd1 commit d0e6292
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 74 deletions.
12 changes: 6 additions & 6 deletions pkg/proxy/net/packetio.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,15 @@ func NewPacketIO(conn net.Conn, opts ...PacketIOption) *PacketIO {
sequence: 0,
buf: buf,
}
// TODO: disable it by default now
p.proxyInited.Store(true)
p.ApplyOpts(opts...)
return p
}

func (p *PacketIO) ApplyOpts(opts ...PacketIOption) {
for _, opt := range opts {
opt(p)
}
return p
}

func (p *PacketIO) wrapErr(err error) error {
Expand All @@ -117,10 +120,7 @@ func (p *PacketIO) wrapErr(err error) error {

// Proxy returned parsed proxy header from clients if any.
func (p *PacketIO) Proxy() *proxyprotocol.Proxy {
if p.proxyInited.Load() {
return p.proxy
}
return nil
return p.proxy
}

func (p *PacketIO) LocalAddr() net.Addr {
Expand Down
2 changes: 1 addition & 1 deletion pkg/proxy/net/packetio_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
type PacketIOption = func(*PacketIO)

func WithProxy(pi *PacketIO) {
pi.proxyInited.Store(true)
pi.proxyInited.Store(false)
}

func WithWrapError(err error) func(pi *PacketIO) {
Expand Down
65 changes: 18 additions & 47 deletions pkg/proxy/net/packetio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,62 +22,33 @@ import (

"github.com/pingcap/TiProxy/lib/config"
"github.com/pingcap/TiProxy/lib/util/security"
"github.com/pingcap/TiProxy/lib/util/waitgroup"
"github.com/pingcap/TiProxy/pkg/testkit"
"github.com/pingcap/tidb/parser/mysql"
"github.com/stretchr/testify/require"
)

func testPipeConn(t *testing.T, a func(*testing.T, *PacketIO), b func(*testing.T, *PacketIO), loop int) {
var wg waitgroup.WaitGroup
client, server := net.Pipe()
cli, srv := NewPacketIO(client), NewPacketIO(server)
if ddl, ok := t.Deadline(); ok {
require.NoError(t, client.SetDeadline(ddl))
require.NoError(t, server.SetDeadline(ddl))
}
for i := 0; i < loop; i++ {
wg.Run(func() {
testkit.TestPipeConn(t,
func(t *testing.T, c net.Conn) {
a(t, NewPacketIO(c))
},
func(t *testing.T, c net.Conn) {
b(t, NewPacketIO(c))
}, loop)
}

func testTCPConn(t *testing.T, a func(*testing.T, *PacketIO), b func(*testing.T, *PacketIO), loop int) {
testkit.TestTCPConn(t,
func(t *testing.T, c net.Conn) {
cli := NewPacketIO(c)
a(t, cli)
require.NoError(t, cli.Close())
})
wg.Run(func() {
},
func(t *testing.T, c net.Conn) {
srv := NewPacketIO(c)
b(t, srv)
require.NoError(t, srv.Close())
})
wg.Wait()
}
}

func testTCPConn(t *testing.T, a func(*testing.T, *PacketIO), b func(*testing.T, *PacketIO), loop int) {
listener, err := net.Listen("tcp", "0.0.0.0:0")
require.NoError(t, err)
defer func() {
require.NoError(t, listener.Close())
}()
var wg waitgroup.WaitGroup
for i := 0; i < loop; i++ {
wg.Run(func() {
cli, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)
if ddl, ok := t.Deadline(); ok {
require.NoError(t, cli.SetDeadline(ddl))
}
cliIO := NewPacketIO(cli)
a(t, cliIO)
require.NoError(t, cliIO.Close())
})
wg.Run(func() {
srv, err := listener.Accept()
require.NoError(t, err)
if ddl, ok := t.Deadline(); ok {
require.NoError(t, srv.SetDeadline(ddl))
}
srvIO := NewPacketIO(srv)
b(t, srvIO)
require.NoError(t, srvIO.Close())
})
wg.Wait()
}
}, loop)
}

func TestPacketIO(t *testing.T) {
Expand Down
65 changes: 65 additions & 0 deletions pkg/proxy/net/proxy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package net

import (
"bytes"
"io"
"net"
"testing"

"github.com/pingcap/TiProxy/pkg/proxy/proxyprotocol"
"github.com/stretchr/testify/require"
)

func TestProxyParse(t *testing.T) {
tcpaddr, err := net.ResolveTCPAddr("tcp", "192.168.1.1:34")
require.NoError(t, err)

testPipeConn(t,
func(t *testing.T, cli *PacketIO) {
p := &proxyprotocol.Proxy{
Version: proxyprotocol.ProxyVersion2,
Command: proxyprotocol.ProxyCommandLocal,
SrcAddress: tcpaddr,
DstAddress: tcpaddr,
TLV: []proxyprotocol.ProxyTlv{
{
Typ: proxyprotocol.ProxyTlvALPN,
Content: nil,
},
{
Typ: proxyprotocol.ProxyTlvUniqueID,
Content: []byte("test"),
},
},
}
b, err := p.ToBytes()
require.NoError(t, err)
_, err = io.Copy(cli.conn, bytes.NewReader(b))
require.NoError(t, err)
err = cli.WritePacket([]byte("hello"), true)
require.NoError(t, err)
},
func(t *testing.T, srv *PacketIO) {
srv.ApplyOpts(WithProxy)
b, err := srv.ReadPacket()
require.NoError(t, err)
require.Equal(t, "hello", string(b))
require.Equal(t, tcpaddr.String(), srv.RemoteAddr().String())
},
1,
)
}
4 changes: 2 additions & 2 deletions pkg/proxy/proxyprotocol/definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ const (
)

type ProxyTlv struct {
content []byte
typ ProxyTlvType
Content []byte
Typ ProxyTlvType
}

type Proxy struct {
Expand Down
8 changes: 4 additions & 4 deletions pkg/proxy/proxyprotocol/listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ func TestProxyListener(t *testing.T) {
DstAddress: tcpaddr,
TLV: []ProxyTlv{
{
typ: ProxyTlvALPN,
content: nil,
Typ: ProxyTlvALPN,
Content: nil,
},
{
typ: ProxyTlvUniqueID,
content: []byte("test"),
Typ: ProxyTlvUniqueID,
Content: []byte("test"),
},
},
}
Expand Down
10 changes: 5 additions & 5 deletions pkg/proxy/proxyprotocol/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ func (p *Proxy) ToBytes() ([]byte, error) {
buf[magicLen+1] = byte(addressFamily<<4) | byte(network&0xF)

for _, tlv := range p.TLV {
buf = append(buf, byte(tlv.typ))
tlen := len(tlv.content)
buf = append(buf, byte(tlv.Typ))
tlen := len(tlv.Content)
buf = append(buf, byte(tlen>>8), byte(tlen))
buf = append(buf, tlv.content...)
buf = append(buf, tlv.Content...)
}

length := len(buf) - 4 - magicLen
Expand Down Expand Up @@ -205,8 +205,8 @@ func ParseProxyV2(rd io.Reader) (m *Proxy, n int, err error) {
length = len(buf) - 3
}
m.TLV = append(m.TLV, ProxyTlv{
typ: typ,
content: buf[3 : 3+length],
Typ: typ,
Content: buf[3 : 3+length],
})
buf = buf[3+length:]
}
Expand Down
14 changes: 7 additions & 7 deletions pkg/proxy/proxyprotocol/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ func TestProxyParse(t *testing.T) {
DstAddress: tcpaddr,
TLV: []ProxyTlv{
{
typ: ProxyTlvALPN,
content: nil,
Typ: ProxyTlvALPN,
Content: nil,
},
{
typ: ProxyTlvUniqueID,
content: []byte("test"),
Typ: ProxyTlvUniqueID,
Content: []byte("test"),
},
},
}
Expand All @@ -66,9 +66,9 @@ func TestProxyParse(t *testing.T) {
require.Equal(t, ProxyVersion2, p.Version)
require.Equal(t, ProxyCommandLocal, p.Command)
require.Len(t, p.TLV, 2)
require.Equal(t, ProxyTlvALPN, p.TLV[0].typ)
require.Equal(t, ProxyTlvUniqueID, p.TLV[1].typ)
require.Equal(t, []byte("test"), p.TLV[1].content)
require.Equal(t, ProxyTlvALPN, p.TLV[0].Typ)
require.Equal(t, ProxyTlvUniqueID, p.TLV[1].Typ)
require.Equal(t, []byte("test"), p.TLV[1].Content)
},
1,
)
Expand Down
8 changes: 6 additions & 2 deletions pkg/testkit/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ func TestTCPConnWithListener(t *testing.T, listen func(*testing.T, string, strin
require.NoError(t, cli.SetDeadline(ddl))
}
a(t, cli)
require.NoError(t, cli.Close())
if err := cli.Close(); err != nil {
require.ErrorIs(t, err, net.ErrClosed)
}
})
wg.Run(func() {
srv, err := listener.Accept()
Expand All @@ -73,7 +75,9 @@ func TestTCPConnWithListener(t *testing.T, listen func(*testing.T, string, strin
require.NoError(t, srv.SetDeadline(ddl))
}
b(t, srv)
require.NoError(t, srv.Close())
if err := srv.Close(); err != nil {
require.ErrorIs(t, err, net.ErrClosed)
}
})
wg.Wait()
}
Expand Down

0 comments on commit d0e6292

Please sign in to comment.