diff --git a/crypto/paillier/paillier.go b/crypto/paillier/paillier.go index e2461c19..bf7e8399 100644 --- a/crypto/paillier/paillier.go +++ b/crypto/paillier/paillier.go @@ -52,7 +52,8 @@ type ( ) var ( - ErrMessageTooLong = fmt.Errorf("the message is too large or < 0") + ErrMessageTooLong = fmt.Errorf("the message is too large or < 0") + ErrMessageMalFormed = fmt.Errorf("the message is mal-formed") zero = big.NewInt(0) one = big.NewInt(1) @@ -173,6 +174,10 @@ func (privateKey *PrivateKey) Decrypt(c *big.Int) (m *big.Int, err error) { if c.Cmp(zero) == -1 || c.Cmp(N2) != -1 { // c < 0 || c >= N2 ? return nil, ErrMessageTooLong } + cg := new(big.Int).GCD(nil, nil, c, N2) + if cg.Cmp(one) == 1 { + return nil, ErrMessageMalFormed + } // 1. L(u) = (c^LambdaN-1 mod N2) / N Lc := L(new(big.Int).Exp(c, privateKey.LambdaN, N2), privateKey.N) // 2. L(u) = (Gamma^LambdaN-1 mod N2) / N diff --git a/crypto/paillier/paillier_test.go b/crypto/paillier/paillier_test.go index 053966ca..3914d9bf 100644 --- a/crypto/paillier/paillier_test.go +++ b/crypto/paillier/paillier_test.go @@ -64,6 +64,10 @@ func TestEncryptDecrypt(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 0, exp.Cmp(ret), "wrong decryption ", ret, " is not ", exp) + + cypher = new(big.Int).Set(privateKey.N) + _, err = privateKey.Decrypt(cypher) + assert.Error(t, err) } func TestHomoMul(t *testing.T) {