Skip to content

Commit 04abe8c

Browse files
authored
export GetMeasurementsFromTLS and make independent of proxy (#23)
1 parent 837588b commit 04abe8c

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

proxy/atls_config.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ import (
1212

1313
"github.com/flashbots/cvm-reverse-proxy/internal/atls"
1414
azure_tdx "github.com/flashbots/cvm-reverse-proxy/internal/attestation/azure/tdx"
15-
dcap_tdx "github.com/flashbots/cvm-reverse-proxy/tdx"
1615
"github.com/flashbots/cvm-reverse-proxy/internal/attestation/measurements"
1716
"github.com/flashbots/cvm-reverse-proxy/internal/attestation/variant"
1817
"github.com/flashbots/cvm-reverse-proxy/internal/cloud/cloudprovider"
1918
"github.com/flashbots/cvm-reverse-proxy/internal/config"
19+
dcap_tdx "github.com/flashbots/cvm-reverse-proxy/tdx"
2020
)
2121

2222
type AttestationType string

proxy/proxy.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package proxy
22

33
import (
44
"crypto/tls"
5+
"crypto/x509"
56
"crypto/x509/pkix"
67
"encoding/asn1"
78
"encoding/hex"
@@ -108,12 +109,12 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
108109
p.log.With("duration", duration).Info("[proxy-request] proxying complete")
109110
}
110111

111-
func (p *Proxy) getMeasurementsFromTLS(conn *tls.ConnectionState) (atlsVariant variant.Variant, measurements map[uint32][]byte, err error) {
112+
func GetMeasurementsFromTLS(certs []*x509.Certificate, validatorOIDs []asn1.ObjectIdentifier) (atlsVariant variant.Variant, measurements map[uint32][]byte, err error) {
112113
// In verifyEmbeddedReport which is used to validate the extensions, only the first matching extension is validated! Refuse to accept multiple
113114
var ATLSExtension *pkix.Extension = nil
114-
for _, cert := range conn.PeerCertificates {
115+
for _, cert := range certs {
115116
for _, ext := range cert.Extensions {
116-
for _, validatorOID := range p.validatorOIDs {
117+
for _, validatorOID := range validatorOIDs {
117118
if ext.Id.Equal(validatorOID) {
118119
if ATLSExtension != nil {
119120
return nil, nil, errors.New("more than one ATLS extension provided, refusing to continue")
@@ -142,7 +143,8 @@ func (p *Proxy) getMeasurementsFromTLS(conn *tls.ConnectionState) (atlsVariant v
142143
}
143144

144145
func (p *Proxy) copyMeasurementsToHeader(conn *tls.ConnectionState, header *http.Header) (int, error) {
145-
atlsVariant, extractedMeasurements, err := p.getMeasurementsFromTLS(conn)
146+
certs := conn.PeerCertificates
147+
atlsVariant, extractedMeasurements, err := GetMeasurementsFromTLS(certs, p.validatorOIDs)
146148
if err != nil {
147149
return http.StatusTeapot, err
148150
} else if extractedMeasurements == nil {

0 commit comments

Comments
 (0)