@@ -29,7 +29,10 @@ import (
2929
3030 "github.com/flashbots/cvm-reverse-proxy/common"
3131 "github.com/flashbots/cvm-reverse-proxy/internal/atls"
32+ azure_tdx "github.com/flashbots/cvm-reverse-proxy/internal/attestation/azure/tdx"
33+ "github.com/flashbots/cvm-reverse-proxy/internal/attestation/measurements"
3234 "github.com/flashbots/cvm-reverse-proxy/internal/attestation/variant"
35+ "github.com/flashbots/cvm-reverse-proxy/internal/config"
3336 "github.com/flashbots/cvm-reverse-proxy/proxy"
3437 "github.com/urfave/cli/v2" // imports as package "cli"
3538)
@@ -50,6 +53,11 @@ var flags []cli.Flag = []cli.Flag{
5053 Value : "" ,
5154 Usage : "Output file for the response payload" ,
5255 },
56+ & cli.StringFlag {
57+ Name : "attestation-type" , // TODO: Add support for other attestation types
58+ Value : string (proxy .AttestationAzureTDX ),
59+ Usage : "type of attestation to present (currently only azure-tdx)" ,
60+ },
5361 & cli.BoolFlag {
5462 Name : "log-debug" ,
5563 Value : false ,
@@ -75,6 +83,7 @@ func runClient(cCtx *cli.Context) (err error) {
7583 addr := cCtx .String ("addr" )
7684 outMeasurements := cCtx .String ("out-measurements" )
7785 outResponse := cCtx .String ("out-response" )
86+ attestationTypeStr := cCtx .String ("attestation-type" )
7887
7988 // Setup logging
8089 log := common .SetupLogger (& common.LoggingOpts {
@@ -88,7 +97,25 @@ func runClient(cCtx *cli.Context) (err error) {
8897 return errors .New ("address needs to start with https://" )
8998 }
9099
91- log .Info ("Getting verified measurements from " + addr + " ..." )
100+ // Create validators based on the attestation type
101+ var validators []atls.Validator
102+ attestationType , err := proxy .ParseAttestationType (attestationTypeStr )
103+ if err != nil {
104+ log .With ("attestation-type" , attestationType ).Error ("invalid attestation-type passed, see --help" )
105+ return err
106+ }
107+
108+ switch attestationType {
109+ case proxy .AttestationAzureTDX :
110+ // Prepare an azure-tdx validator without any required measurements
111+ attConfig := config .DefaultForAzureTDX ()
112+ attConfig .SetMeasurements (measurements.M {})
113+ validator := azure_tdx .NewValidator (attConfig , proxy.AttestationLogger {Log : log })
114+ validators = append (validators , validator )
115+ default :
116+ log .Error ("currently only azure-tdx attestation is supported" )
117+ return errors .New ("currently only azure-tdx attestation is supported" )
118+ }
92119
93120 // Prepare aTLS stuff
94121 issuer , err := proxy .CreateAttestationIssuer (log , proxy .AttestationAzureTDX )
@@ -97,24 +124,27 @@ func runClient(cCtx *cli.Context) (err error) {
97124 return err
98125 }
99126
100- tlsConfig , err := atls .CreateAttestationClientTLSConfig (issuer , []atls.Validator {})
127+ // Create the (a)TLS config
128+ tlsConfig , err := atls .CreateAttestationClientTLSConfig (issuer , validators )
101129 if err != nil {
102130 log .Error ("could not create atls config" , "err" , err )
103131 return err
104132 }
105133
106- tr := & http.Transport {
134+ // Prepare the client
135+ client := & http.Client {Transport : & http.Transport {
107136 TLSClientConfig : tlsConfig ,
108- }
109- client := & http.Client {Transport : tr }
137+ }}
138+
139+ // Execute the GET request
140+ log .Info ("Executing attested GET request to " + addr + " ..." )
110141 resp , err := client .Get (addr )
111142 if err != nil {
112143 return err
113144 }
114- certs := resp .TLS .PeerCertificates
115145
116146 // Extract the aTLS variant and measurements from the TLS connection
117- atlsVariant , extractedMeasurements , err := proxy .GetMeasurementsFromTLS (certs , []asn1.ObjectIdentifier {variant.AzureTDX {}.OID ()})
147+ atlsVariant , extractedMeasurements , err := proxy .GetMeasurementsFromTLS (resp . TLS . PeerCertificates , []asn1.ObjectIdentifier {variant.AzureTDX {}.OID ()})
118148 if err != nil {
119149 log .Error ("Error in getMeasurementsFromTLS" , "err" , err )
120150 return err
@@ -130,11 +160,10 @@ func runClient(cCtx *cli.Context) (err error) {
130160 return errors .New ("could not marshal measurement extracted from tls extension" )
131161 }
132162
133- log .Info ("Variant: " + atlsVariant .String ())
134163 log .Info (fmt .Sprintf ("Measurements for %s with %d entries:" , atlsVariant .String (), len (measurementsInHeaderFormat )))
135164 fmt .Println (string (marshaledPcrs ))
136165 if outMeasurements != "" {
137- if err := os .WriteFile (outMeasurements , marshaledPcrs , 0644 ); err != nil {
166+ if err := os .WriteFile (outMeasurements , marshaledPcrs , 0o644 ); err != nil {
138167 return err
139168 }
140169 }
@@ -148,7 +177,7 @@ func runClient(cCtx *cli.Context) (err error) {
148177 log .Info (fmt .Sprintf ("Response body with %d bytes:" , len (msg )))
149178 fmt .Println (string (msg ))
150179 if outResponse != "" {
151- if err := os .WriteFile (outResponse , msg , 0644 ); err != nil {
180+ if err := os .WriteFile (outResponse , msg , 0o644 ); err != nil {
152181 return err
153182 }
154183 }
0 commit comments