Skip to content

Commit 6ea7791

Browse files
Add support for the TLS13-KDF algorithm
1 parent cc9339b commit 6ea7791

File tree

3 files changed

+227
-16
lines changed

3 files changed

+227
-16
lines changed

const.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,11 @@ const ( //checkheader:ignore
4848
_DigestNameSHA2_256 cString = "SHA2-256\x00"
4949

5050
// KDF names
51-
_OSSL_KDF_NAME_HKDF cString = "HKDF\x00"
52-
_OSSL_KDF_NAME_PBKDF2 cString = "PBKDF2\x00"
53-
_OSSL_KDF_NAME_TLS1_PRF cString = "TLS1-PRF\x00"
54-
_OSSL_MAC_NAME_HMAC cString = "HMAC\x00"
51+
_OSSL_KDF_NAME_HKDF cString = "HKDF\x00"
52+
_OSSL_KDF_NAME_PBKDF2 cString = "PBKDF2\x00"
53+
_OSSL_KDF_NAME_TLS1_PRF cString = "TLS1-PRF\x00"
54+
_OSSL_KDF_NAME_TLS13_KDF cString = "TLS13-KDF\x00"
55+
_OSSL_MAC_NAME_HMAC cString = "HMAC\x00"
5556

5657
// KDF parameters
5758
_OSSL_KDF_PARAM_DIGEST cString = "digest\x00"
@@ -62,6 +63,11 @@ const ( //checkheader:ignore
6263
_OSSL_KDF_PARAM_SALT cString = "salt\x00"
6364
_OSSL_KDF_PARAM_MODE cString = "mode\x00"
6465

66+
// TLS3-KDF parameters
67+
_OSSL_KDF_PARAM_PREFIX cString = "prefix\x00"
68+
_OSSL_KDF_PARAM_LABEL cString = "label\x00"
69+
_OSSL_KDF_PARAM_DATA cString = "data\x00"
70+
6571
// PKEY parameters
6672
_OSSL_PKEY_PARAM_PUB_KEY cString = "pub\x00"
6773
_OSSL_PKEY_PARAM_PRIV_KEY cString = "priv\x00"

hkdf.go

Lines changed: 116 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package openssl
44

55
import "C"
66
import (
7+
"bytes"
78
"errors"
89
"hash"
910
"io"
@@ -261,6 +262,80 @@ func (c *hkdf3) finalize() {
261262
}
262263
}
263264

265+
// fetchTLS13_KDF fetches the TLS13-KDF algorithm.
266+
// It is safe to call this function concurrently.
267+
// The returned EVP_KDF_PTR shouldn't be freed.
268+
var fetchTLS13_KDF = sync.OnceValues(func() (ossl.EVP_KDF_PTR, error) {
269+
checkMajorVersion(3)
270+
271+
kdf, err := ossl.EVP_KDF_fetch(nil, _OSSL_KDF_NAME_TLS13_KDF.ptr(), nil)
272+
if err != nil {
273+
return nil, err
274+
}
275+
return kdf, nil
276+
})
277+
278+
//
279+
// https://docs.openssl.org/3.4/man7/EVP_KDF-TLS13_KDF/
280+
// https://datatracker.ietf.org/doc/html/rfc8446#section-7.1
281+
//
282+
// This function parses the info parameter for TLS 1.3 KDF.
283+
// The info parameter is expected to be in the format:
284+
//
285+
// +--------------+-----------+-------------------+------------+------------------+
286+
// | total length | labelLen | label bytes... | contextLen | context bytes... |
287+
// | 2 bytes | 1 byte | labelLen bytes | 1 byte | contextLen bytes |
288+
// +--------------+-----------+-------------------+------------+------------------+
289+
//
290+
// The label bytes are expected to be begin with "tls13 ".
291+
//
292+
func ParseForTLS13(info []byte) (isTLS13 bool, label, context []byte) {
293+
isTLS13 = false
294+
label = nil
295+
context = nil
296+
297+
if len(info) <= 2 {
298+
return
299+
}
300+
301+
cursor := 2
302+
labelLen := int(info[cursor])
303+
304+
if labelLen < 7 {
305+
return
306+
}
307+
308+
cursor++
309+
if cursor+labelLen > len(info) {
310+
return
311+
}
312+
313+
314+
labelBytes := info[cursor : cursor+labelLen]
315+
cursor += labelLen
316+
317+
if !bytes.HasPrefix(labelBytes, []byte("tls13 ")) {
318+
return
319+
}
320+
321+
if cursor >= len(info) {
322+
return
323+
}
324+
325+
contextLen := int(info[cursor])
326+
cursor++
327+
if cursor+contextLen > len(info) {
328+
return
329+
}
330+
331+
// Success, set the out parameters
332+
label = labelBytes[len("tls13 "):]
333+
context = info[cursor : cursor+contextLen]
334+
isTLS13 = true
335+
336+
return
337+
}
338+
264339
// fetchHKDF3 fetches the HKDF algorithm.
265340
// It is safe to call this function concurrently.
266341
// The returned EVP_KDF_PTR shouldn't be freed.
@@ -278,7 +353,14 @@ var fetchHKDF3 = sync.OnceValues(func() (ossl.EVP_KDF_PTR, error) {
278353
func newHKDFCtx3(md ossl.EVP_MD_PTR, mode int32, secret, salt, pseudorandomKey, info []byte) (_ ossl.EVP_KDF_CTX_PTR, err error) {
279354
checkMajorVersion(3)
280355

281-
kdf, err := fetchHKDF3()
356+
useTLS13KDF, label, context := ParseForTLS13(info)
357+
358+
var kdf ossl.EVP_KDF_PTR
359+
if useTLS13KDF {
360+
kdf, err = fetchTLS13_KDF()
361+
} else {
362+
kdf, err = fetchHKDF3()
363+
}
282364
if err != nil {
283365
return nil, err
284366
}
@@ -298,17 +380,39 @@ func newHKDFCtx3(md ossl.EVP_MD_PTR, mode int32, secret, salt, pseudorandomKey,
298380
}
299381
bld.addUTF8String(_OSSL_KDF_PARAM_DIGEST, ossl.EVP_MD_get0_name(md), 0)
300382
bld.addInt32(_OSSL_KDF_PARAM_MODE, int32(mode))
301-
if len(secret) > 0 {
302-
bld.addOctetString(_OSSL_KDF_PARAM_KEY, secret)
303-
}
304-
if len(salt) > 0 {
305-
bld.addOctetString(_OSSL_KDF_PARAM_SALT, salt)
306-
}
307-
if len(pseudorandomKey) > 0 {
308-
bld.addOctetString(_OSSL_KDF_PARAM_KEY, pseudorandomKey)
309-
}
310-
if len(info) > 0 {
311-
bld.addOctetString(_OSSL_KDF_PARAM_INFO, info)
383+
384+
if useTLS13KDF {
385+
if (mode == ossl.EVP_KDF_HKDF_MODE_EXTRACT_ONLY) {
386+
//bld.addUTF8String(C.CString("mode"), C.CString("EXTRACT_ONLY"), 0)
387+
if len(salt) > 0 {
388+
bld.addOctetString(_OSSL_KDF_PARAM_SALT, salt)
389+
}
390+
if len(pseudorandomKey) > 0 {
391+
bld.addOctetString(_OSSL_KDF_PARAM_KEY, secret)
392+
}
393+
} else {
394+
//bld.addUTF8String(C.CString("mode"), C.CString("EXPAND_ONLY"), 0)
395+
bld.addOctetString(_OSSL_KDF_PARAM_PREFIX, []byte("tls13 "))
396+
bld.addOctetString(_OSSL_KDF_PARAM_LABEL, label)
397+
bld.addOctetString(_OSSL_KDF_PARAM_DATA, context)
398+
if len(pseudorandomKey) > 0 {
399+
bld.addOctetString(_OSSL_KDF_PARAM_KEY, pseudorandomKey)
400+
}
401+
}
402+
403+
} else {
404+
if len(secret) > 0 {
405+
bld.addOctetString(_OSSL_KDF_PARAM_KEY, secret)
406+
}
407+
if len(salt) > 0 {
408+
bld.addOctetString(_OSSL_KDF_PARAM_SALT, salt)
409+
}
410+
if len(pseudorandomKey) > 0 {
411+
bld.addOctetString(_OSSL_KDF_PARAM_KEY, pseudorandomKey)
412+
}
413+
if len(info) > 0 {
414+
bld.addOctetString(_OSSL_KDF_PARAM_INFO, info)
415+
}
312416
}
313417
params, err := bld.build()
314418
if err != nil {

hkdf_test.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package openssl_test
22

33
import (
4+
"encoding/binary"
45
"bytes"
56
"hash"
67
"io"
@@ -448,3 +449,103 @@ func TestExpandHKDFOneShotLimit(t *testing.T) {
448449
t.Errorf("expected error for key expansion overflow")
449450
}
450451
}
452+
453+
func makeInfo(fullLabel string, context []byte) []byte {
454+
totalLen := 1 + len(fullLabel) + 1 + len(context)
455+
456+
info := make([]byte, 2+totalLen)
457+
binary.BigEndian.PutUint16(info[0:2], uint16(totalLen))
458+
info[2] = byte(len(fullLabel))
459+
copy(info[3:], fullLabel)
460+
info[3+len(fullLabel)] = byte(len(context))
461+
copy(info[4+len(fullLabel):], context)
462+
463+
return info
464+
}
465+
466+
func TestParseForTLS13_Valid(t *testing.T) {
467+
tests := []struct {
468+
name string
469+
labelPrefix string
470+
label string
471+
context []byte
472+
}{
473+
{"IV", "tls13 ", "iv", []byte{}},
474+
{"Traffic Secret", "tls13 ", "c hs traffic", []byte{0xaa, 0xbb}},
475+
{"Finished", "tls13 ", "finished", []byte{0x00}},
476+
}
477+
478+
for _, tt := range tests {
479+
t.Run(tt.name, func(t *testing.T) {
480+
info := makeInfo(tt.labelPrefix + tt.label, tt.context)
481+
isTLS13, label, context := openssl.ParseForTLS13(info)
482+
if !isTLS13 {
483+
t.Errorf("Expected TLS13 label, got isTLS13=false")
484+
}
485+
if !bytes.Equal(label, []byte(tt.label)) {
486+
t.Errorf("Label mismatch: got %q, want %q", label, tt.label)
487+
}
488+
if !bytes.Equal(context, tt.context) {
489+
t.Errorf("Context mismatch: got %x, want %x", context, tt.context)
490+
}
491+
})
492+
}
493+
}
494+
495+
func TestParseForTLS13_Invalid(t *testing.T) {
496+
tests := []struct {
497+
name string
498+
info []byte
499+
}{
500+
{"Missing tls13 prefix", makeInfo("foobar", []byte{0x01})},
501+
{"Too short", []byte{0x00}},
502+
{"Label length exceeds buffer", []byte{0xFF, 't'}},
503+
{"Incomplete prefix", []byte{0x06, 't', 'l', 's', '1', '3', ' ', 0x00 }}, // discovered by the fuzzer
504+
{"Correct prefix but missing context", []byte{0x08, 't', 'l', 's', '1', '3', ' ', 'i', 'v'}},
505+
{"Correct prefix but truncated context", []byte{0x08, 't', 'l', 's', '1', '3', ' ', 'i', 'v', 0x02}},
506+
{"Correct prefix but truncated context", []byte{0x06, 't', 'l', 's', '1', '3', ' ', 'i', 'v', 0x02}},
507+
}
508+
509+
for _, tt := range tests {
510+
t.Run(tt.name, func(t *testing.T) {
511+
isTLS13, label, context := openssl.ParseForTLS13(tt.info)
512+
if isTLS13 {
513+
t.Errorf("Expected isTLS13=false, got true")
514+
}
515+
if label != nil {
516+
t.Errorf("Expected label=nil, got %q", label)
517+
}
518+
if context != nil {
519+
t.Errorf("Expected context=nil, got %x", context)
520+
}
521+
})
522+
}
523+
}
524+
525+
// run the fuzzer with:
526+
// go test -fuzz=FuzzParseForTLS13
527+
func FuzzParseForTLS13(f *testing.F) {
528+
// Seed with known-good examples
529+
f.Add([]byte{0x08, 't', 'l', 's', '1', '3', ' ', 'i', 'v', 0x00}) // "tls13 iv" + empty context
530+
f.Add([]byte{0x0c, 't', 'l', 's', '1', '3', ' ', 'c', ' ', 'h', 's', ' ', 't', 'r', 'a', 'f', 'f', 'i', 'c', 0x02, 0xAA, 0xBB})
531+
532+
f.Fuzz(func(t *testing.T, data []byte) {
533+
defer func() {
534+
if r := recover(); r != nil {
535+
t.Errorf("panic with input: %x — %v", data, r)
536+
}
537+
}()
538+
539+
isTLS13, label, context := openssl.ParseForTLS13(data)
540+
541+
if isTLS13 {
542+
if len(label) == 0 {
543+
t.Errorf("isTLS13=true but label is empty (input: %x)", data)
544+
}
545+
// Context can be 0-length, but shouldn't go out of bounds
546+
if context == nil {
547+
t.Errorf("isTLS13=true but context is nil (input: %x)", data)
548+
}
549+
}
550+
})
551+
}

0 commit comments

Comments
 (0)