diff --git a/ecdh/sm2ec.go b/ecdh/sm2ec.go index b6cc3ed8..4c124a0a 100644 --- a/ecdh/sm2ec.go +++ b/ecdh/sm2ec.go @@ -51,7 +51,7 @@ func (c *sm2Curve) NewPrivateKey(key []byte) (*PrivateKey, error) { if len(key) != len(c.scalarOrderMinus1) { return nil, errors.New("ecdh: invalid private key size") } - if subtle.ConstantTimeAllZero(key) || !isLess(key, c.scalarOrderMinus1) { + if subtle.ConstantTimeAllZero(key) == 1 || !isLess(key, c.scalarOrderMinus1) { return nil, errInvalidPrivateKey } return &PrivateKey{ diff --git a/internal/subtle/constant_time.go b/internal/subtle/constant_time.go index f70ef384..ac5cb22f 100644 --- a/internal/subtle/constant_time.go +++ b/internal/subtle/constant_time.go @@ -1,9 +1,9 @@ package subtle -func ConstantTimeAllZero(bytes []byte) bool { +func ConstantTimeAllZero(bytes []byte) int { var b uint8 for _, v := range bytes { b |= v } - return b == 0 + return int((uint32(b) - 1) >> 31) } diff --git a/internal/subtle/constant_time_test.go b/internal/subtle/constant_time_test.go index 81346082..e3b4d59f 100644 --- a/internal/subtle/constant_time_test.go +++ b/internal/subtle/constant_time_test.go @@ -1,6 +1,9 @@ package subtle -import "testing" +import ( + "fmt" + "testing" +) func TestConstantTimeAllZero(t *testing.T) { type args struct { @@ -9,10 +12,10 @@ func TestConstantTimeAllZero(t *testing.T) { tests := []struct { name string args args - want bool + want int }{ - {"all zero", args{[]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, true}, - {"not all zero", args{[]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 1}}, false}, + {"all zero", args{[]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, 1}, + {"not all zero", args{[]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 1}}, 0}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -22,3 +25,17 @@ func TestConstantTimeAllZero(t *testing.T) { }) } } + +func BenchmarkConstantTimeAllZero(b *testing.B) { + data := make([]byte, 1<<15) + sizes := []int64{1 << 3, 1 << 4, 1 << 5, 1 << 7, 1 << 11, 1 << 13, 1 << 15} + for _, size := range sizes { + b.Run(fmt.Sprintf("%dBytes", size), func(b *testing.B) { + s0 := data[:size] + b.SetBytes(int64(size)) + for i := 0; i < b.N; i++ { + ConstantTimeAllZero(s0) + } + }) + } +} diff --git a/sm2/sm2.go b/sm2/sm2.go index 15b5a473..526373e4 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -251,7 +251,7 @@ func encryptSM2EC(c *sm2Curve, pub *ecdsa.PublicKey, random io.Reader, msg []byt } C2Bytes := C2.Bytes()[1:] c2 := sm3.Kdf(C2Bytes, len(msg)) - if subtle.ConstantTimeAllZero(c2) { + if subtle.ConstantTimeAllZero(c2) == 1 { retryCount++ if retryCount > maxRetryLimit { return nil, fmt.Errorf("sm2: A5, failed to calculate valid t, tried %v times", retryCount) @@ -424,7 +424,7 @@ func decryptSM2EC(c *sm2Curve, priv *PrivateKey, ciphertext []byte, opts *Decryp C2Bytes := C2.Bytes()[1:] msgLen := len(c2) msg := sm3.Kdf(C2Bytes, msgLen) - if subtle.ConstantTimeAllZero(c2) { + if subtle.ConstantTimeAllZero(c2) == 1 { return nil, ErrDecryption } diff --git a/sm2/sm2_legacy.go b/sm2/sm2_legacy.go index 2cd8a98e..bd1a0bcc 100644 --- a/sm2/sm2_legacy.go +++ b/sm2/sm2_legacy.go @@ -260,7 +260,7 @@ func encryptLegacy(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Enc //A5, calculate t=KDF(x2||y2, klen) c2 := sm3.Kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) - if subtle.ConstantTimeAllZero(c2) { + if subtle.ConstantTimeAllZero(c2) == 1 { retryCount++ if retryCount > maxRetryLimit { return nil, fmt.Errorf("sm2: A5, failed to calculate valid t, tried %v times", retryCount) @@ -408,7 +408,7 @@ func rawDecrypt(priv *PrivateKey, x1, y1 *big.Int, c2, c3 []byte) ([]byte, error x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes()) msgLen := len(c2) msg := sm3.Kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) - if subtle.ConstantTimeAllZero(c2) { + if subtle.ConstantTimeAllZero(c2) == 1 { return nil, ErrDecryption } diff --git a/sm9/sm9.go b/sm9/sm9.go index bd0aaa1e..02c842f4 100644 --- a/sm9/sm9.go +++ b/sm9/sm9.go @@ -317,7 +317,7 @@ func WrapKey(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, buffer = append(buffer, uid...) key = sm3.Kdf(buffer, kLen) - if !subtle.ConstantTimeAllZero(key) { + if subtle.ConstantTimeAllZero(key) == 0 { break } } @@ -403,7 +403,7 @@ func UnwrapKey(priv *EncryptPrivateKey, uid []byte, cipher *bn256.G1, kLen int) buffer = append(buffer, uid...) key := sm3.Kdf(buffer, kLen) - if subtle.ConstantTimeAllZero(key) { + if subtle.ConstantTimeAllZero(key) == 1 { return nil, ErrDecryption } return key, nil