Skip to content

Commit

Permalink
kdf: share Z hash state #220
Browse files Browse the repository at this point in the history
  • Loading branch information
emmansun authored May 15, 2024
1 parent 57318ea commit c99ad27
Show file tree
Hide file tree
Showing 12 changed files with 153 additions and 35 deletions.
5 changes: 2 additions & 3 deletions cfca/pkcs12_sm2.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"fmt"
"math/big"

"github.com/emmansun/gmsm/kdf"
"github.com/emmansun/gmsm/padding"
"github.com/emmansun/gmsm/pkcs"
"github.com/emmansun/gmsm/sm2"
Expand Down Expand Up @@ -59,7 +58,7 @@ func ParseSM2(password, data []byte) (*sm2.PrivateKey, *smx509.Certificate, erro
if !keys.EncryptedKey.Algorithm.Equal(oidSM4) && !keys.EncryptedKey.Algorithm.Equal(oidSM4CBC) {
return nil, nil, fmt.Errorf("cfca: unsupported algorithm <%v>", keys.EncryptedKey.Algorithm)
}
ivkey := kdf.Kdf(sm3.New(), password, 32)
ivkey := sm3.Kdf(password, 32)
marshalledIV, err := asn1.Marshal(ivkey[:16])
if err != nil {
return nil, nil, err
Expand Down Expand Up @@ -91,7 +90,7 @@ func MarshalSM2(password []byte, key *sm2.PrivateKey, cert *smx509.Certificate)
if len(password) == 0 {
return nil, errors.New("cfca: invalid password")
}
ivkey := kdf.Kdf(sm3.New(), password, 32)
ivkey := sm3.Kdf(password, 32)
block, err := sm4.NewCipher(ivkey[16:])
if err != nil {
return nil, err
Expand Down
41 changes: 31 additions & 10 deletions kdf/kdf.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,48 @@
package kdf

import (
"encoding"
"encoding/binary"
"hash"
)

// Kdf key derivation function, compliance with GB/T 32918.4-2016 5.4.3.
// ANSI-X9.63-KDF
func Kdf(md hash.Hash, z []byte, len int) []byte {
limit := uint64(len+md.Size()-1) / uint64(md.Size())
func Kdf(newHash func() hash.Hash, z []byte, keyLen int) []byte {
baseMD := newHash()
limit := uint64(keyLen+baseMD.Size()-1) / uint64(baseMD.Size())
if limit >= uint64(1<<32)-1 {
panic("kdf: key length too long")
}
var countBytes [4]byte
var ct uint32 = 1
var k []byte
for i := 0; i < int(limit); i++ {
binary.BigEndian.PutUint32(countBytes[:], ct)
md.Write(z)
md.Write(countBytes[:])
k = md.Sum(k)
ct++
md.Reset()

marshaler, ok := baseMD.(encoding.BinaryMarshaler)
if limit == 1 || len(z) < baseMD.BlockSize() || !ok {
for i := 0; i < int(limit); i++ {
binary.BigEndian.PutUint32(countBytes[:], ct)
baseMD.Write(z)
baseMD.Write(countBytes[:])
k = baseMD.Sum(k)
ct++
baseMD.Reset()
}
} else {
baseMD.Write(z)
zstate, _ := marshaler.MarshalBinary()
for i := 0; i < int(limit); i++ {
md := newHash()
err := md.(encoding.BinaryUnmarshaler).UnmarshalBinary(zstate)
if err != nil {
panic(err)
}
binary.BigEndian.PutUint32(countBytes[:], ct)
md.Write(countBytes[:])
k = md.Sum(k)
ct++
}
}
return k[:len]

return k[:keyLen]
}
2 changes: 1 addition & 1 deletion kdf/kdf_64bit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ import (
// This case should be failed on 32bits system.
func TestKdfPanic(t *testing.T) {
shouldPanic(t, func() {
Kdf(sm3.New(), []byte("123456"), 1<<37)
Kdf(sm3.New, []byte("123456"), 1<<37)
})
}
11 changes: 6 additions & 5 deletions kdf/kdf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestKdf(t *testing.T) {
for _, tt := range tests {
wantBytes, _ := hex.DecodeString(tt.want)
t.Run(tt.name, func(t *testing.T) {
if got := Kdf(tt.args.md, tt.args.z, tt.args.len); !reflect.DeepEqual(got, wantBytes) {
if got := Kdf(sm3.New, tt.args.z, tt.args.len); !reflect.DeepEqual(got, wantBytes) {
t.Errorf("Kdf(%v) = %x, want %v", tt.name, got, tt.want)
}
})
Expand All @@ -44,7 +44,7 @@ func TestKdfOldCase(t *testing.T) {

expected := "006e30dae231b071dfad8aa379e90264491603"

result := Kdf(sm3.New(), append(x2.Bytes(), y2.Bytes()...), 19)
result := Kdf(sm3.New, append(x2.Bytes(), y2.Bytes()...), 19)

resultStr := hex.EncodeToString(result)

Expand All @@ -71,16 +71,17 @@ func BenchmarkKdf(b *testing.B) {
{64, 32},
{64, 64},
{64, 128},
{440, 32},
{64, 256},
{64, 512},
{64, 1024},
}
sm3Hash := sm3.New()
z := make([]byte, 512)
for _, tt := range tests {
b.Run(fmt.Sprintf("zLen=%v-kLen=%v", tt.zLen, tt.kLen), func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
Kdf(sm3Hash, z[:tt.zLen], tt.kLen)
Kdf(sm3.New, z[:tt.zLen], tt.kLen)
}
})
}
Expand Down
5 changes: 2 additions & 3 deletions sm2/sm2.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"github.com/emmansun/gmsm/internal/randutil"
_sm2ec "github.com/emmansun/gmsm/internal/sm2ec"
"github.com/emmansun/gmsm/internal/subtle"
"github.com/emmansun/gmsm/kdf"
"github.com/emmansun/gmsm/sm2/sm2ec"
"github.com/emmansun/gmsm/sm3"
"golang.org/x/crypto/cryptobyte"
Expand Down Expand Up @@ -251,7 +250,7 @@ func encryptSM2EC(c *sm2Curve, pub *ecdsa.PublicKey, random io.Reader, msg []byt
return nil, err
}
C2Bytes := C2.Bytes()[1:]
c2 := kdf.Kdf(sm3.New(), C2Bytes, len(msg))
c2 := sm3.Kdf(C2Bytes, len(msg))
if subtle.ConstantTimeAllZero(c2) {
retryCount++
if retryCount > maxRetryLimit {
Expand Down Expand Up @@ -424,7 +423,7 @@ func decryptSM2EC(c *sm2Curve, priv *PrivateKey, ciphertext []byte, opts *Decryp
}
C2Bytes := C2.Bytes()[1:]
msgLen := len(c2)
msg := kdf.Kdf(sm3.New(), C2Bytes, msgLen)
msg := sm3.Kdf(C2Bytes, msgLen)
if subtle.ConstantTimeAllZero(c2) {
return nil, ErrDecryption
}
Expand Down
3 changes: 1 addition & 2 deletions sm2/sm2_keyexchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"io"
"math/big"

"github.com/emmansun/gmsm/kdf"
"github.com/emmansun/gmsm/sm3"
)

Expand Down Expand Up @@ -185,7 +184,7 @@ func (ke *KeyExchange) generateSharedKey(isResponder bool) ([]byte, error) {
buffer = append(buffer, ke.z...)
buffer = append(buffer, ke.peerZ...)
}
return kdf.Kdf(sm3.New(), buffer, ke.keyLength), nil
return sm3.Kdf(buffer, ke.keyLength), nil
}

// avf is the associative value function.
Expand Down
5 changes: 2 additions & 3 deletions sm2/sm2_legacy.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"strings"

"github.com/emmansun/gmsm/internal/subtle"
"github.com/emmansun/gmsm/kdf"
"github.com/emmansun/gmsm/sm2/sm2ec"
"github.com/emmansun/gmsm/sm3"
"golang.org/x/crypto/cryptobyte"
Expand Down Expand Up @@ -260,7 +259,7 @@ func encryptLegacy(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Enc
x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())

//A5, calculate t=KDF(x2||y2, klen)
c2 := kdf.Kdf(sm3.New(), append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
c2 := sm3.Kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
if subtle.ConstantTimeAllZero(c2) {
retryCount++
if retryCount > maxRetryLimit {
Expand Down Expand Up @@ -408,7 +407,7 @@ func rawDecrypt(priv *PrivateKey, x1, y1 *big.Int, c2, c3 []byte) ([]byte, error
curve := priv.Curve
x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes())
msgLen := len(c2)
msg := kdf.Kdf(sm3.New(), append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
msg := sm3.Kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
if subtle.ConstantTimeAllZero(c2) {
return nil, ErrDecryption
}
Expand Down
8 changes: 8 additions & 0 deletions sm2/sm2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -851,3 +851,11 @@ func BenchmarkMoreThan32_P256(b *testing.B) {
func BenchmarkMoreThan32_SM2(b *testing.B) {
benchmarkEncrypt(b, P256(), "encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard")
}

func BenchmarkEncrypt512_SM2(b *testing.B) {
benchmarkEncrypt(b, P256(), "encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption s")
}

func BenchmarkEncrypt1024_SM2(b *testing.B) {
benchmarkEncrypt(b, P256(), "encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption sencryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption s")
}
23 changes: 23 additions & 0 deletions sm3/sm3.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,26 @@ func Sum(data []byte) [Size]byte {
d.Write(data)
return d.checkSum()
}

// Kdf key derivation function using SM3, compliance with GB/T 32918.4-2016 5.4.3.
func Kdf(z []byte, keyLen int) []byte {
limit := uint64(keyLen+Size-1) / uint64(Size)
if limit >= uint64(1<<32)-1 {
panic("sm3: key length too long")
}
var countBytes [4]byte
var ct uint32 = 1
var k []byte
baseMD := new(digest)
baseMD.Reset()
baseMD.Write(z)
for i := 0; i < int(limit); i++ {
binary.BigEndian.PutUint32(countBytes[:], ct)
md := *baseMD
md.Write(countBytes[:])
h := md.checkSum()
k = append(k, h[:]...)
ct++
}
return k[:keyLen]
}
71 changes: 71 additions & 0 deletions sm3/sm3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"fmt"
"hash"
"io"
"math/big"
"reflect"
"testing"

"golang.org/x/sys/cpu"
Expand Down Expand Up @@ -403,6 +405,75 @@ func BenchmarkHash8K_SH256(b *testing.B) {
benchmarkSize(benchSH256, b, 8192)
}

func TestKdf(t *testing.T) {
type args struct {
z []byte
len int
}
tests := []struct {
name string
args args
want string
}{
{"sm3 case 1", args{[]byte("emmansun"), 16}, "708993ef1388a0ae4245a19bb6c02554"},
{"sm3 case 2", args{[]byte("emmansun"), 32}, "708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd4"},
{"sm3 case 3", args{[]byte("emmansun"), 48}, "708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd47eba4fa460e7b277bc6b4ce4d07ed493"},
{"sm3 case 4", args{[]byte("708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd47eba4fa460e7b277bc6b4ce4d07ed493708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd47eba4fa460e7b277bc6b4ce4d07ed493"), 48}, "49cf14649f324a07e0d5bb2a00f7f05d5f5bdd6d14dff028e071327ec031104590eddb18f98b763e18bf382ff7c3875f"},
{"sm3 case 5", args{[]byte("708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd47eba4fa460e7b277bc6b4ce4d07ed493708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd47eba4fa460e7b277bc6b4ce4d07ed493"), 128}, "49cf14649f324a07e0d5bb2a00f7f05d5f5bdd6d14dff028e071327ec031104590eddb18f98b763e18bf382ff7c3875f30277f3179baebd795e7853fa643fdf280d8d7b81a2ab7829f615e132ab376d32194cd315908d27090e1180ce442d9be99322523db5bfac40ac5acb03550f5c93e5b01b1d71f2630868909a6a1250edb"},
}
for _, tt := range tests {
wantBytes, _ := hex.DecodeString(tt.want)
t.Run(tt.name, func(t *testing.T) {
if got := Kdf(tt.args.z, tt.args.len); !reflect.DeepEqual(got, wantBytes) {
t.Errorf("Kdf(%v) = %x, want %v", tt.name, got, tt.want)
}
})
}
}

func TestKdfOldCase(t *testing.T) {
x2, _ := new(big.Int).SetString("64D20D27D0632957F8028C1E024F6B02EDF23102A566C932AE8BD613A8E865FE", 16)
y2, _ := new(big.Int).SetString("58D225ECA784AE300A81A2D48281A828E1CEDF11C4219099840265375077BF78", 16)

expected := "006e30dae231b071dfad8aa379e90264491603"

result := Kdf(append(x2.Bytes(), y2.Bytes()...), 19)

resultStr := hex.EncodeToString(result)

if expected != resultStr {
t.Fatalf("expected %s, real value %s", expected, resultStr)
}
}

func BenchmarkKdfWithSM3(b *testing.B) {
tests := []struct {
zLen int
kLen int
}{
{32, 32},
{32, 64},
{32, 128},
{64, 32},
{64, 64},
{64, 128},
{64, 256},
{64, 512},
{64, 1024},
{64, 1024*8},
}
z := make([]byte, 512)
for _, tt := range tests {
b.Run(fmt.Sprintf("zLen=%v-kLen=%v", tt.zLen, tt.kLen), func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
Kdf(z[:tt.zLen], tt.kLen)
}
})
}
}

/*
func round1(a, b, c, d, e, f, g, h string, i int) {
fmt.Printf("//Round %d\n", i+1)
Expand Down
7 changes: 3 additions & 4 deletions sm9/sm9.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/emmansun/gmsm/internal/bigmod"
"github.com/emmansun/gmsm/internal/randutil"
"github.com/emmansun/gmsm/internal/subtle"
"github.com/emmansun/gmsm/kdf"
"github.com/emmansun/gmsm/sm3"
"github.com/emmansun/gmsm/sm9/bn256"
"golang.org/x/crypto/cryptobyte"
Expand Down Expand Up @@ -317,7 +316,7 @@ func WrapKey(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte,
buffer = append(buffer, w.Marshal()...)
buffer = append(buffer, uid...)

key = kdf.Kdf(sm3.New(), buffer, kLen)
key = sm3.Kdf(buffer, kLen)
if !subtle.ConstantTimeAllZero(key) {
break
}
Expand Down Expand Up @@ -403,7 +402,7 @@ func UnwrapKey(priv *EncryptPrivateKey, uid []byte, cipher *bn256.G1, kLen int)
buffer = append(buffer, w.Marshal()...)
buffer = append(buffer, uid...)

key := kdf.Kdf(sm3.New(), buffer, kLen)
key := sm3.Kdf(buffer, kLen)
if subtle.ConstantTimeAllZero(key) {
return nil, ErrDecryption
}
Expand Down Expand Up @@ -685,7 +684,7 @@ func (ke *KeyExchange) generateSharedKey(isResponder bool) ([]byte, error) {
buffer = append(buffer, ke.g2.Marshal()...)
buffer = append(buffer, ke.g3.Marshal()...)

return kdf.Kdf(sm3.New(), buffer, ke.keyLength), nil
return sm3.Kdf(buffer, ke.keyLength), nil
}

func respondKeyExchange(ke *KeyExchange, hid byte, r *bigmod.Nat, rA *bn256.G1) (*bn256.G1, []byte, error) {
Expand Down
7 changes: 3 additions & 4 deletions sm9/sm9_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (

"github.com/emmansun/gmsm/internal/bigmod"
"github.com/emmansun/gmsm/internal/subtle"
"github.com/emmansun/gmsm/kdf"
"github.com/emmansun/gmsm/sm3"
"github.com/emmansun/gmsm/sm9/bn256"
"golang.org/x/crypto/cryptobyte"
Expand Down Expand Up @@ -563,7 +562,7 @@ func TestWrapKeySM9Sample(t *testing.T) {
buffer = append(buffer, w.Marshal()...)
buffer = append(buffer, uid...)

key := kdf.Kdf(sm3.New(), buffer, 32)
key := sm3.Kdf(buffer, 32)

if hex.EncodeToString(key) != expectedKey {
t.Errorf("expected %v, got %v\n", expectedKey, hex.EncodeToString(key))
Expand Down Expand Up @@ -629,7 +628,7 @@ func TestEncryptSM9Sample(t *testing.T) {
buffer = append(buffer, w.Marshal()...)
buffer = append(buffer, uid...)

key := kdf.Kdf(sm3.New(), buffer, len(plaintext)+32)
key := sm3.Kdf(buffer, len(plaintext)+32)

if hex.EncodeToString(key) != expectedKey {
t.Errorf("not expected key")
Expand Down Expand Up @@ -697,7 +696,7 @@ func TestEncryptSM9SampleBlockMode(t *testing.T) {
buffer = append(buffer, w.Marshal()...)
buffer = append(buffer, uid...)

key := kdf.Kdf(sm3.New(), buffer, 16+32)
key := sm3.Kdf(buffer, 16+32)

if hex.EncodeToString(key) != expectedKey {
t.Errorf("not expected key, expected %v, got %x\n", expectedKey, key)
Expand Down

0 comments on commit c99ad27

Please sign in to comment.