Skip to content

Commit 816dcde

Browse files
authored
CLI tool for attested HTTP GET requests (#25)
* CLI tool to print verified measurements from an aTLS server * use HTTP GET request, and save to file * cleanup * rename to attested-get * some fixes, and adding validator back in
1 parent da9d29a commit 816dcde

File tree

2 files changed

+187
-0
lines changed

2 files changed

+187
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
/measurements.json
44
/build/
55
/quotes/
6+
/builder-cert.pem

cmd/attested-get/main.go

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
package main
2+
3+
//
4+
// Make a HTTP GET request over a TEE-attested connection (to a server with aTLS support),
5+
// and print the verified measurements and the response payload.
6+
//
7+
// Currently only works for Azure TDX but is straight-forward to expand.
8+
//
9+
// Usage:
10+
//
11+
// go run cmd/attested-get/main.go --addr=https://instance_ip:port
12+
//
13+
// Can also save the verified measurements and the response body to files:
14+
//
15+
// go run cmd/attested-get/main.go --addr=https://instance_ip:port --out-measurements=measurements.json --out-response=response.txt
16+
//
17+
18+
import (
19+
"encoding/asn1"
20+
"encoding/hex"
21+
"encoding/json"
22+
"errors"
23+
"fmt"
24+
"io"
25+
"log"
26+
"net/http"
27+
"os"
28+
"strings"
29+
30+
"github.com/flashbots/cvm-reverse-proxy/common"
31+
"github.com/flashbots/cvm-reverse-proxy/internal/atls"
32+
azure_tdx "github.com/flashbots/cvm-reverse-proxy/internal/attestation/azure/tdx"
33+
"github.com/flashbots/cvm-reverse-proxy/internal/attestation/measurements"
34+
"github.com/flashbots/cvm-reverse-proxy/internal/attestation/variant"
35+
"github.com/flashbots/cvm-reverse-proxy/internal/config"
36+
"github.com/flashbots/cvm-reverse-proxy/proxy"
37+
"github.com/urfave/cli/v2" // imports as package "cli"
38+
)
39+
40+
var flags []cli.Flag = []cli.Flag{
41+
&cli.StringFlag{
42+
Name: "addr",
43+
Value: "https://localhost:7936",
44+
Usage: "TEE server address",
45+
},
46+
&cli.StringFlag{
47+
Name: "out-measurements",
48+
Value: "",
49+
Usage: "Output file for the measurements",
50+
},
51+
&cli.StringFlag{
52+
Name: "out-response",
53+
Value: "",
54+
Usage: "Output file for the response payload",
55+
},
56+
&cli.StringFlag{
57+
Name: "attestation-type", // TODO: Add support for other attestation types
58+
Value: string(proxy.AttestationAzureTDX),
59+
Usage: "type of attestation to present (currently only azure-tdx)",
60+
},
61+
&cli.BoolFlag{
62+
Name: "log-debug",
63+
Value: false,
64+
Usage: "log debug messages",
65+
},
66+
}
67+
68+
func main() {
69+
app := &cli.App{
70+
Name: "attested-get",
71+
Usage: "Get verified measurements",
72+
Flags: flags,
73+
Action: runClient,
74+
}
75+
76+
if err := app.Run(os.Args); err != nil {
77+
log.Fatal(err)
78+
}
79+
}
80+
81+
func runClient(cCtx *cli.Context) (err error) {
82+
logDebug := cCtx.Bool("log-debug")
83+
addr := cCtx.String("addr")
84+
outMeasurements := cCtx.String("out-measurements")
85+
outResponse := cCtx.String("out-response")
86+
attestationTypeStr := cCtx.String("attestation-type")
87+
88+
// Setup logging
89+
log := common.SetupLogger(&common.LoggingOpts{
90+
Debug: logDebug,
91+
JSON: false,
92+
Service: "attested-get",
93+
Version: common.Version,
94+
})
95+
96+
if !strings.HasPrefix(addr, "https://") {
97+
return errors.New("address needs to start with https://")
98+
}
99+
100+
// Create validators based on the attestation type
101+
attestationType, err := proxy.ParseAttestationType(attestationTypeStr)
102+
if err != nil {
103+
log.With("attestation-type", attestationType).Error("invalid attestation-type passed, see --help")
104+
return err
105+
}
106+
107+
var validators []atls.Validator
108+
switch attestationType {
109+
case proxy.AttestationAzureTDX:
110+
// Prepare an azure-tdx validator without any required measurements
111+
attConfig := config.DefaultForAzureTDX()
112+
attConfig.SetMeasurements(measurements.M{})
113+
validator := azure_tdx.NewValidator(attConfig, proxy.AttestationLogger{Log: log})
114+
validators = append(validators, validator)
115+
default:
116+
log.Error("currently only azure-tdx attestation is supported")
117+
return errors.New("currently only azure-tdx attestation is supported")
118+
}
119+
120+
// Prepare aTLS stuff
121+
issuer, err := proxy.CreateAttestationIssuer(log, proxy.AttestationAzureTDX)
122+
if err != nil {
123+
log.Error("could not create attestation issuer", "err", err)
124+
return err
125+
}
126+
127+
// Create the (a)TLS config
128+
tlsConfig, err := atls.CreateAttestationClientTLSConfig(issuer, validators)
129+
if err != nil {
130+
log.Error("could not create atls config", "err", err)
131+
return err
132+
}
133+
134+
// Prepare the client
135+
client := &http.Client{Transport: &http.Transport{
136+
TLSClientConfig: tlsConfig,
137+
}}
138+
139+
// Execute the GET request
140+
log.Info("Executing attested GET request to " + addr + " ...")
141+
resp, err := client.Get(addr)
142+
if err != nil {
143+
return err
144+
}
145+
146+
// Extract the aTLS variant and measurements from the TLS connection
147+
atlsVariant, extractedMeasurements, err := proxy.GetMeasurementsFromTLS(resp.TLS.PeerCertificates, []asn1.ObjectIdentifier{variant.AzureTDX{}.OID()})
148+
if err != nil {
149+
log.Error("Error in getMeasurementsFromTLS", "err", err)
150+
return err
151+
}
152+
153+
measurementsInHeaderFormat := make(map[uint32]string, len(extractedMeasurements))
154+
for pcr, value := range extractedMeasurements {
155+
measurementsInHeaderFormat[pcr] = hex.EncodeToString(value)
156+
}
157+
158+
marshaledPcrs, err := json.MarshalIndent(measurementsInHeaderFormat, "", " ")
159+
if err != nil {
160+
return errors.New("could not marshal measurement extracted from tls extension")
161+
}
162+
163+
log.Info(fmt.Sprintf("Measurements for %s with %d entries:", atlsVariant.String(), len(measurementsInHeaderFormat)))
164+
fmt.Println(string(marshaledPcrs))
165+
if outMeasurements != "" {
166+
if err := os.WriteFile(outMeasurements, marshaledPcrs, 0o644); err != nil {
167+
return err
168+
}
169+
}
170+
171+
// Print the response body
172+
msg, err := io.ReadAll(resp.Body)
173+
if err != nil {
174+
return err
175+
}
176+
177+
log.Info(fmt.Sprintf("Response body with %d bytes:", len(msg)))
178+
fmt.Println(string(msg))
179+
if outResponse != "" {
180+
if err := os.WriteFile(outResponse, msg, 0o644); err != nil {
181+
return err
182+
}
183+
}
184+
185+
return nil
186+
}

0 commit comments

Comments
 (0)