Skip to content

Commit 8f31766

Browse files
committed
Adds an option to serve and verify regular TLS
1 parent 8ff4c2f commit 8f31766

File tree

2 files changed

+64
-3
lines changed

2 files changed

+64
-3
lines changed

cmd/proxy-client/main.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
package main
22

33
import (
4+
"errors"
45
"log"
56
"net/http"
67
"os"
78

89
"github.com/flashbots/cvm-reverse-proxy/common"
910
"github.com/flashbots/cvm-reverse-proxy/internal/atls"
1011
"github.com/flashbots/cvm-reverse-proxy/proxy"
11-
1212
"github.com/urfave/cli/v2" // imports as package "cli"
1313
)
1414

@@ -28,6 +28,11 @@ var flags []cli.Flag = []cli.Flag{
2828
Value: string(proxy.AttestationAzureTDX),
2929
Usage: "type of attestation to expect and verify (" + proxy.AvailableAttestationTypes + ")",
3030
},
31+
&cli.BoolFlag{
32+
Name: "verify-tls",
33+
Value: false,
34+
Usage: "verify server's TLS certificate instead of server's attestation. Only valid for server-attestation-type=none.",
35+
},
3136
&cli.StringFlag{
3237
Name: "server-measurements",
3338
Usage: "optional path to JSON measurements enforced on the server",
@@ -76,6 +81,11 @@ func runClient(cCtx *cli.Context) error {
7681
Version: common.Version,
7782
})
7883

84+
if cCtx.String("server-attestation-type") != "none" && cCtx.Bool("verify-tls") {
85+
log.Error("invalid combination of --verify-tls and --server-attestation-type passed (only 'none' is allowed)")
86+
return errors.New("invalid combination of --verify-tls and --server-attestation-type passed (only 'none' is allowed)")
87+
}
88+
7989
clientAttestationType, err := proxy.ParseAttestationType(cCtx.String("client-attestation-type"))
8090
if err != nil {
8191
log.With("attestation-type", cCtx.String("client-attestation-type")).Error("invalid client-attestation-type passed, see --help")
@@ -106,6 +116,11 @@ func runClient(cCtx *cli.Context) error {
106116
return err
107117
}
108118

119+
if cCtx.Bool("verify-tls") {
120+
tlsConfig.InsecureSkipVerify = false
121+
tlsConfig.VerifyPeerCertificate = nil // TODO: make sure this is needed
122+
}
123+
109124
proxyHandler := proxy.NewProxy(targetAddr, validators).WithTransport(&http.Transport{TLSClientConfig: tlsConfig})
110125

111126
log.With("listenAddr", listenAddr).Info("about to start proxy")

cmd/proxy-server/main.go

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package main
33
import (
44
"context"
55
"crypto/tls"
6+
"errors"
67
"log"
78
"net/http"
89
"os"
@@ -13,7 +14,6 @@ import (
1314
"github.com/flashbots/cvm-reverse-proxy/common"
1415
"github.com/flashbots/cvm-reverse-proxy/internal/atls"
1516
"github.com/flashbots/cvm-reverse-proxy/proxy"
16-
1717
"github.com/urfave/cli/v2" // imports as package "cli"
1818
)
1919

@@ -38,6 +38,14 @@ var flags []cli.Flag = []cli.Flag{
3838
Value: string(proxy.AttestationNone),
3939
Usage: "type of attestation to expect and verify (" + proxy.AvailableAttestationTypes + ")",
4040
},
41+
&cli.StringFlag{
42+
Name: "tls-certificate",
43+
Usage: "Certificate to present (PEM)",
44+
},
45+
&cli.StringFlag{
46+
Name: "tls-private-key",
47+
Usage: "Private key for the certificate (PEM)",
48+
},
4149
&cli.StringFlag{
4250
Name: "client-measurements",
4351
Usage: "optional path to JSON measurements enforced on the client",
@@ -73,6 +81,10 @@ func runServer(cCtx *cli.Context) error {
7381
clientMeasurements := cCtx.String("client-measurements")
7482
logJSON := cCtx.Bool("log-json")
7583
logDebug := cCtx.Bool("log-debug")
84+
serverAttestationTypeFlag := cCtx.String("server-attestation-type")
85+
86+
certFile := cCtx.String("tls-certificate")
87+
keyFile := cCtx.String("tls-private-key")
7688

7789
log := common.SetupLogger(&common.LoggingOpts{
7890
Debug: logDebug,
@@ -81,7 +93,18 @@ func runServer(cCtx *cli.Context) error {
8193
Version: common.Version,
8294
})
8395

84-
serverAttestationType, err := proxy.ParseAttestationType(cCtx.String("server-attestation-type"))
96+
useRegularTLS := certFile != "" || keyFile != ""
97+
if serverAttestationTypeFlag != "none" && useRegularTLS {
98+
log.Error("invalid combination of --tls-certificate, --tls-private-key and --server-attestation-type flags passed (only 'none' is allowed)")
99+
return errors.New("invalid combination of --tls-certificate, --tls-private-key and --server-attestation-type flags passed (only 'none' is allowed)")
100+
}
101+
102+
if useRegularTLS && (certFile == "" || keyFile == "") {
103+
log.Error("not all of --tls-certificate and --tls-private-key specified")
104+
return errors.New("not all of --tls-certificate and --tls-private-key specified")
105+
}
106+
107+
serverAttestationType, err := proxy.ParseAttestationType(serverAttestationTypeFlag)
85108
if err != nil {
86109
log.With("attestation-type", cCtx.String("server-attestation-type")).Error("invalid server-attestation-type passed, see --help")
87110
return err
@@ -112,6 +135,29 @@ func runServer(cCtx *cli.Context) error {
112135
panic(err)
113136
}
114137

138+
if useRegularTLS {
139+
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
140+
if err != nil {
141+
log.Error("could not load tls key pair", "err", err)
142+
return err
143+
}
144+
145+
originalGetConfigForClient := confTLS.GetConfigForClient
146+
confTLS.GetConfigForClient = func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
147+
ogClientConfig, err := originalGetConfigForClient(clientHello)
148+
if err != nil {
149+
return ogClientConfig, err
150+
}
151+
152+
// Note: we don't have to copy the certificate because it's always created per request
153+
ogClientConfig.Certificates = []tls.Certificate{cert}
154+
ogClientConfig.GetClientCertificate = nil
155+
ogClientConfig.ServerName = ""
156+
return ogClientConfig, nil
157+
}
158+
159+
}
160+
115161
// Create an HTTP server
116162
server := &http.Server{
117163
Addr: listenAddr,

0 commit comments

Comments
 (0)