Skip to content

Commit

Permalink
proxy: move proxyprotocol as a separate package
Browse files Browse the repository at this point in the history
Signed-off-by: xhe <xw897002528@gmail.com>
  • Loading branch information
xhebox committed Mar 13, 2023
1 parent c369737 commit 7903616
Show file tree
Hide file tree
Showing 10 changed files with 440 additions and 274 deletions.
7 changes: 2 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ EXECUTABLE_TARGETS := $(patsubst cmd/%,cmd_%,$(wildcard cmd/*))

default: cmd

dev: cmd lint test

cache: build lint test
dev: build lint test

cmd: $(EXECUTABLE_TARGETS)

Expand All @@ -55,8 +53,7 @@ gocovmerge:
tidy:
go mod tidy
cd lib && go mod tidy

cache:
build:
go build ./...
cd lib && go build ./...

Expand Down
7 changes: 4 additions & 3 deletions pkg/proxy/backend/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (

"github.com/pingcap/TiProxy/lib/util/errors"
pnet "github.com/pingcap/TiProxy/pkg/proxy/net"
"github.com/pingcap/TiProxy/pkg/proxy/proxyprotocol"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/util/hack"
"go.uber.org/zap"
Expand Down Expand Up @@ -65,14 +66,14 @@ func (auth *Authenticator) writeProxyProtocol(clientIO, backendIO *pnet.PacketIO
if auth.proxyProtocol {
proxy := clientIO.Proxy()
if proxy == nil {
proxy = &pnet.Proxy{
proxy = &proxyprotocol.Proxy{
SrcAddress: clientIO.RemoteAddr(),
DstAddress: backendIO.RemoteAddr(),
Version: pnet.ProxyVersion2,
Version: proxyprotocol.ProxyVersion2,
}
}
// either from another proxy or directly from clients, we are acting as a proxy
proxy.Command = pnet.ProxyCommandProxy
proxy.Command = proxyprotocol.ProxyCommandProxy
if err := backendIO.WriteProxyV2(proxy); err != nil {
return err
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/proxy/net/packetio.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import (
"github.com/pingcap/TiProxy/lib/config"
"github.com/pingcap/TiProxy/lib/util/errors"
"github.com/pingcap/TiProxy/pkg/proxy/keepalive"
"github.com/pingcap/TiProxy/pkg/proxy/proxyprotocol"
"github.com/pingcap/tidb/errno"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/util/dbterror"
Expand Down Expand Up @@ -84,7 +85,7 @@ type PacketIO struct {
rawConn net.Conn
buf *bufio.ReadWriter
proxyInited atomic.Bool
proxy *Proxy
proxy *proxyprotocol.Proxy
remoteAddr net.Addr
wrap error
sequence uint8
Expand Down Expand Up @@ -117,7 +118,7 @@ func (p *PacketIO) wrapErr(err error) error {
}

// Proxy returned parsed proxy header from clients if any.
func (p *PacketIO) Proxy() *Proxy {
func (p *PacketIO) Proxy() *proxyprotocol.Proxy {
if p.proxyInited.Load() {
return p.proxy
}
Expand Down
15 changes: 7 additions & 8 deletions pkg/proxy/net/packetio_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package net

import (
"net"

"github.com/pingcap/TiProxy/pkg/proxy/proxyprotocol"
)

type PacketIOption = func(*PacketIO)
Expand All @@ -31,13 +33,17 @@ func WithWrapError(err error) func(pi *PacketIO) {
}

// WithRemoteAddr
var _ net.Addr = &originAddr{}
var _ proxyprotocol.AddressWrapper = &originAddr{}

type originAddr struct {
net.Addr
addr string
}

func (o *originAddr) Unwrap() net.Addr {
return o.Addr
}

func (o *originAddr) String() string {
return o.addr
}
Expand All @@ -47,10 +53,3 @@ func WithRemoteAddr(readdr string, addr net.Addr) func(pi *PacketIO) {
pi.remoteAddr = &originAddr{Addr: addr, addr: readdr}
}
}

func unwrapOriginAddr(addr net.Addr) net.Addr {
if oaddr, ok := addr.(*originAddr); ok {
return oaddr.Addr
}
return addr
}
254 changes: 9 additions & 245 deletions pkg/proxy/net/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,147 +17,12 @@ package net
import (
"bytes"
"io"
"net"

"github.com/pingcap/TiProxy/lib/util/errors"
"github.com/pingcap/TiProxy/pkg/proxy/proxyprotocol"
)

var (
ErrAddressFamilyMismatch = errors.New("address family between source and target mismatched")
)

type ProxyVersion int

const (
ProxyVersion2 ProxyVersion = iota + 2
)

type ProxyCommand int

const (
ProxyCommandLocal ProxyCommand = iota
ProxyCommandProxy
)

type ProxyAddressFamily int

const (
ProxyAFUnspec ProxyAddressFamily = iota
ProxyAFINet
ProxyAFINet6
ProxyAFUnix
)

type ProxyNetwork int

const (
ProxyNetworkUnspec ProxyNetwork = iota
ProxyNetworkStream
ProxyNetworkDgram
)

type ProxyTlvType int

const (
ProxyTlvALPN ProxyTlvType = iota + 0x01
ProxyTlvAuthority
ProxyTlvCRC32C
ProxyTlvNoop
ProxyTlvUniqueID
ProxyTlvSSL ProxyTlvType = iota + 0x20
ProxyTlvSSLCN
ProxyTlvSSLCipher
ProxyTlvSSLSignALG
ProxyTlvSSLKeyALG
ProxyTlvNetns ProxyTlvType = iota + 0x30
)

type ProxyTlv struct {
content []byte
typ ProxyTlvType
}

type Proxy struct {
SrcAddress net.Addr
DstAddress net.Addr
TLV []ProxyTlv
Version ProxyVersion
Command ProxyCommand
}

func (p *Proxy) ToBytes() ([]byte, error) {
magicLen := len(proxyV2Magic)
buf := make([]byte, magicLen+4)
_ = copy(buf, proxyV2Magic)
buf[magicLen] = byte(p.Version<<4) | byte(p.Command&0xF)

addressFamily := ProxyAFUnspec
network := ProxyNetworkUnspec

srcAddr := unwrapOriginAddr(p.SrcAddress)
dstAddr := unwrapOriginAddr(p.DstAddress)

switch sadd := srcAddr.(type) {
case *net.TCPAddr:
addressFamily = ProxyAFINet
if len(sadd.IP) == net.IPv6len {
addressFamily = ProxyAFINet6
}
network = ProxyNetworkStream
dadd, ok := dstAddr.(*net.TCPAddr)
if !ok {
return nil, ErrAddressFamilyMismatch
}
buf = append(buf, sadd.IP...)
buf = append(buf, dadd.IP...)
buf = append(buf, byte(sadd.Port>>8), byte(sadd.Port))
buf = append(buf, byte(dadd.Port>>8), byte(dadd.Port))
case *net.UDPAddr:
addressFamily = ProxyAFINet
if len(sadd.IP) == net.IPv6len {
addressFamily = ProxyAFINet6
}
network = ProxyNetworkDgram
dadd, ok := dstAddr.(*net.UDPAddr)
if !ok {
return nil, ErrAddressFamilyMismatch
}
buf = append(buf, sadd.IP...)
buf = append(buf, dadd.IP...)
buf = append(buf, byte(sadd.Port>>8), byte(sadd.Port))
buf = append(buf, byte(dadd.Port>>8), byte(dadd.Port))
case *net.UnixAddr:
addressFamily = ProxyAFUnix
switch sadd.Net {
case "unix":
network = ProxyNetworkStream
case "unixdgram":
network = ProxyNetworkDgram
}
dadd, ok := dstAddr.(*net.UnixAddr)
if !ok {
return nil, ErrAddressFamilyMismatch
}
buf = append(buf, []byte(sadd.Name)...)
buf = append(buf, []byte(dadd.Name)...)
}
buf[magicLen+1] = byte(addressFamily<<4) | byte(network&0xF)

for _, tlv := range p.TLV {
buf = append(buf, byte(tlv.typ))
tlen := len(tlv.content)
buf = append(buf, byte(tlen>>8), byte(tlen))
buf = append(buf, tlv.content...)
}

length := len(buf) - 4 - magicLen
buf[magicLen+2] = byte(length >> 8)
buf[magicLen+3] = byte(length)

return buf, nil
}

func (p *PacketIO) parseProxyV2() (*Proxy, error) {
func (p *PacketIO) parseProxyV2() (*proxyprotocol.Proxy, error) {
rem, err := p.buf.Peek(8)
if err != nil {
return nil, errors.WithStack(errors.Wrap(ErrReadConn, err))
Expand All @@ -173,118 +38,17 @@ func (p *PacketIO) parseProxyV2() (*Proxy, error) {
}
p.inBytes += 8

var hdr [4]byte

if _, err := io.ReadFull(p.buf, hdr[:]); err != nil {
return nil, errors.WithStack(err)
}
p.inBytes += 4

m := &Proxy{}
m.Version = ProxyVersion(hdr[0] >> 4)
m.Command = ProxyCommand(hdr[0] & 0xF)

buf := make([]byte, int(hdr[2])<<8|int(hdr[3]))
if _, err := io.ReadFull(p.buf, buf); err != nil {
return nil, errors.WithStack(err)
m, n, err := proxyprotocol.ParseProxyV2(p.buf)
p.inBytes += uint64(n)
if err == nil {
// set RemoteAddr in case of proxy.
p.remoteAddr = m.SrcAddress
}
p.inBytes += uint64(len(buf))

addressFamily := ProxyAddressFamily(hdr[1] >> 4)
network := ProxyNetwork(hdr[1] & 0xF)
switch addressFamily {
case ProxyAFINet:
fallthrough
case ProxyAFINet6:
length := 4
if addressFamily == ProxyAFINet6 {
length = 16
}
if len(buf) < length*2+4 {
// TODO: logging
break
}
saddr := net.IP(buf[:length])
daddr := net.IP(buf[length : length*2])
sport := int(buf[2*length])<<8 | int(buf[2*length+1])
dport := int(buf[2*length+2])<<8 | int(buf[2*length+3])
switch network {
case ProxyNetworkStream:
m.SrcAddress = &net.TCPAddr{
IP: saddr,
Port: sport,
}
m.DstAddress = &net.TCPAddr{
IP: daddr,
Port: dport,
}
case ProxyNetworkDgram:
m.SrcAddress = &net.UDPAddr{
IP: saddr,
Port: sport,
}
m.DstAddress = &net.UDPAddr{
IP: daddr,
Port: dport,
}
default:
// TODO: logging
}
buf = buf[length*2+4:]
case ProxyAFUnix:
if len(buf) < 216 {
// TODO: logging
break
}
saddr := string(buf[:108])
daddr := string(buf[108:216])
switch network {
case ProxyNetworkStream:
m.SrcAddress = &net.UnixAddr{
Name: saddr,
Net: "unix",
}
m.DstAddress = &net.UnixAddr{
Name: daddr,
Net: "unix",
}
case ProxyNetworkDgram:
m.SrcAddress = &net.UnixAddr{
Name: saddr,
Net: "unixdgram",
}
m.DstAddress = &net.UnixAddr{
Name: daddr,
Net: "unixdgram",
}
default:
// TODO: logging
}
buf = buf[216:]
default:
buf = buf[len(buf):]
}

for len(buf) >= 3 {
typ := ProxyTlvType(buf[0])
length := int(buf[1])<<8 | int(buf[2])
if len(buf) < length+3 {
length = len(buf) - 3
}
m.TLV = append(m.TLV, ProxyTlv{
typ: typ,
content: buf[3 : 3+length],
})
buf = buf[3+length:]
}

// set RemoteAddr in case of proxy.
p.remoteAddr = m.SrcAddress
return m, nil
return m, err
}

// WriteProxyV2 should only be called at the beginning of connection, before any write operations.
func (p *PacketIO) WriteProxyV2(m *Proxy) error {
func (p *PacketIO) WriteProxyV2(m *proxyprotocol.Proxy) error {
buf, err := m.ToBytes()
if err != nil {
return errors.Wrap(ErrWriteConn, err)
Expand Down
Loading

0 comments on commit 7903616

Please sign in to comment.