Skip to content

Commit 1d40ba8

Browse files
progress save on certreloader + working local tests
Signed-off-by: Kevin Schoonover <me@kschoon.me>
1 parent febbbd3 commit 1d40ba8

File tree

6 files changed

+430
-58
lines changed

6 files changed

+430
-58
lines changed

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
FROM squidfunk/mkdocs-material:9.5
2-
RUN pip install mkdocs-include-markdown-plugin
2+
RUN pip install mkdocs-include-markdown-plugin

core/pkg/telemetry/builder.go

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"connectrpc.com/connect"
1212
"connectrpc.com/otelconnect"
1313
"github.com/open-feature/flagd/core/pkg/logger"
14+
"github.com/open-feature/flagd/flagd/pkg/certreloader"
1415
"go.opentelemetry.io/otel"
1516
"go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc"
1617
"go.opentelemetry.io/otel/exporters/otlp/otlptrace"
@@ -33,10 +34,11 @@ const (
3334
)
3435

3536
type CollectorConfig struct {
36-
Target string
37-
CertPath string
38-
KeyPath string
39-
CAPath string
37+
Target string
38+
CertPath string
39+
KeyPath string
40+
ReloadInterval time.Duration
41+
CAPath string
4042
}
4143

4244
// Config of the telemetry runtime. These are expected to be mapped to start-up arguments
@@ -132,15 +134,20 @@ func buildTransportCredentials(_ context.Context, cfg CollectorConfig) (credenti
132134
}
133135
}
134136

137+
reloader, err := certreloader.NewCertReloader(certreloader.Config{
138+
KeyPath: cfg.KeyPath,
139+
CertPath: cfg.CertPath,
140+
ReloadInterval: time.Minute * 5,
141+
})
142+
if err != nil {
143+
return nil, fmt.Errorf("failed to create certreloader: %w", err)
144+
}
145+
135146
tlsConfig := &tls.Config{
136-
RootCAs: capool,
147+
RootCAs: capool,
148+
MinVersion: tls.VersionTLS13,
137149
GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
138-
newCert, err := tls.LoadX509KeyPair(cfg.CertPath, cfg.KeyPath)
139-
if err != nil {
140-
return nil, err
141-
}
142-
143-
return &newCert, err
150+
return reloader.GetCertificate()
144151
},
145152
}
146153

flagd/cmd/start.go

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"fmt"
55
"log"
66
"strings"
7+
"time"
78

89
"github.com/open-feature/flagd/core/pkg/logger"
910
"github.com/open-feature/flagd/core/pkg/sync"
@@ -16,22 +17,23 @@ import (
1617
)
1718

1819
const (
19-
corsFlagName = "cors-origin"
20-
logFormatFlagName = "log-format"
21-
managementPortFlagName = "management-port"
22-
metricsExporter = "metrics-exporter"
23-
ofrepPortFlagName = "ofrep-port"
24-
otelCollectorURI = "otel-collector-uri"
25-
otelCertPathFlagName = "otel-cert-path"
26-
otelKeyPathFlagName = "otel-key-path"
27-
otelCAPathFlagName = "otel-ca-path"
28-
portFlagName = "port"
29-
serverCertPathFlagName = "server-cert-path"
30-
serverKeyPathFlagName = "server-key-path"
31-
socketPathFlagName = "socket-path"
32-
sourcesFlagName = "sources"
33-
syncPortFlagName = "sync-port"
34-
uriFlagName = "uri"
20+
corsFlagName = "cors-origin"
21+
logFormatFlagName = "log-format"
22+
managementPortFlagName = "management-port"
23+
metricsExporter = "metrics-exporter"
24+
ofrepPortFlagName = "ofrep-port"
25+
otelCollectorURI = "otel-collector-uri"
26+
otelCertPathFlagName = "otel-cert-path"
27+
otelKeyPathFlagName = "otel-key-path"
28+
otelCAPathFlagName = "otel-ca-path"
29+
otelReloadIntervalFlagName = "otel-reload-interval"
30+
portFlagName = "port"
31+
serverCertPathFlagName = "server-cert-path"
32+
serverKeyPathFlagName = "server-key-path"
33+
socketPathFlagName = "socket-path"
34+
sourcesFlagName = "sources"
35+
syncPortFlagName = "sync-port"
36+
uriFlagName = "uri"
3537
)
3638

3739
func init() {
@@ -73,6 +75,7 @@ func init() {
7375
flags.StringP(otelCertPathFlagName, "D", "", "tls certificate path to use with OpenTelemetry collector")
7476
flags.StringP(otelKeyPathFlagName, "K", "", "tls key path to use with OpenTelemetry collector")
7577
flags.StringP(otelCAPathFlagName, "A", "", "tls certificate authority path to use with OpenTelemetry collector")
78+
flags.DurationP(otelReloadIntervalFlagName, "I", time.Hour, "how long between reloading the otel tls certificate from disk")
7679

7780
_ = viper.BindPFlag(corsFlagName, flags.Lookup(corsFlagName))
7881
_ = viper.BindPFlag(logFormatFlagName, flags.Lookup(logFormatFlagName))
@@ -136,20 +139,21 @@ var startCmd = &cobra.Command{
136139

137140
// Build Runtime -----------------------------------------------------------
138141
rt, err := runtime.FromConfig(logger, Version, runtime.Config{
139-
CORS: viper.GetStringSlice(corsFlagName),
140-
MetricExporter: viper.GetString(metricsExporter),
141-
ManagementPort: viper.GetUint16(managementPortFlagName),
142-
OfrepServicePort: viper.GetUint16(ofrepPortFlagName),
143-
OtelCollectorURI: viper.GetString(otelCollectorURI),
144-
OtelCertPath: viper.GetString(otelCertPathFlagName),
145-
OtelKeyPath: viper.GetString(otelKeyPathFlagName),
146-
OtelCAPath: viper.GetString(otelCAPathFlagName),
147-
ServiceCertPath: viper.GetString(serverCertPathFlagName),
148-
ServiceKeyPath: viper.GetString(serverKeyPathFlagName),
149-
ServicePort: viper.GetUint16(portFlagName),
150-
ServiceSocketPath: viper.GetString(socketPathFlagName),
151-
SyncServicePort: viper.GetUint16(syncPortFlagName),
152-
SyncProviders: syncProviders,
142+
CORS: viper.GetStringSlice(corsFlagName),
143+
MetricExporter: viper.GetString(metricsExporter),
144+
ManagementPort: viper.GetUint16(managementPortFlagName),
145+
OfrepServicePort: viper.GetUint16(ofrepPortFlagName),
146+
OtelCollectorURI: viper.GetString(otelCollectorURI),
147+
OtelCertPath: viper.GetString(otelCertPathFlagName),
148+
OtelKeyPath: viper.GetString(otelKeyPathFlagName),
149+
OtelReloadInterval: viper.GetDuration(otelReloadIntervalFlagName),
150+
OtelCAPath: viper.GetString(otelCAPathFlagName),
151+
ServiceCertPath: viper.GetString(serverCertPathFlagName),
152+
ServiceKeyPath: viper.GetString(serverKeyPathFlagName),
153+
ServicePort: viper.GetUint16(portFlagName),
154+
ServiceSocketPath: viper.GetString(socketPathFlagName),
155+
SyncServicePort: viper.GetUint16(syncPortFlagName),
156+
SyncProviders: syncProviders,
153157
})
154158
if err != nil {
155159
rtLogger.Fatal(err.Error())
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package certreloader
2+
3+
import (
4+
"crypto/tls"
5+
"fmt"
6+
"sync"
7+
"time"
8+
)
9+
10+
type Config struct {
11+
KeyPath string
12+
CertPath string
13+
ReloadInterval time.Duration
14+
}
15+
16+
type certReloader struct {
17+
cert *tls.Certificate
18+
mu sync.RWMutex
19+
nextReload time.Time
20+
Config
21+
}
22+
23+
func NewCertReloader(config Config) (*certReloader, error) {
24+
reloader := certReloader{
25+
Config: config,
26+
}
27+
28+
reloader.mu.Lock()
29+
defer reloader.mu.Unlock()
30+
cert, err := reloader.loadCertificate()
31+
if err != nil {
32+
return nil, fmt.Errorf("failed to load initial certificate: %w", err)
33+
}
34+
reloader.cert = &cert
35+
36+
return &reloader, nil
37+
}
38+
39+
func (r *certReloader) GetCertificate() (*tls.Certificate, error) {
40+
now := time.Now()
41+
// Read locking here before we do the time comparison
42+
// If a reload is in progress this will block and we will skip reloading in the current
43+
// call once we can continue
44+
r.mu.RLock()
45+
shouldReload := r.ReloadInterval != 0 && r.nextReload.Before(now)
46+
r.mu.RUnlock()
47+
if shouldReload {
48+
// Need to release the read lock, otherwise we deadlock
49+
r.mu.Lock()
50+
defer r.mu.Unlock()
51+
cert, err := r.loadCertificate()
52+
if err != nil {
53+
return nil, fmt.Errorf("failed to load TLS cert and key: %w", err)
54+
}
55+
r.cert = &cert
56+
r.nextReload = now.Add(r.ReloadInterval)
57+
return r.cert, nil
58+
}
59+
return r.cert, nil
60+
}
61+
62+
func (c *certReloader) loadCertificate() (tls.Certificate, error) {
63+
newCert, err := tls.LoadX509KeyPair(c.CertPath, c.KeyPath)
64+
if err != nil {
65+
return tls.Certificate{}, fmt.Errorf("failed to load key pair: %w", err)
66+
}
67+
68+
return newCert, nil
69+
}

0 commit comments

Comments
 (0)