Skip to content

Commit 4686efd

Browse files
committed
*: refine options
Signed-off-by: xhe <xw897002528@gmail.com>
1 parent e6738a1 commit 4686efd

File tree

12 files changed

+160
-89
lines changed

12 files changed

+160
-89
lines changed

conf/weirproxy.yaml

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,30 +20,36 @@ log:
2020
max-backups: 1
2121
security:
2222
rsa-key-size: 4096
23-
# tls object is either of type server or client
23+
# tls object is either of type server, client, or peer
2424
# xxxx:
2525
# ca: ca.pem
2626
# cert: c.pem
2727
# key: k.pem
2828
# auto-certs: true
2929
# skip-ca: trure
30-
# client-tls:
30+
# client object:
3131
# 1. requires: ca or skip-ca(skip verify server certs)
32-
# 2. optionally: cert/key will be used for server-side client verification
32+
# 2. optionally: cert/key will be used if server asks
3333
# 3. useless/forbid: auto-certs
3434
# server object:
3535
# 1. requires: cert/key or auto-certs(generate a temporary cert, mostly for testing)
36-
# 2. optionally: ca will enable server-side client verification
36+
# 2. optionally: ca will enable server-side client verification.
3737
# 3. useless/forbid: skip-ca
38+
# peer object:
39+
# 1. requires: cert/key/ca or auto-certs/skip-ca
3840
cluster-tls: # client object
3941
# access to other components like TiDB or PD, will use this
4042
skip-ca: true
4143
sql-tls: # client object
4244
# access to TiDB sql port, it has a standalone TLS configuration
4345
skip-ca: true
4446
server-tls: # server object
45-
# proxy SQL or internal HTTP port will all use this
47+
# proxy SQL or HTTP port will use this
4648
auto-certs: true
49+
peer-tls: # peer object
50+
# internal communication between proxies
51+
auto-certs: true
52+
skip-ca: true
4753
advance:
4854
# ignore-wrong-namespace: true
4955
# peer-port: "3081"

lib/config/proxy.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ func (c TLSConfig) HasCA() bool {
9494
type Security struct {
9595
RSAKeySize int `yaml:"rsa-key-size,omitempty" toml:"rsa-key-size,omitempty" json:"rsa-key-size,omitempty"`
9696
ServerTLS TLSConfig `yaml:"server-tls,omitempty" toml:"server-tls,omitempty" json:"server-tls,omitempty"`
97+
PeerTLS TLSConfig `yaml:"peer-tls,omitempty" toml:"peer-tls,omitempty" json:"peer-tls,omitempty"`
9798
ClusterTLS TLSConfig `yaml:"cluster-tls,omitempty" toml:"cluster-tls,omitempty" json:"cluster-tls,omitempty"`
9899
SQLTLS TLSConfig `yaml:"sql-tls,omitempty" toml:"sql-tls,omitempty" json:"sql-tls,omitempty"`
99100
}

lib/config/proxy_test.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ var testProxyConfig = Config{
6363
Key: "c",
6464
AutoCerts: true,
6565
},
66+
PeerTLS: TLSConfig{
67+
CA: "a",
68+
Cert: "b",
69+
Key: "c",
70+
AutoCerts: true,
71+
},
6672
ClusterTLS: TLSConfig{
6773
CA: "a",
6874
SkipCA: true,

lib/util/security/tls.go

Lines changed: 52 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import (
2222
"crypto/x509"
2323
"crypto/x509/pkix"
2424
"encoding/pem"
25-
"fmt"
2625
"io/ioutil"
2726
"math/big"
2827
"net"
@@ -45,6 +44,13 @@ func createTLSConfigificates(logger *zap.Logger, certpath string, keypath string
4544
return errors.New("cert and key should be present or not at the same time")
4645
}
4746

47+
if err := os.MkdirAll(filepath.Dir(keypath), 0755); err != nil {
48+
return err
49+
}
50+
if err := os.MkdirAll(filepath.Dir(certpath), 0755); err != nil {
51+
return err
52+
}
53+
4854
privkey, err := rsa.GenerateKey(rand.Reader, rsaKeySize)
4955
if err != nil {
5056
return err
@@ -108,6 +114,18 @@ func createTLSConfigificates(logger *zap.Logger, certpath string, keypath string
108114
return nil
109115
}
110116

117+
func PreProcessTLSConfig(logger *zap.Logger, scfg *config.TLSConfig, workdir, mod string, keySize int) error {
118+
if !scfg.HasCert() && scfg.AutoCerts {
119+
scfg.Cert = filepath.Join(workdir, mod, "cert.pem")
120+
scfg.Key = filepath.Join(workdir, mod, "key.pem")
121+
if err := createTLSConfigificates(logger, scfg.Cert, scfg.Key, keySize); err != nil {
122+
return errors.WithStack(err)
123+
}
124+
return PreProcessTLSConfig(logger, scfg, workdir, mod, keySize)
125+
}
126+
return nil
127+
}
128+
111129
// CreateTLSConfigForTest is from https://gist.github.com/shaneutt/5e1995295cff6721c89a71d13a71c251.
112130
func CreateTLSConfigForTest() (serverTLSConf *tls.Config, clientTLSConf *tls.Config, err error) {
113131
// set up our CA certificate
@@ -213,109 +231,98 @@ func CreateTLSConfigForTest() (serverTLSConf *tls.Config, clientTLSConf *tls.Con
213231
return
214232
}
215233

216-
func BuildServerTLSConfig(logger *zap.Logger, cfg config.TLSConfig, workdir, mod string, keySize int) (*tls.Config, error) {
234+
func BuildServerTLSConfig(logger *zap.Logger, cfg config.TLSConfig) (*tls.Config, error) {
235+
logger = logger.With(zap.String("tls", "server"))
217236
if !cfg.HasCert() {
218-
if cfg.AutoCerts {
219-
cfg.Cert = filepath.Join(workdir, mod, "cert.pem")
220-
cfg.Key = filepath.Join(workdir, mod, "key.pem")
221-
if err := createTLSConfigificates(logger, cfg.Cert, cfg.Key, keySize); err != nil {
222-
return nil, err
223-
}
224-
return BuildServerTLSConfig(logger, cfg, workdir, mod, keySize)
225-
}
226-
227-
// TODO: require certs here
228-
logger.Warn(fmt.Sprintf("require certificates to secure %s clients connections", mod))
237+
logger.Warn("require certificates to secure clients connections, disable TLS")
229238
return nil, nil
230239
}
231240

232241
tcfg := &tls.Config{}
233242
cert, err := tls.LoadX509KeyPair(cfg.Cert, cfg.Key)
234243
if err != nil {
235-
return nil, errors.Errorf("failed to load certs for %s: %w", mod, err)
244+
return nil, errors.Errorf("failed to load certs: %w", err)
236245
}
237246
tcfg.Certificates = append(tcfg.Certificates, cert)
238247

239248
if !cfg.HasCA() {
240-
logger.Warn(fmt.Sprintf("no signed certs for %s, will not authenticate %s clients (connection is still secured)", mod, mod))
249+
logger.Warn("no CA, server will not authenticate clients (connection is still secured)")
241250
return tcfg, nil
242251
}
243252

244253
tcfg.ClientAuth = tls.RequireAndVerifyClientCert
245254
tcfg.ClientCAs = x509.NewCertPool()
246255
certBytes, err := ioutil.ReadFile(cfg.CA)
247256
if err != nil {
248-
return nil, errors.Errorf("failed to read CA for %s: %w", mod, err)
257+
return nil, errors.Errorf("failed to read CA: %w", err)
249258
}
250259
if !tcfg.ClientCAs.AppendCertsFromPEM(certBytes) {
251-
return nil, errors.Errorf("failed to append CA for %s", mod)
260+
return nil, errors.Errorf("failed to append CA")
252261
}
253262
return tcfg, nil
254263
}
255264

256-
func BuildClientTLSConfig(logger *zap.Logger, cfg config.TLSConfig, mod string) (*tls.Config, error) {
265+
func BuildClientTLSConfig(logger *zap.Logger, cfg config.TLSConfig) (*tls.Config, error) {
266+
logger = logger.With(zap.String("tls", "client"))
257267
if !cfg.HasCA() {
258-
logger.Warn(fmt.Sprintf("require CA to verify %s server connections", mod))
259268
if cfg.SkipCA {
260269
// still enable TLS without verify server certs
261270
return &tls.Config{InsecureSkipVerify: true}, nil
262271
}
263-
// no TLS
272+
logger.Warn("no CA to verify server connections, disable TLS")
264273
return nil, nil
265274
}
266275

267276
tcfg := &tls.Config{}
268277
tcfg.ClientCAs = x509.NewCertPool()
269278
certBytes, err := ioutil.ReadFile(cfg.CA)
270279
if err != nil {
271-
return nil, errors.Errorf("failed to read CA for %s: %w", mod, err)
280+
return nil, errors.Errorf("failed to read CA: %w", err)
272281
}
273282
if !tcfg.ClientCAs.AppendCertsFromPEM(certBytes) {
274-
return nil, errors.Errorf("failed to append CA for %s", mod)
283+
return nil, errors.Errorf("failed to append CA")
275284
}
276285

277286
if !cfg.HasCert() {
278-
logger.Warn(fmt.Sprintf("no certs for %s, server may reject the connection", mod))
287+
logger.Warn("no certificates, server may reject the connection")
279288
return tcfg, nil
280289
}
281290
cert, err := tls.LoadX509KeyPair(cfg.Cert, cfg.Key)
282291
if err != nil {
283-
return nil, errors.Errorf("failed to load certs for %s: %w", mod, err)
292+
return nil, errors.Errorf("failed to load certs for: %w", err)
284293
}
285294
tcfg.Certificates = append(tcfg.Certificates, cert)
286295

287296
return tcfg, nil
288297
}
289298

290-
func BuildEtcdTLSConfig(logger *zap.Logger, server config.TLSConfig, workdir, mod string, keySize int) (clientInfo, peerInfo transport.TLSInfo, err error) {
291-
if !server.HasCert() {
292-
if server.AutoCerts {
293-
server.Cert = filepath.Join(workdir, mod, "cert.pem")
294-
server.Key = filepath.Join(workdir, mod, "key.pem")
295-
if err = createTLSConfigificates(logger, server.Cert, server.Key, keySize); err != nil {
296-
return
297-
}
298-
return BuildEtcdTLSConfig(logger, server, workdir, mod, keySize)
299-
}
300-
} else {
299+
func BuildEtcdTLSConfig(logger *zap.Logger, server, peer config.TLSConfig) (clientInfo, peerInfo transport.TLSInfo, err error) {
300+
logger = logger.With(zap.String("tls", "etcd"))
301+
302+
if server.HasCert() {
301303
clientInfo.CertFile = server.Cert
302304
clientInfo.KeyFile = server.Key
303305
if server.HasCA() {
304306
clientInfo.ClientCertAuth = true
305307
clientInfo.TrustedCAFile = server.CA
306-
} else {
307-
logger.Warn("no signed certs for etcd clients, proxy will not authenticate etcd clients (connection is still secured)")
308+
} else if !server.SkipCA {
309+
logger.Warn("no CA, proxy will not authenticate etcd clients (connection is still secured)")
308310
}
309311
}
310312

311-
if server.HasCA() && server.HasCert() {
312-
peerInfo.CertFile = server.Cert
313-
peerInfo.KeyFile = server.Key
314-
peerInfo.TrustedCAFile = server.CA
315-
peerInfo.ClientCertAuth = true
316-
} else if server.HasCA() || server.HasCert() {
317-
err = errors.New("need a full set of cert/ca/key for secure etcd peer inter-communication")
318-
return
313+
if peer.HasCert() {
314+
peerInfo.CertFile = peer.Cert
315+
peerInfo.KeyFile = peer.Key
316+
if peer.HasCA() {
317+
peerInfo.TrustedCAFile = peer.CA
318+
peerInfo.ClientCertAuth = true
319+
} else if peer.SkipCA {
320+
peerInfo.InsecureSkipVerify = true
321+
peerInfo.ClientCertAuth = false
322+
} else {
323+
err = errors.New("need a full set of cert/key/ca or cert/key/skip-ca for secure etcd peer inter-communication")
324+
return
325+
}
319326
}
320327

321328
return

pkg/manager/config/manager.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ func (srv *ConfigManager) Init(ctx context.Context, addrs []string, cfg config.A
8383
}
8484

8585
var err error
86-
etcdConfig.TLS, err = security.BuildClientTLSConfig(logger, scfg, "frontend")
86+
etcdConfig.TLS, err = security.BuildClientTLSConfig(logger, scfg)
8787
if err != nil {
8888
return errors.Wrapf(err, "create etcd config center error")
8989
}

pkg/manager/namespace/manager.go

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ package namespace
1717

1818
import (
1919
"fmt"
20-
"path/filepath"
20+
"net/http"
2121
"sync"
2222

2323
"github.com/pingcap/TiProxy/lib/config"
@@ -31,18 +31,18 @@ import (
3131
type NamespaceManager struct {
3232
sync.RWMutex
3333
client *clientv3.Client
34+
httpCli *http.Client
3435
logger *zap.Logger
35-
workdir string
36-
keySize int
3736
nsm map[string]*Namespace
3837
}
3938

40-
func NewNamespaceManager(workdir string, keySize int) *NamespaceManager {
41-
return &NamespaceManager{workdir: workdir, keySize: keySize}
39+
func NewNamespaceManager() *NamespaceManager {
40+
return &NamespaceManager{}
4241
}
43-
func (mgr *NamespaceManager) buildNamespace(cfg *config.Namespace, client *clientv3.Client) (*Namespace, error) {
42+
func (mgr *NamespaceManager) buildNamespace(cfg *config.Namespace) (*Namespace, error) {
4443
logger := mgr.logger.With(zap.String("namespace", cfg.Namespace))
45-
rt, err := router.NewScoreBasedRouter(&cfg.Backend, client)
44+
45+
rt, err := router.NewScoreBasedRouter(&cfg.Backend, mgr.client, mgr.httpCli)
4646
if err != nil {
4747
return nil, errors.Errorf("build router error: %w", err)
4848
}
@@ -51,12 +51,12 @@ func (mgr *NamespaceManager) buildNamespace(cfg *config.Namespace, client *clien
5151
router: rt,
5252
}
5353

54-
r.frontendTLS, err = security.BuildServerTLSConfig(logger, cfg.Frontend.Security, filepath.Join(mgr.workdir, r.name), "frontend", mgr.keySize)
54+
r.frontendTLS, err = security.BuildServerTLSConfig(logger, cfg.Frontend.Security)
5555
if err != nil {
5656
return nil, errors.Errorf("build router error: %w", err)
5757
}
5858

59-
r.backendTLS, err = security.BuildClientTLSConfig(logger, cfg.Backend.Security, "backend")
59+
r.backendTLS, err = security.BuildClientTLSConfig(logger, cfg.Backend.Security)
6060
if err != nil {
6161
return nil, errors.Errorf("build router error: %w", err)
6262
}
@@ -78,7 +78,7 @@ func (mgr *NamespaceManager) CommitNamespaces(nss []*config.Namespace, nss_delet
7878
continue
7979
}
8080

81-
ns, err := mgr.buildNamespace(nsc, mgr.client)
81+
ns, err := mgr.buildNamespace(nsc)
8282
if err != nil {
8383
return fmt.Errorf("%w: create namespace error, namespace: %s", err, nsc.Namespace)
8484
}
@@ -91,9 +91,10 @@ func (mgr *NamespaceManager) CommitNamespaces(nss []*config.Namespace, nss_delet
9191
return nil
9292
}
9393

94-
func (mgr *NamespaceManager) Init(logger *zap.Logger, nss []*config.Namespace, client *clientv3.Client) error {
94+
func (mgr *NamespaceManager) Init(logger *zap.Logger, nss []*config.Namespace, client *clientv3.Client, httpCli *http.Client) error {
9595
mgr.Lock()
9696
mgr.client = client
97+
mgr.httpCli = httpCli
9798
mgr.logger = logger
9899
mgr.Unlock()
99100

pkg/manager/router/backend_observer.go

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ type BackendObserver struct {
132132
// All the backend info in the topology, including tombstones.
133133
allBackendInfo map[string]*BackendInfo
134134
client *clientv3.Client
135+
httpCli *http.Client
136+
httpTLS bool
135137
staticAddrs []string
136138
eventReceiver BackendEventReceiver
137139
wg waitgroup.WaitGroup
@@ -149,7 +151,7 @@ func InitEtcdClient(logger *zap.Logger, cfg *config.Config) (*clientv3.Client, e
149151
pdEndpoints := strings.Split(pdAddr, ",")
150152
logConfig := zap.NewProductionConfig()
151153
logConfig.Level = zap.NewAtomicLevelAt(zap.ErrorLevel)
152-
tlsConfig, err := security.BuildClientTLSConfig(logger, cfg.Security.ClusterTLS, "pd")
154+
tlsConfig, err := security.BuildClientTLSConfig(logger, cfg.Security.ClusterTLS)
153155
if err != nil {
154156
return nil, err
155157
}
@@ -182,8 +184,8 @@ func InitEtcdClient(logger *zap.Logger, cfg *config.Config) (*clientv3.Client, e
182184
}
183185

184186
// StartBackendObserver creates a BackendObserver and starts watching.
185-
func StartBackendObserver(eventReceiver BackendEventReceiver, client *clientv3.Client, config *HealthCheckConfig, staticAddrs []string) (*BackendObserver, error) {
186-
bo, err := NewBackendObserver(eventReceiver, client, config, staticAddrs)
187+
func StartBackendObserver(eventReceiver BackendEventReceiver, client *clientv3.Client, httpCli *http.Client, config *HealthCheckConfig, staticAddrs []string) (*BackendObserver, error) {
188+
bo, err := NewBackendObserver(eventReceiver, client, httpCli, config, staticAddrs)
187189
if err != nil {
188190
return nil, err
189191
}
@@ -192,15 +194,24 @@ func StartBackendObserver(eventReceiver BackendEventReceiver, client *clientv3.C
192194
}
193195

194196
// NewBackendObserver creates a BackendObserver.
195-
func NewBackendObserver(eventReceiver BackendEventReceiver, client *clientv3.Client, config *HealthCheckConfig, staticAddrs []string) (*BackendObserver, error) {
197+
func NewBackendObserver(eventReceiver BackendEventReceiver, client *clientv3.Client, httpCli *http.Client, config *HealthCheckConfig, staticAddrs []string) (*BackendObserver, error) {
196198
if client == nil && len(staticAddrs) == 0 {
197199
return nil, ErrNoInstanceToSelect
198200
}
201+
if httpCli == nil {
202+
httpCli = http.DefaultClient
203+
}
204+
httpTLS := false
205+
if v, ok := httpCli.Transport.(*http.Transport); ok && v != nil && v.TLSClientConfig != nil {
206+
httpTLS = true
207+
}
199208
bo := &BackendObserver{
200209
config: config,
201210
curBackendInfo: make(map[string]BackendStatus),
202211
allBackendInfo: make(map[string]*BackendInfo),
203212
client: client,
213+
httpCli: httpCli,
214+
httpTLS: httpTLS,
204215
staticAddrs: staticAddrs,
205216
eventReceiver: eventReceiver,
206217
}
@@ -382,11 +393,15 @@ func (bo *BackendObserver) checkHealth(ctx context.Context, backends map[string]
382393

383394
// When a backend gracefully shut down, the status port returns 500 but the SQL port still accepts
384395
// new connections, so we must check the status port first.
385-
url := fmt.Sprintf("http://%s:%d%s", info.IP, info.StatusPort, statusPathSuffix)
396+
schema := "http"
397+
if bo.httpTLS {
398+
schema = "https"
399+
}
400+
url := fmt.Sprintf("%s://%s:%d%s", schema, info.IP, info.StatusPort, statusPathSuffix)
386401
var resp *http.Response
387402
err := connectWithRetry(func() error {
388403
var err error
389-
if resp, err = http.Get(url); err == nil {
404+
if resp, err = bo.httpCli.Get(url); err == nil {
390405
if err := resp.Body.Close(); err != nil {
391406
logutil.Logger(ctx).Warn("close http response in health check failed", zap.Error(err))
392407
}

0 commit comments

Comments
 (0)