11package main
22
33//
4- // CLI tool to get and print verified measurements from an aTLS server.
4+ // Make a HTTP GET request over a TEE-attested connection (to a server with aTLS support),
5+ // and print the verified measurements and the response payload.
56//
67// Currently only works for Azure TDX but should be easy to expand.
78//
89// Usage:
910//
10- // go run cmd/get-measurements/main.go instance_ip:port
11+ // go run cmd/get-measurements/main.go --addr=https://instance_ip:port
12+ //
13+ // Can also save the verified measurements and the response body to files:
14+ //
15+ // go run cmd/get-measurements/main.go --addr=https://instance_ip:port --out-measurements=measurements.json --out-response=response.txt
1116//
1217
1318import (
14- "crypto/tls"
1519 "encoding/asn1"
1620 "encoding/hex"
1721 "encoding/json"
1822 "errors"
1923 "fmt"
24+ "io"
2025 "log"
26+ "net/http"
2127 "os"
28+ "strings"
2229
2330 "github.com/flashbots/cvm-reverse-proxy/common"
2431 "github.com/flashbots/cvm-reverse-proxy/internal/atls"
@@ -28,6 +35,21 @@ import (
2835)
2936
3037var flags []cli.Flag = []cli.Flag {
38+ & cli.StringFlag {
39+ Name : "addr" ,
40+ Value : "https://localhost:7936" ,
41+ Usage : "TEE server address" ,
42+ },
43+ & cli.StringFlag {
44+ Name : "out-measurements" ,
45+ Value : "" ,
46+ Usage : "Output file for the measurements" ,
47+ },
48+ & cli.StringFlag {
49+ Name : "out-response" ,
50+ Value : "" ,
51+ Usage : "Output file for the response payload" ,
52+ },
3153 & cli.BoolFlag {
3254 Name : "log-debug" ,
3355 Value : false ,
@@ -48,8 +70,11 @@ func main() {
4870 }
4971}
5072
51- func runClient (cCtx * cli.Context ) error {
73+ func runClient (cCtx * cli.Context ) ( err error ) {
5274 logDebug := cCtx .Bool ("log-debug" )
75+ addr := cCtx .String ("addr" )
76+ outMeasurements := cCtx .String ("out-measurements" )
77+ outResponse := cCtx .String ("out-response" )
5378
5479 // Setup logging
5580 log := common .SetupLogger (& common.LoggingOpts {
@@ -59,14 +84,11 @@ func runClient(cCtx *cli.Context) error {
5984 Version : common .Version ,
6085 })
6186
62- addr := cCtx .Args ().Get (0 )
63- if addr == "" {
64- log .Error ("Please provide an address as cli argument" )
65- return errors .New ("provide an address as argument" )
87+ if ! strings .HasPrefix (addr , "https://" ) {
88+ return errors .New ("address needs to start with https://" )
6689 }
6790
6891 log .Info ("Getting verified measurements from " + addr + " ..." )
69-
7092 // Prepare aTLS stuff
7193 serverAttestationType := proxy .AttestationAzureTDX
7294 issuer , err := proxy .CreateAttestationIssuer (log , serverAttestationType )
@@ -87,16 +109,18 @@ func runClient(cCtx *cli.Context) error {
87109 return err
88110 }
89111
90- // Open connection to the TDX server and verify the aTLS attestation
91- conn , err := tls .Dial ("tcp" , addr , tlsConfig )
112+ tr := & http.Transport {
113+ TLSClientConfig : tlsConfig ,
114+ }
115+ client := & http.Client {Transport : tr }
116+ resp , err := client .Get (addr )
92117 if err != nil {
93- log .Error ("Error in Dial" , "err" , err )
94118 return err
95119 }
96- defer conn . Close ()
120+ certs := resp . TLS . PeerCertificates
97121
98122 // Extract the aTLS variant and measurements from the TLS connection
99- certs := conn .ConnectionState ().PeerCertificates
123+ // certs := conn.ConnectionState().PeerCertificates
100124 atlsVariant , extractedMeasurements , err := proxy .GetMeasurementsFromTLS (certs , []asn1.ObjectIdentifier {variant.AzureTDX {}.OID ()})
101125 if err != nil {
102126 log .Error ("Error in getMeasurementsFromTLS" , "err" , err )
@@ -114,8 +138,27 @@ func runClient(cCtx *cli.Context) error {
114138 }
115139
116140 log .Info ("Variant: " + atlsVariant .String ())
117- // log.Info("Measurements", "measurements", string(marshaledPcrs ))
141+ log .Info (fmt . Sprintf ( "Measurements for %s with %d entries: " , atlsVariant . String (), len ( measurementsInHeaderFormat ) ))
118142 fmt .Println (string (marshaledPcrs ))
143+ if outMeasurements != "" {
144+ if err := os .WriteFile (outMeasurements , marshaledPcrs , 0644 ); err != nil {
145+ return err
146+ }
147+ }
148+
149+ // Print the response body
150+ msg , err := io .ReadAll (resp .Body )
151+ if err != nil {
152+ return err
153+ }
154+
155+ log .Info (fmt .Sprintf ("Response body with %d bytes:" , len (msg )))
156+ fmt .Println (string (msg ))
157+ if outResponse != "" {
158+ if err := os .WriteFile (outResponse , msg , 0644 ); err != nil {
159+ return err
160+ }
161+ }
119162
120163 return nil
121164}
0 commit comments