Skip to content

Commit 1c36ebb

Browse files
committed
some fixes, and adding validator back in
1 parent e8bf254 commit 1c36ebb

File tree

2 files changed

+40
-10
lines changed

2 files changed

+40
-10
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
/measurements.json
44
/build/
55
/quotes/
6+
/builder-cert.pem

cmd/attested-get/main.go

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ import (
2929

3030
"github.com/flashbots/cvm-reverse-proxy/common"
3131
"github.com/flashbots/cvm-reverse-proxy/internal/atls"
32+
azure_tdx "github.com/flashbots/cvm-reverse-proxy/internal/attestation/azure/tdx"
33+
"github.com/flashbots/cvm-reverse-proxy/internal/attestation/measurements"
3234
"github.com/flashbots/cvm-reverse-proxy/internal/attestation/variant"
35+
"github.com/flashbots/cvm-reverse-proxy/internal/config"
3336
"github.com/flashbots/cvm-reverse-proxy/proxy"
3437
"github.com/urfave/cli/v2" // imports as package "cli"
3538
)
@@ -50,6 +53,11 @@ var flags []cli.Flag = []cli.Flag{
5053
Value: "",
5154
Usage: "Output file for the response payload",
5255
},
56+
&cli.StringFlag{
57+
Name: "attestation-type", // TODO: Add support for other attestation types
58+
Value: string(proxy.AttestationAzureTDX),
59+
Usage: "type of attestation to present (currently only azure-tdx)",
60+
},
5361
&cli.BoolFlag{
5462
Name: "log-debug",
5563
Value: false,
@@ -75,6 +83,7 @@ func runClient(cCtx *cli.Context) (err error) {
7583
addr := cCtx.String("addr")
7684
outMeasurements := cCtx.String("out-measurements")
7785
outResponse := cCtx.String("out-response")
86+
attestationTypeStr := cCtx.String("attestation-type")
7887

7988
// Setup logging
8089
log := common.SetupLogger(&common.LoggingOpts{
@@ -88,7 +97,25 @@ func runClient(cCtx *cli.Context) (err error) {
8897
return errors.New("address needs to start with https://")
8998
}
9099

91-
log.Info("Getting verified measurements from " + addr + " ...")
100+
// Create validators based on the attestation type
101+
var validators []atls.Validator
102+
attestationType, err := proxy.ParseAttestationType(attestationTypeStr)
103+
if err != nil {
104+
log.With("attestation-type", attestationType).Error("invalid attestation-type passed, see --help")
105+
return err
106+
}
107+
108+
switch attestationType {
109+
case proxy.AttestationAzureTDX:
110+
// Prepare an azure-tdx validator without any required measurements
111+
attConfig := config.DefaultForAzureTDX()
112+
attConfig.SetMeasurements(measurements.M{})
113+
validator := azure_tdx.NewValidator(attConfig, proxy.AttestationLogger{Log: log})
114+
validators = append(validators, validator)
115+
default:
116+
log.Error("currently only azure-tdx attestation is supported")
117+
return errors.New("currently only azure-tdx attestation is supported")
118+
}
92119

93120
// Prepare aTLS stuff
94121
issuer, err := proxy.CreateAttestationIssuer(log, proxy.AttestationAzureTDX)
@@ -97,24 +124,27 @@ func runClient(cCtx *cli.Context) (err error) {
97124
return err
98125
}
99126

100-
tlsConfig, err := atls.CreateAttestationClientTLSConfig(issuer, []atls.Validator{})
127+
// Create the (a)TLS config
128+
tlsConfig, err := atls.CreateAttestationClientTLSConfig(issuer, validators)
101129
if err != nil {
102130
log.Error("could not create atls config", "err", err)
103131
return err
104132
}
105133

106-
tr := &http.Transport{
134+
// Prepare the client
135+
client := &http.Client{Transport: &http.Transport{
107136
TLSClientConfig: tlsConfig,
108-
}
109-
client := &http.Client{Transport: tr}
137+
}}
138+
139+
// Execute the GET request
140+
log.Info("Executing attested GET request to " + addr + " ...")
110141
resp, err := client.Get(addr)
111142
if err != nil {
112143
return err
113144
}
114-
certs := resp.TLS.PeerCertificates
115145

116146
// Extract the aTLS variant and measurements from the TLS connection
117-
atlsVariant, extractedMeasurements, err := proxy.GetMeasurementsFromTLS(certs, []asn1.ObjectIdentifier{variant.AzureTDX{}.OID()})
147+
atlsVariant, extractedMeasurements, err := proxy.GetMeasurementsFromTLS(resp.TLS.PeerCertificates, []asn1.ObjectIdentifier{variant.AzureTDX{}.OID()})
118148
if err != nil {
119149
log.Error("Error in getMeasurementsFromTLS", "err", err)
120150
return err
@@ -130,11 +160,10 @@ func runClient(cCtx *cli.Context) (err error) {
130160
return errors.New("could not marshal measurement extracted from tls extension")
131161
}
132162

133-
log.Info("Variant: " + atlsVariant.String())
134163
log.Info(fmt.Sprintf("Measurements for %s with %d entries:", atlsVariant.String(), len(measurementsInHeaderFormat)))
135164
fmt.Println(string(marshaledPcrs))
136165
if outMeasurements != "" {
137-
if err := os.WriteFile(outMeasurements, marshaledPcrs, 0644); err != nil {
166+
if err := os.WriteFile(outMeasurements, marshaledPcrs, 0o644); err != nil {
138167
return err
139168
}
140169
}
@@ -148,7 +177,7 @@ func runClient(cCtx *cli.Context) (err error) {
148177
log.Info(fmt.Sprintf("Response body with %d bytes:", len(msg)))
149178
fmt.Println(string(msg))
150179
if outResponse != "" {
151-
if err := os.WriteFile(outResponse, msg, 0644); err != nil {
180+
if err := os.WriteFile(outResponse, msg, 0o644); err != nil {
152181
return err
153182
}
154183
}

0 commit comments

Comments
 (0)