Skip to content

Commit

Permalink
refactor+feat: Custom Client Handshake + Implement ALPS extension (re…
Browse files Browse the repository at this point in the history
…fraction-networking#142)

* refactor: split `CompressCertExtension` changes

- Split most of changes for `CompressCertExtension` made to `crypto/tls` files out and moved them to `u_` files.
- Edited some `crypto/tls` files to achieve better programmability for uTLS.
- Minor styling fix.

* feat: implement ALPS Extension draft

- Made necessary modifications to existing types to support ALPS.
- Ported `ApplicationSettingsExtension` implementation from `ulixee/utls` by @blakebyrnes with some adaptation.

Co-Authored-By: Blake Byrnes <115056+blakebyrnes@users.noreply.github.com>

* feat: utlsFakeCustomExtension in ALPS

- Introducing `utlsFakeCustomExtension` to enable implementation for custom extensions to be exchanged via ALPS.
- currently it doesn't do anything.

Co-Authored-By: Blake Byrnes <115056+blakebyrnes@users.noreply.github.com>

* fix: magic number in `StatusRequestV2Extension`

- Fixed magic number `17` in `StatusRequestV2Extension` with pre-defined enum `extensionStatusRequestV2`.

Co-authored-by: Blake Byrnes <115056+blakebyrnes@users.noreply.github.com>
  • Loading branch information
gaukas and blakebyrnes authored Nov 17, 2022
1 parent 1b3a9ad commit fb99df2
Show file tree
Hide file tree
Showing 13 changed files with 375 additions and 140 deletions.
12 changes: 11 additions & 1 deletion common.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ const (
extensionSupportedPoints uint16 = 11
extensionSignatureAlgorithms uint16 = 13
extensionALPN uint16 = 16
extensionStatusRequestV2 uint16 = 17
extensionSCT uint16 = 18
extensionDelegatedCredentials uint16 = 34
extensionSessionTicket uint16 = 35
Expand All @@ -100,7 +101,7 @@ const (
extensionCertificateAuthorities uint16 = 47
extensionSignatureAlgorithmsCert uint16 = 50
extensionKeyShare uint16 = 51
extensionNextProtoNeg uint16 = 13172 // not IANA assigned // Pending discussion on whether or not remove this. crypto/tls removed it on Nov 21, 2019.
extensionNextProtoNeg uint16 = 13172 // not IANA assigned // Pending discussion on whether or not remove this. crypto/tls removed it on Nov 21, 2019.
extensionRenegotiationInfo uint16 = 0xff01
)

Expand Down Expand Up @@ -237,6 +238,10 @@ type ConnectionState struct {
// Deprecated: this value is always true.
NegotiatedProtocolIsMutual bool

// PeerApplicationSettings is the Application-Layer Protocol Settings (ALPS)
// provided by peer.
PeerApplicationSettings []byte // [uTLS]

// ServerName is the value of the Server Name Indication extension sent by
// the client. It's available both on the server and on the client side.
ServerName string
Expand Down Expand Up @@ -624,6 +629,10 @@ type Config struct {
// ConnectionState.NegotiatedProtocol will be empty.
NextProtos []string

// ApplicationSettings is a set of application settings (ALPS) to use
// with each application protocol (ALPN).
ApplicationSettings map[string][]byte // [uTLS]

// ServerName is used to verify the hostname on the returned
// certificates unless InsecureSkipVerify is given. It is also included
// in the client's handshake to support virtual hosting unless it is
Expand Down Expand Up @@ -799,6 +808,7 @@ func (c *Config) Clone() *Config {
VerifyConnection: c.VerifyConnection,
RootCAs: c.RootCAs,
NextProtos: c.NextProtos,
ApplicationSettings: c.ApplicationSettings,
ServerName: c.ServerName,
ClientAuth: c.ClientAuth,
ClientCAs: c.ClientCAs,
Expand Down
22 changes: 16 additions & 6 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ type Conn struct {
// clientProtocol is the negotiated ALPN protocol.
clientProtocol string

// [UTLS SECTION START]
utls utlsConnExtraFields // used for extensive things such as ALPS
// [UTLS SECTION END]

// input/output
in, out halfConn
rawInput bytes.Buffer // raw input, starting with a record header
Expand Down Expand Up @@ -1075,17 +1079,22 @@ func (c *Conn) readHandshake() (any, error) {
}
case typeFinished:
m = new(finishedMsg)
case typeEncryptedExtensions:
m = new(encryptedExtensionsMsg)
// [uTLS] Commented typeEncryptedExtensions to force
// utlsHandshakeMessageType to handle it
// case typeEncryptedExtensions:
// m = new(encryptedExtensionsMsg)
case typeEndOfEarlyData:
m = new(endOfEarlyDataMsg)
case typeKeyUpdate:
m = new(keyUpdateMsg)
// [UTLS SECTION BEGINS]
case typeCompressedCertificate:
m = new(compressedCertificateMsg)
// [UTLS SECTION ENDS]
default:
// [UTLS SECTION BEGINS]
var err error
m, err = c.utlsHandshakeMessageType(data[0]) // see u_conn.go
if err == nil {
break
}
// [UTLS SECTION ENDS]
return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}

Expand Down Expand Up @@ -1514,6 +1523,7 @@ func (c *Conn) connectionStateLocked() ConnectionState {
} else {
state.ekm = c.ekm
}

return state
}

Expand Down
121 changes: 26 additions & 95 deletions handshake_client_tls13.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,15 @@ package tls

import (
"bytes"
"compress/zlib"
"context"
"crypto"
"crypto/hmac"
"crypto/rsa"
"errors"
"fmt"
"hash"
"io"
"sync/atomic"
"time"

"github.com/andybalholm/brotli"
"github.com/klauspost/compress/zstd"
)

type clientHandshakeStateTLS13 struct {
Expand Down Expand Up @@ -103,6 +98,11 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
if err := hs.readServerFinished(); err != nil {
return err
}
// [UTLS SECTION START]
if err := hs.serverFinishedReceived(); err != nil {
return err
}
// [UTLS SECTION END]
if err := hs.sendClientCertificate(); err != nil {
return err
}
Expand Down Expand Up @@ -477,6 +477,15 @@ func (hs *clientHandshakeStateTLS13) readServerParameters() error {
}
c.clientProtocol = encryptedExtensions.alpnProtocol

// [UTLS SECTION STARTS]
if hs.uconn != nil {
err = hs.utlsReadServerParameters(encryptedExtensions)
if err != nil {
c.sendAlert(alertUnsupportedExtension)
return err
}
}
// [UTLS SECTION ENDS]
return nil
}

Expand Down Expand Up @@ -516,19 +525,15 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error {
}

// [UTLS SECTION BEGINS]
receivedCompressedCert := false
// Check to see if we advertised any compression algorithms
if hs.uconn != nil && len(hs.uconn.certCompressionAlgs) > 0 {
// Check to see if the message is a compressed certificate message, otherwise move on.
compressedCertMsg, ok := msg.(*compressedCertificateMsg)
if ok {
receivedCompressedCert = true
hs.transcript.Write(compressedCertMsg.marshal())

msg, err = hs.decompressCert(*compressedCertMsg)
if err != nil {
return fmt.Errorf("tls: failed to decompress certificate message: %w", err)
}
var skipWritingCertToTranscript bool = false
if hs.uconn != nil {
processedMsg, err := hs.utlsReadServerCertificate(msg)
if err != nil {
return err
}
if processedMsg != nil {
skipWritingCertToTranscript = true
msg = processedMsg // msg is now a processed-by-extension certificateMsg
}
}
// [UTLS SECTION ENDS]
Expand All @@ -544,7 +549,7 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error {
}
// [UTLS SECTION BEGINS]
// Previously, this was simply 'hs.transcript.Write(certMsg.marshal())' (without the if).
if !receivedCompressedCert {
if !skipWritingCertToTranscript {
hs.transcript.Write(certMsg.marshal())
}
// [UTLS SECTION ENDS]
Expand All @@ -570,15 +575,15 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error {
// See RFC 8446, Section 4.4.3.
if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms()) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: certificate used with invalid signature algorithm")
return errors.New("tls: certificate used with invalid signature algorithm -- not implemented")
}
sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm)
if err != nil {
return c.sendAlert(alertInternalError)
}
if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: certificate used with invalid signature algorithm")
return errors.New("tls: certificate used with invalid signature algorithm -- obsolete")
}
signed := signedMessage(sigHash, serverSignatureContext, hs.transcript)
if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey,
Expand Down Expand Up @@ -729,80 +734,6 @@ func (hs *clientHandshakeStateTLS13) sendClientFinished() error {
return nil
}

// [UTLS SECTION BEGINS]
func (hs *clientHandshakeStateTLS13) decompressCert(m compressedCertificateMsg) (*certificateMsgTLS13, error) {
var (
decompressed io.Reader
compressed = bytes.NewReader(m.compressedCertificateMessage)
c = hs.c
)

// Check to see if the peer responded with an algorithm we advertised.
supportedAlg := false
for _, alg := range hs.uconn.certCompressionAlgs {
if m.algorithm == uint16(alg) {
supportedAlg = true
}
}
if !supportedAlg {
c.sendAlert(alertBadCertificate)
return nil, fmt.Errorf("unadvertised algorithm (%d)", m.algorithm)
}

switch CertCompressionAlgo(m.algorithm) {
case CertCompressionBrotli:
decompressed = brotli.NewReader(compressed)

case CertCompressionZlib:
rc, err := zlib.NewReader(compressed)
if err != nil {
c.sendAlert(alertBadCertificate)
return nil, fmt.Errorf("failed to open zlib reader: %w", err)
}
defer rc.Close()
decompressed = rc

case CertCompressionZstd:
rc, err := zstd.NewReader(compressed)
if err != nil {
c.sendAlert(alertBadCertificate)
return nil, fmt.Errorf("failed to open zstd reader: %w", err)
}
defer rc.Close()
decompressed = rc

default:
c.sendAlert(alertBadCertificate)
return nil, fmt.Errorf("unsupported algorithm (%d)", m.algorithm)
}

rawMsg := make([]byte, m.uncompressedLength+4) // +4 for message type and uint24 length field
rawMsg[0] = typeCertificate
rawMsg[1] = uint8(m.uncompressedLength >> 16)
rawMsg[2] = uint8(m.uncompressedLength >> 8)
rawMsg[3] = uint8(m.uncompressedLength)

n, err := decompressed.Read(rawMsg[4:])
if err != nil {
c.sendAlert(alertBadCertificate)
return nil, err
}
if n < len(rawMsg)-4 {
// If, after decompression, the specified length does not match the actual length, the party
// receiving the invalid message MUST abort the connection with the "bad_certificate" alert.
// https://datatracker.ietf.org/doc/html/rfc8879#section-4
c.sendAlert(alertBadCertificate)
return nil, fmt.Errorf("decompressed len (%d) does not match specified len (%d)", n, m.uncompressedLength)
}
certMsg := new(certificateMsgTLS13)
if !certMsg.unmarshal(rawMsg) {
return nil, c.sendAlert(alertUnexpectedMessage)
}
return certMsg, nil
}

// [UTLS SECTION ENDS]

func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error {
if !c.isClient {
c.sendAlert(alertUnexpectedMessage)
Expand Down
7 changes: 7 additions & 0 deletions handshake_messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,8 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
type encryptedExtensionsMsg struct {
raw []byte
alpnProtocol string

utls utlsEncryptedExtensionsMsgExtraFields // [uTLS]
}

func (m *encryptedExtensionsMsg) marshal() []byte {
Expand Down Expand Up @@ -927,6 +929,11 @@ func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
}
m.alpnProtocol = string(proto)
default:
// [UTLS SECTION START]
if !m.utlsUnmarshal(extension, extData) {
return false // return false when ERROR
}
// [UTLS SECTION END]
// Ignore unknown extensions.
continue
}
Expand Down
6 changes: 3 additions & 3 deletions handshake_messages_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ var tests = []any{
&newSessionTicketMsgTLS13{},
&certificateRequestMsgTLS13{},
&certificateMsgTLS13{},
&compressedCertificateMsg{}, // [UTLS]
&utlsCompressedCertificateMsg{}, // [UTLS]
}

func TestMarshalUnmarshal(t *testing.T) {
Expand Down Expand Up @@ -406,8 +406,8 @@ func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
}

// [UTLS]
func (*compressedCertificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &compressedCertificateMsg{}
func (*utlsCompressedCertificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &utlsCompressedCertificateMsg{}
m.algorithm = uint16(rand.Intn(2 << 15))
m.uncompressedLength = uint32(rand.Intn(2 << 23))
m.compressedCertificateMessage = randomBytes(rand.Intn(500)+1, rand)
Expand Down
2 changes: 1 addition & 1 deletion key_agreement.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHell
}

if !isSupportedSignatureAlgorithm(signatureAlgorithm, clientHello.supportedSignatureAlgorithms) {
return errors.New("tls: certificate used with invalid signature algorithm")
return fmt.Errorf("tls: certificate used with invalid signature algorithm -- ClientHello not advertising %04x", uint16(signatureAlgorithm))
}
sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm)
if err != nil {
Expand Down
5 changes: 4 additions & 1 deletion tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/refraction-networking/utls/testenv"
"io"
"math"
"net"
Expand All @@ -22,6 +21,8 @@ import (
"strings"
"testing"
"time"

"github.com/refraction-networking/utls/testenv"
)

var rsaCertPEM = `-----BEGIN CERTIFICATE-----
Expand Down Expand Up @@ -827,6 +828,8 @@ func TestCloneNonFuncFields(t *testing.T) {
f.Set(reflect.ValueOf(RenegotiateOnceAsClient))
case "mutex", "autoSessionTicketKeys", "sessionTicketKeys":
continue // these are unexported fields that are handled separately
case "ApplicationSettings":
f.Set(reflect.ValueOf(map[string][]byte{"a": {1}}))
default:
t.Errorf("all fields must be accounted for, but saw unknown field %q", fn)
}
Expand Down
Loading

0 comments on commit fb99df2

Please sign in to comment.