-
Notifications
You must be signed in to change notification settings - Fork 2.4k
/
cert_watcher.go
140 lines (122 loc) · 3.71 KB
/
cert_watcher.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
// Copyright (c) 2020 The Jaeger Authors.
// SPDX-License-Identifier: Apache-2.0
package tlscfg
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"path/filepath"
"sync"
"go.uber.org/zap"
"github.com/jaegertracing/jaeger/pkg/fswatcher"
)
const (
logMsgPairReloaded = "Reloaded modified key pair"
logMsgCertReloaded = "Reloaded modified certificate"
logMsgPairNotReloaded = "Failed to reload key pair, using previous versions"
logMsgCertNotReloaded = "Failed to reload certificate, using previous version"
)
// certWatcher watches filesystem changes on certificates supplied via Options
// The changed RootCAs and ClientCAs certificates are added to x509.CertPool without invalidating the previously used certificate.
// The certificate and key can be obtained via certWatcher.certificate.
// The consumers of this API should use GetCertificate or GetClientCertificate from tls.Config to supply the certificate to the config.
type certWatcher struct {
mu sync.RWMutex
opts Options
logger *zap.Logger
watchers []*fswatcher.FSWatcher
cert *tls.Certificate
}
var _ io.Closer = (*certWatcher)(nil)
func newCertWatcher(opts Options, logger *zap.Logger, rootCAs, clientCAs *x509.CertPool) (*certWatcher, error) {
var cert *tls.Certificate
if opts.CertPath != "" && opts.KeyPath != "" {
// load certs at startup to catch missing certs error early
c, err := tls.LoadX509KeyPair(filepath.Clean(opts.CertPath), filepath.Clean(opts.KeyPath))
if err != nil {
return nil, fmt.Errorf("failed to load server TLS cert and key: %w", err)
}
cert = &c
}
w := &certWatcher{
opts: opts,
logger: logger,
cert: cert,
}
if err := w.watchCertPair(); err != nil {
return nil, err
}
if err := w.watchCert(w.opts.CAPath, rootCAs); err != nil {
return nil, err
}
if err := w.watchCert(w.opts.ClientCAPath, clientCAs); err != nil {
return nil, err
}
return w, nil
}
func (w *certWatcher) Close() error {
var errs []error
for _, w := range w.watchers {
errs = append(errs, w.Close())
}
return errors.Join(errs...)
}
func (w *certWatcher) certificate() *tls.Certificate {
w.mu.RLock()
defer w.mu.RUnlock()
return w.cert
}
func (w *certWatcher) watchCertPair() error {
watcher, err := fswatcher.New(
[]string{w.opts.CertPath, w.opts.KeyPath},
w.onCertPairChange,
w.logger,
)
if err == nil {
w.watchers = append(w.watchers, watcher)
return nil
}
w.Close()
return fmt.Errorf("failed to watch key pair %s and %s: %w", w.opts.KeyPath, w.opts.CertPath, err)
}
func (w *certWatcher) watchCert(certPath string, certPool *x509.CertPool) error {
onCertChange := func() { w.onCertChange(certPath, certPool) }
watcher, err := fswatcher.New([]string{certPath}, onCertChange, w.logger)
if err == nil {
w.watchers = append(w.watchers, watcher)
return nil
}
w.Close()
return fmt.Errorf("failed to watch cert %s: %w", certPath, err)
}
func (w *certWatcher) onCertPairChange() {
cert, err := tls.LoadX509KeyPair(filepath.Clean(w.opts.CertPath), filepath.Clean(w.opts.KeyPath))
if err == nil {
w.mu.Lock()
w.cert = &cert
w.mu.Unlock()
w.logger.Info(
logMsgPairReloaded,
zap.String("key", w.opts.KeyPath),
zap.String("cert", w.opts.CertPath),
)
} else {
w.logger.Error(
logMsgPairNotReloaded,
zap.String("key", w.opts.KeyPath),
zap.String("cert", w.opts.CertPath),
zap.Error(err),
)
}
}
func (w *certWatcher) onCertChange(certPath string, certPool *x509.CertPool) {
w.mu.Lock() // prevent concurrent updates to the same certPool
if err := addCertToPool(certPath, certPool); err == nil {
w.logger.Info(logMsgCertReloaded, zap.String("cert", certPath))
} else {
w.logger.Error(logMsgCertNotReloaded, zap.String("cert", certPath), zap.Error(err))
}
w.mu.Unlock()
}