Skip to content

Commit d4c4207

Browse files
author
fnerdman
committed
feat: fetches validator attestation type from measurments, removes flag
1 parent e36bf5b commit d4c4207

File tree

3 files changed

+39
-40
lines changed

3 files changed

+39
-40
lines changed

cmd/proxy-client/main.go

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,6 @@ var flags []cli.Flag = []cli.Flag{
2525
Value: "https://localhost:80",
2626
Usage: "address to proxy requests to",
2727
},
28-
&cli.StringFlag{
29-
Name: "server-attestation-type",
30-
Value: string(proxy.AttestationAzureTDX),
31-
Usage: "type of attestation to expect and verify (" + proxy.AvailableAttestationTypes + ")",
32-
},
3328
&cli.StringFlag{
3429
Name: "server-measurements",
3530
Usage: "optional path to JSON measurements enforced on the server",
@@ -96,9 +91,9 @@ func runClient(cCtx *cli.Context) error {
9691
Version: common.Version,
9792
})
9893

99-
if cCtx.String("server-attestation-type") != "none" && verifyTLS {
100-
log.Error("invalid combination of --verify-tls and --server-attestation-type passed (only 'none' is allowed)")
101-
return errors.New("invalid combination of --verify-tls and --server-attestation-type passed (only 'none' is allowed)")
94+
if serverMeasurements != nil && verifyTLS {
95+
log.Error("invalid combination of --verify-tls and --server-measurements passed (cannot add server measurements and verify default TLS at the same time)")
96+
return errors.New("invalid combination of --verify-tls and --server-measurements passed (cannot add server measurements and verify default TLS at the same time)")
10297
}
10398

10499
// Auto-detect client attestation type if not specified
@@ -126,9 +121,9 @@ func runClient(cCtx *cli.Context) error {
126121
return err
127122
}
128123

129-
validators, err := proxy.CreateAttestationValidators(log, serverAttestationType, serverMeasurements)
124+
validators, err = proxy.CreateAttestationValidatorsFromFile(log, serverMeasurements)
130125
if err != nil {
131-
log.Error("could not create attestation validators", "err", err)
126+
log.Error("could not create attestation validators from file", "err", err)
132127
return err
133128
}
134129

cmd/proxy-server/main.go

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,6 @@ var flags []cli.Flag = []cli.Flag{
5353
EnvVars: []string{"TLS_PRIVATE_KEY_PATH"},
5454
Usage: "Path to private key file for the certificate. Only valid with --tls-certificate-path",
5555
},
56-
&cli.StringFlag{
57-
Name: "client-attestation-type",
58-
EnvVars: []string{"CLIENT_ATTESTATION_TYPE"},
59-
Value: string(proxy.AttestationNone),
60-
Usage: "type of attestation to expect and verify (" + proxy.AvailableAttestationTypes + ")",
61-
},
6256
&cli.StringFlag{
6357
Name: "client-measurements",
6458
EnvVars: []string{"CLIENT_MEASUREMENTS"},
@@ -145,15 +139,9 @@ func runServer(cCtx *cli.Context) error {
145139
}
146140
}
147141

148-
clientAttestationType, err := proxy.ParseAttestationType(cCtx.String("client-attestation-type"))
149-
if err != nil {
150-
log.With("attestation-type", cCtx.String("client-attestation-type")).Error("invalid client-attestation-type passed, see --help")
151-
return err
152-
}
153-
154-
validators, err := proxy.CreateAttestationValidators(log, clientAttestationType, clientMeasurements)
142+
validators, err = proxy.CreateAttestationValidatorsFromFile(log, clientMeasurements)
155143
if err != nil {
156-
log.Error("could not create attestation validators", "err", err)
144+
log.Error("could not create attestation validators from file", "err", err)
157145
return err
158146
}
159147

proxy/atls_config.go

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ func CreateAttestationIssuer(log *slog.Logger, attestationType AttestationType)
7575
}
7676
}
7777

78-
func CreateAttestationValidators(log *slog.Logger, attestationType AttestationType, jsonMeasurementsPath string) ([]atls.Validator, error) {
79-
if attestationType == AttestationNone {
78+
func CreateAttestationValidatorsFromFile(log *slog.Logger, jsonMeasurementsPath string) ([]atls.Validator, error) {
79+
if jsonMeasurementsPath == "" {
8080
return nil, nil
8181
}
8282

@@ -91,26 +91,42 @@ func CreateAttestationValidators(log *slog.Logger, attestationType AttestationTy
9191
return nil, err
9292
}
9393

94-
switch attestationType {
95-
case AttestationAzureTDX:
96-
validators := []atls.Validator{}
97-
for _, measurement := range parsedMeasurements {
94+
// Group validators by attestation type
95+
validatorsByType := make(map[AttestationType][]atls.Validator)
96+
97+
for _, measurement := range parsedMeasurements {
98+
attestationType, err := ParseAttestationType(measurement.AttestationType)
99+
if err != nil {
100+
return nil, fmt.Errorf("invalid attestation type %s in measurements file", measurement.AttestationType)
101+
}
102+
103+
switch attestationType {
104+
case AttestationAzureTDX:
98105
attConfig := config.DefaultForAzureTDX()
99106
attConfig.SetMeasurements(measurement.Measurements)
100-
validators = append(validators, azure_tdx.NewValidator(attConfig, AttestationLogger{Log: log}))
101-
}
102-
return []atls.Validator{NewMultiValidator(validators)}, nil
103-
case AttestationDCAPTDX:
104-
validators := []atls.Validator{}
105-
for _, measurement := range parsedMeasurements {
107+
validatorsByType[attestationType] = append(
108+
validatorsByType[attestationType],
109+
azure_tdx.NewValidator(attConfig, AttestationLogger{Log: log}),
110+
)
111+
case AttestationDCAPTDX:
106112
attConfig := &config.QEMUTDX{Measurements: measurements.DefaultsFor(cloudprovider.QEMU, variant.QEMUTDX{})}
107113
attConfig.SetMeasurements(measurement.Measurements)
108-
validators = append(validators, dcap_tdx.NewValidator(attConfig, AttestationLogger{Log: log}))
114+
validatorsByType[attestationType] = append(
115+
validatorsByType[attestationType],
116+
dcap_tdx.NewValidator(attConfig, AttestationLogger{Log: log}),
117+
)
118+
default:
119+
return nil, fmt.Errorf("unsupported attestation type %s in measurements file", measurement.AttestationType)
109120
}
110-
return []atls.Validator{NewMultiValidator(validators)}, nil
111-
default:
112-
return nil, errors.New("invalid attestation-type passed in")
113121
}
122+
123+
// Create a MultiValidator for each attestation type
124+
var validators []atls.Validator
125+
for _, typeValidators := range validatorsByType {
126+
validators = append(validators, NewMultiValidator(typeValidators))
127+
}
128+
129+
return validators, nil
114130
}
115131

116132
func ExtractMeasurementsFromExtension(ext *pkix.Extension, v variant.Variant) (map[uint32][]byte, error) {

0 commit comments

Comments
 (0)