Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 95 additions & 40 deletions cmd/proxy-client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@ import (
"crypto/x509"
"errors"
"log"
"log/slog"
"net/http"
"net/http/httputil"
"net/url"
"os"

"github.com/flashbots/cvm-reverse-proxy/common"
"github.com/flashbots/cvm-reverse-proxy/internal/atls"
"github.com/flashbots/cvm-reverse-proxy/multimeasurements"
"github.com/flashbots/cvm-reverse-proxy/proxy"
"github.com/flashbots/cvm-reverse-proxy/tdx"
"github.com/urfave/cli/v2" // imports as package "cli"
Expand Down Expand Up @@ -65,6 +69,15 @@ var flags []cli.Flag = []cli.Flag{
},
}

type ClientConfig struct {
ListenAddr string `json:"listen_addr"`
TargetAddr string `json:"target_addr"`
ServerMeasurements string `json:"server_measurements"`
VerifyTLS bool `json:"verify_tls"`
TLSCACertificate string `json:"tls_ca_certificate"`
ClientAttestationType string `json:"client_attestation_type"`
}

func main() {
app := &cli.App{
Name: "proxy-client",
Expand All @@ -79,18 +92,18 @@ func main() {
}

func runClient(cCtx *cli.Context) error {
listenAddr := cCtx.String("listen-addr")
targetAddr := cCtx.String("target-addr")
serverMeasurements := cCtx.String("server-measurements")
logJSON := cCtx.Bool("log-json")
logDebug := cCtx.Bool("log-debug")
tdx.SetLogDcapQuote(cCtx.Bool("log-dcap-quote"))

verifyTLS := cCtx.Bool("verify-tls")
config := ClientConfig{
ListenAddr: cCtx.String("listen-addr"),
TargetAddr: cCtx.String("target-addr"),
ServerMeasurements: cCtx.String("server-measurements"),
VerifyTLS: cCtx.Bool("verify-tls"),
TLSCACertificate: cCtx.String("tls-ca-certificate"),
ClientAttestationType: cCtx.String("client-attestation-type"),
}

log := common.SetupLogger(&common.LoggingOpts{
Debug: logDebug,
JSON: logJSON,
Debug: cCtx.Bool("log-debug"),
JSON: cCtx.Bool("log-json"),
Service: "proxy-client",
Version: common.Version,
})
Expand All @@ -99,14 +112,20 @@ func runClient(cCtx *cli.Context) error {
log.Warn("DEPRECATED: --server-attestation-type is deprecated and will be removed in a future version")
}

if serverMeasurements != "" && verifyTLS {
tdx.SetLogDcapQuote(cCtx.Bool("log-dcap-quote"))

return runClientFromConfig(log, config)
}

func runClientFromConfig(log *slog.Logger, config ClientConfig) error {
if config.ServerMeasurements != "" && config.VerifyTLS {
log.Error("invalid combination of --verify-tls and --server-measurements passed (cannot add server measurements and verify default TLS at the same time)")
return errors.New("invalid combination of --verify-tls and --server-measurements passed (cannot add server measurements and verify default TLS at the same time)")
}

clientAttestationType, err := proxy.ParseAttestationType(cCtx.String("client-attestation-type"))
clientAttestationType, err := proxy.ParseAttestationType(config.ClientAttestationType)
if err != nil {
log.With("attestation-type", cCtx.String("client-attestation-type")).Error("invalid client-attestation-type passed, see --help")
log.With("attestation-type", config.ClientAttestationType).Error("invalid client-attestation-type passed, see --help")
return err
}

Expand All @@ -116,49 +135,85 @@ func runClient(cCtx *cli.Context) error {
return err
}

validators, err := proxy.CreateAttestationValidatorsFromFile(log, serverMeasurements)
parsedMeasurements, err := proxy.LoadMeasurementsFromFile(log, config.ServerMeasurements)
if err != nil {
log.Error("could not create attestation validators from file", "err", err)
return err
}

tlsConfig, err := atls.CreateAttestationClientTLSConfig(issuer, validators)
if err != nil {
log.Error("could not create atls config", "err", err)
return err
}

if verifyTLS {
tlsConfig.InsecureSkipVerify = false
tlsConfig.ServerName = ""
}

if additionalTLSCA := cCtx.String("tls-ca-certificate"); additionalTLSCA != "" {
if !verifyTLS {
log.Error("--tls-ca-certificate specified but --verify-tls is not, refusing to continue")
return errors.New("--tls-ca-certificate specified but --verify-tls is not, refusing to continue")
// Maps service tag (id) to list of measurements (ids). Should be passed in separately. For now assume single "default" service.
serviceMeasurements := map[proxy.ServiceTag][]multimeasurements.MeasurementsContainer{proxy.ServiceTag("default"): parsedMeasurements}
serviceValidators := make(map[proxy.ServiceTag][]atls.Validator)
for service, listOfMeasurements := range serviceMeasurements {
validators, err := proxy.CreateAttestationValidatorsFromMeasurements(log, listOfMeasurements)
if err != nil {
return err
}
serviceValidators[service] = validators
}

certData, err := os.ReadFile(additionalTLSCA)
proxyByService := make(map[proxy.ServiceTag]http.HandlerFunc)
for service, validators := range serviceValidators {
tlsConfig, err := atls.CreateAttestationClientTLSConfig(issuer, validators)
if err != nil {
log.Error("could not read tls ca certificate data", "err", err)
log.Error("could not create atls config", "err", err)
return err
}

roots := x509.NewCertPool()
ok := roots.AppendCertsFromPEM(certData)
if !ok {
log.Error("invalid certificate received", "cert", string(certData))
return errors.New("invalid certificate")
if config.VerifyTLS {
tlsConfig.InsecureSkipVerify = false
tlsConfig.ServerName = ""
}

tlsConfig.RootCAs = roots
if additionalTLSCA := config.TLSCACertificate; additionalTLSCA != "" {
if !config.VerifyTLS {
log.Error("--tls-ca-certificate specified but --verify-tls is not, refusing to continue")
return errors.New("--tls-ca-certificate specified but --verify-tls is not, refusing to continue")
}

certData, err := os.ReadFile(additionalTLSCA)
if err != nil {
log.Error("could not read tls ca certificate data", "err", err)
return err
}

roots := x509.NewCertPool()
ok := roots.AppendCertsFromPEM(certData)
if !ok {
log.Error("invalid certificate received", "cert", string(certData))
return errors.New("invalid certificate")
}

tlsConfig.RootCAs = roots
}

var proxyHandler http.HandlerFunc
if config.TargetAddr == "from_header" {
rproxyFactory := func(targetURL *url.URL) *httputil.ReverseProxy {
rproxy := httputil.NewSingleHostReverseProxy(targetURL)
rproxy.Transport = &http.Transport{TLSClientConfig: tlsConfig}
return rproxy
}
proxyHandler = proxy.NewDynamicHostReverseProxyFromHeader(log, validators, rproxyFactory)
} else {
rproxy := proxy.NewSingleHostReverseProxyFromUrl(log, validators, config.TargetAddr)
rproxy.Transport = &http.Transport{TLSClientConfig: tlsConfig}
proxyHandler = proxy.NewProxy(log, rproxy.ServeHTTP, validators).ServeHTTP
}
proxyByService[service] = proxyHandler
}

proxyHandler := proxy.NewProxy(log, targetAddr, validators).WithTransport(&http.Transport{TLSClientConfig: tlsConfig})
var proxyHandler http.Handler
if len(proxyByService) == 1 {
for _, onlyProxy := range proxyByService {
proxyHandler = onlyProxy
}
} else {
proxyHandler = proxy.NewMultiServiceMiddleware(proxyByService)
}

log.With("listenAddr", listenAddr).Info("Starting proxy client")
err = http.ListenAndServe(listenAddr, proxyHandler)
log.With("listenAddr", config.ListenAddr).Info("Starting proxy client")
err = http.ListenAndServe(config.ListenAddr, proxyHandler)
if err != nil {
log.Error("stopping proxy", "server error", err)
return err
Expand Down
62 changes: 39 additions & 23 deletions cmd/proxy-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ var flags []cli.Flag = []cli.Flag{
Usage: "Path to private key file for the certificate. Only valid with --tls-certificate-path",
},
&cli.StringFlag{
Name: "client-attestation-type",
Usage: "Deprecated and not used. Client attestation types are set via the measurements file.",
Name: "client-attestation-type",
Usage: "Deprecated and not used. Client attestation types are set via the measurements file.",
},
&cli.StringFlag{
Name: "client-measurements",
Expand Down Expand Up @@ -101,19 +101,31 @@ func main() {
}
}

type ServerConfig struct {
ListenAddr string `json:"listen_addr"`
ListenAddrHealthcheck string `json:"listen_addr_healthcheck"`
TargetAddr string `json:"target_addr"`
ServerAttestationType string `json:"server_attestation_type"`
TLSCertificatePath string `json:"tls_certificate_path"`
TLSPrivateKeyPath string `json:"tls_private_key_path"`
ClientMeasurements string `json:"client_measurements"`
}

func runServer(cCtx *cli.Context) error {
listenAddr := cCtx.String("listen-addr")
targetAddr := cCtx.String("target-addr")
clientMeasurements := cCtx.String("client-measurements")
config := ServerConfig{
ListenAddr: cCtx.String("listen-addr"),
ListenAddrHealthcheck: cCtx.String("listen-addr-healthcheck"),
TargetAddr: cCtx.String("target-addr"),
ServerAttestationType: cCtx.String("server-attestation-type"),
TLSCertificatePath: cCtx.String("tls-certificate"),
TLSPrivateKeyPath: cCtx.String("tls-private-key"),
ClientMeasurements: cCtx.String("client-measurements"),
}

logJSON := cCtx.Bool("log-json")
logDebug := cCtx.Bool("log-debug")
tdx.SetLogDcapQuote(cCtx.Bool("log-dcap-quote"))

serverAttestationTypeFlag := cCtx.String("server-attestation-type")

certFile := cCtx.String("tls-certificate")
keyFile := cCtx.String("tls-private-key")

log = common.SetupLogger(&common.LoggingOpts{
Debug: logDebug,
JSON: logJSON,
Expand All @@ -125,22 +137,26 @@ func runServer(cCtx *cli.Context) error {
log.Warn("DEPRECATED: --client-attestation-type is deprecated and will be removed in a future version")
}

useRegularTLS := certFile != "" || keyFile != ""
if serverAttestationTypeFlag != "none" && useRegularTLS {
return runServerFromConfig(log, config)
}

func runServerFromConfig(log *slog.Logger, config ServerConfig) error {
useRegularTLS := config.TLSCertificatePath != "" || config.TLSPrivateKeyPath != ""
if config.ServerAttestationType != "none" && useRegularTLS {
return errors.New("invalid combination of --tls-certificate-path, --tls-private-key-path and --server-attestation-type flags passed (only 'none' is allowed)")
}

if useRegularTLS && (certFile == "" || keyFile == "") {
if useRegularTLS && (config.TLSCertificatePath == "" || config.TLSPrivateKeyPath == "") {
return errors.New("not all of --tls-certificate-path and --tls-private-key-path specified")
}

serverAttestationType, err := proxy.ParseAttestationType(serverAttestationTypeFlag)
serverAttestationType, err := proxy.ParseAttestationType(config.ServerAttestationType)
if err != nil {
log.With("attestation-type", cCtx.String("server-attestation-type")).Error("invalid server-attestation-type passed, see --help")
log.With("attestation-type", config.ServerAttestationType).Error("invalid server-attestation-type passed, see --help")
return err
}

validators, err := proxy.CreateAttestationValidatorsFromFile(log, clientMeasurements)
validators, err := proxy.CreateAttestationValidatorsFromFile(log, config.ClientMeasurements)
if err != nil {
log.Error("could not create attestation validators from file", "err", err)
return err
Expand All @@ -152,15 +168,16 @@ func runServer(cCtx *cli.Context) error {
return err
}

proxyHandler := proxy.NewProxy(log, targetAddr, validators)
rproxy := proxy.NewSingleHostReverseProxyFromUrl(log, validators, config.TargetAddr)
proxyHandler := proxy.NewProxy(log, rproxy.ServeHTTP, validators)

confTLS, err := atls.CreateAttestationServerTLSConfig(issuer, validators)
if err != nil {
panic(err)
}

if useRegularTLS {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
cert, err := tls.LoadX509KeyPair(config.TLSCertificatePath, config.TLSPrivateKeyPath)
if err != nil {
log.Error("could not load tls key pair", "err", err)
return err
Expand All @@ -185,7 +202,7 @@ func runServer(cCtx *cli.Context) error {

// Create an HTTP server
server := &http.Server{
Addr: listenAddr,
Addr: config.ListenAddr,
Handler: proxyHandler,
TLSConfig: confTLS,
}
Expand All @@ -212,12 +229,11 @@ func runServer(cCtx *cli.Context) error {
}()

// Start the health check server
listenAddrHealthCheck := cCtx.String("listen-addr-healthcheck")
if listenAddrHealthCheck != "" {
go startHealthCheckServer(listenAddrHealthCheck)
if config.ListenAddrHealthcheck != "" {
go startHealthCheckServer(config.ListenAddrHealthcheck)
}

log.With("listenAddr", listenAddr).Info("Starting proxy server")
log.With("listenAddr", config.ListenAddr).Info("Starting proxy server")
err = server.Serve(tlsListener)
if err != nil {
log.Error("stopping proxy", "server error", err)
Expand Down
Loading
Loading