Skip to content

Commit d97a297

Browse files
committed
use HTTP GET request, and save to file
1 parent 607ce3c commit d97a297

File tree

1 file changed

+58
-15
lines changed

1 file changed

+58
-15
lines changed

cmd/get-measurements/main.go

Lines changed: 58 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,31 @@
11
package main
22

33
//
4-
// CLI tool to get and print verified measurements from an aTLS server.
4+
// Make a HTTP GET request over a TEE-attested connection (to a server with aTLS support),
5+
// and print the verified measurements and the response payload.
56
//
67
// Currently only works for Azure TDX but should be easy to expand.
78
//
89
// Usage:
910
//
10-
// go run cmd/get-measurements/main.go instance_ip:port
11+
// go run cmd/get-measurements/main.go --addr=https://instance_ip:port
12+
//
13+
// Can also save the verified measurements and the response body to files:
14+
//
15+
// go run cmd/get-measurements/main.go --addr=https://instance_ip:port --out-measurements=measurements.json --out-response=response.txt
1116
//
1217

1318
import (
14-
"crypto/tls"
1519
"encoding/asn1"
1620
"encoding/hex"
1721
"encoding/json"
1822
"errors"
1923
"fmt"
24+
"io"
2025
"log"
26+
"net/http"
2127
"os"
28+
"strings"
2229

2330
"github.com/flashbots/cvm-reverse-proxy/common"
2431
"github.com/flashbots/cvm-reverse-proxy/internal/atls"
@@ -28,6 +35,21 @@ import (
2835
)
2936

3037
var flags []cli.Flag = []cli.Flag{
38+
&cli.StringFlag{
39+
Name: "addr",
40+
Value: "https://localhost:7936",
41+
Usage: "TEE server address",
42+
},
43+
&cli.StringFlag{
44+
Name: "out-measurements",
45+
Value: "",
46+
Usage: "Output file for the measurements",
47+
},
48+
&cli.StringFlag{
49+
Name: "out-response",
50+
Value: "",
51+
Usage: "Output file for the response payload",
52+
},
3153
&cli.BoolFlag{
3254
Name: "log-debug",
3355
Value: false,
@@ -48,8 +70,11 @@ func main() {
4870
}
4971
}
5072

51-
func runClient(cCtx *cli.Context) error {
73+
func runClient(cCtx *cli.Context) (err error) {
5274
logDebug := cCtx.Bool("log-debug")
75+
addr := cCtx.String("addr")
76+
outMeasurements := cCtx.String("out-measurements")
77+
outResponse := cCtx.String("out-response")
5378

5479
// Setup logging
5580
log := common.SetupLogger(&common.LoggingOpts{
@@ -59,14 +84,11 @@ func runClient(cCtx *cli.Context) error {
5984
Version: common.Version,
6085
})
6186

62-
addr := cCtx.Args().Get(0)
63-
if addr == "" {
64-
log.Error("Please provide an address as cli argument")
65-
return errors.New("provide an address as argument")
87+
if !strings.HasPrefix(addr, "https://") {
88+
return errors.New("address needs to start with https://")
6689
}
6790

6891
log.Info("Getting verified measurements from " + addr + " ...")
69-
7092
// Prepare aTLS stuff
7193
serverAttestationType := proxy.AttestationAzureTDX
7294
issuer, err := proxy.CreateAttestationIssuer(log, serverAttestationType)
@@ -87,16 +109,18 @@ func runClient(cCtx *cli.Context) error {
87109
return err
88110
}
89111

90-
// Open connection to the TDX server and verify the aTLS attestation
91-
conn, err := tls.Dial("tcp", addr, tlsConfig)
112+
tr := &http.Transport{
113+
TLSClientConfig: tlsConfig,
114+
}
115+
client := &http.Client{Transport: tr}
116+
resp, err := client.Get(addr)
92117
if err != nil {
93-
log.Error("Error in Dial", "err", err)
94118
return err
95119
}
96-
defer conn.Close()
120+
certs := resp.TLS.PeerCertificates
97121

98122
// Extract the aTLS variant and measurements from the TLS connection
99-
certs := conn.ConnectionState().PeerCertificates
123+
// certs := conn.ConnectionState().PeerCertificates
100124
atlsVariant, extractedMeasurements, err := proxy.GetMeasurementsFromTLS(certs, []asn1.ObjectIdentifier{variant.AzureTDX{}.OID()})
101125
if err != nil {
102126
log.Error("Error in getMeasurementsFromTLS", "err", err)
@@ -114,8 +138,27 @@ func runClient(cCtx *cli.Context) error {
114138
}
115139

116140
log.Info("Variant: " + atlsVariant.String())
117-
// log.Info("Measurements", "measurements", string(marshaledPcrs))
141+
log.Info(fmt.Sprintf("Measurements for %s with %d entries:", atlsVariant.String(), len(measurementsInHeaderFormat)))
118142
fmt.Println(string(marshaledPcrs))
143+
if outMeasurements != "" {
144+
if err := os.WriteFile(outMeasurements, marshaledPcrs, 0644); err != nil {
145+
return err
146+
}
147+
}
148+
149+
// Print the response body
150+
msg, err := io.ReadAll(resp.Body)
151+
if err != nil {
152+
return err
153+
}
154+
155+
log.Info(fmt.Sprintf("Response body with %d bytes:", len(msg)))
156+
fmt.Println(string(msg))
157+
if outResponse != "" {
158+
if err := os.WriteFile(outResponse, msg, 0644); err != nil {
159+
return err
160+
}
161+
}
119162

120163
return nil
121164
}

0 commit comments

Comments
 (0)