From afb5e27a15efe59e33c2825d40ef44995c13b8bc Mon Sep 17 00:00:00 2001 From: Helder Eijs Date: Wed, 20 Dec 2023 20:46:08 +0100 Subject: [PATCH] Fix side-channel leakage in RSA decryption --- lib/Crypto/Cipher/PKCS1_OAEP.py | 6 +- lib/Crypto/Cipher/PKCS1_v1_5.py | 7 +- lib/Crypto/Math/_IntegerBase.py | 20 +++ lib/Crypto/Math/_IntegerBase.pyi | 4 + lib/Crypto/Math/_IntegerCustom.py | 56 +++++++- lib/Crypto/Math/_IntegerGMP.py | 20 +++ lib/Crypto/Math/_IntegerNative.py | 12 ++ lib/Crypto/PublicKey/RSA.py | 10 +- lib/Crypto/SelfTest/Math/__init__.py | 2 + lib/Crypto/SelfTest/Math/test_Numbers.py | 28 ++++ lib/Crypto/SelfTest/Math/test_modmult.py | 120 ++++++++++++++++++ lib/Crypto/SelfTest/PublicKey/test_RSA.py | 4 +- .../SelfTest/PublicKey/test_import_RSA.py | 6 +- lib/Crypto/Signature/pkcs1_15.py | 9 +- lib/Crypto/Signature/pss.py | 11 +- src/bignum.c | 2 +- src/modexp.c | 68 +++++++++- 17 files changed, 350 insertions(+), 35 deletions(-) create mode 100644 lib/Crypto/SelfTest/Math/test_modmult.py diff --git a/lib/Crypto/Cipher/PKCS1_OAEP.py b/lib/Crypto/Cipher/PKCS1_OAEP.py index 57a982b85..8b0f89fc4 100644 --- a/lib/Crypto/Cipher/PKCS1_OAEP.py +++ b/lib/Crypto/Cipher/PKCS1_OAEP.py @@ -167,10 +167,8 @@ def decrypt(self, ciphertext): raise ValueError("Ciphertext with incorrect length.") # Step 2a (O2SIP) ct_int = bytes_to_long(ciphertext) - # Step 2b (RSADP) - m_int = self._key._decrypt(ct_int) - # Complete step 2c (I2OSP) - em = long_to_bytes(m_int, k) + # Step 2b (RSADP) and step 2c (I2OSP) + em = self._key._decrypt(ct_int) # Step 3a lHash = self._hashObj.new(self._label).digest() # Step 3b diff --git a/lib/Crypto/Cipher/PKCS1_v1_5.py b/lib/Crypto/Cipher/PKCS1_v1_5.py index d0d474a6f..db5e73153 100644 --- a/lib/Crypto/Cipher/PKCS1_v1_5.py +++ b/lib/Crypto/Cipher/PKCS1_v1_5.py @@ -176,11 +176,8 @@ def decrypt(self, ciphertext, sentinel, expected_pt_len=0): # Step 2a (O2SIP) ct_int = bytes_to_long(ciphertext) - # Step 2b (RSADP) - m_int = self._key._decrypt(ct_int) - - # Complete step 2c (I2OSP) - em = long_to_bytes(m_int, k) + # Step 2b (RSADP) and Step 2c (I2OSP) + em = self._key._decrypt(ct_int) # Step 3 (not constant time when the sentinel is not a byte string) output = bytes(bytearray(k)) diff --git a/lib/Crypto/Math/_IntegerBase.py b/lib/Crypto/Math/_IntegerBase.py index ec9cb478d..931743aa1 100644 --- a/lib/Crypto/Math/_IntegerBase.py +++ b/lib/Crypto/Math/_IntegerBase.py @@ -390,3 +390,23 @@ def random_range(cls, **kwargs): ) return norm_candidate + min_inclusive + @staticmethod + @abc.abstractmethod + def _mult_modulo_bytes(term1, term2, modulus): + """Multiply two integers, take the modulo, and encode as big endian. + This specialized method is used for RSA decryption. + + Args: + term1 : integer + The first term of the multiplication, non-negative. + term2 : integer + The second term of the multiplication, non-negative. + modulus: integer + The modulus, a positive odd number. + :Returns: + A byte string, with the result of the modular multiplication + encoded in big endian mode. + It is as long as the modulus would be, with zero padding + on the left if needed. + """ + pass diff --git a/lib/Crypto/Math/_IntegerBase.pyi b/lib/Crypto/Math/_IntegerBase.pyi index a42a48baf..ea235326f 100644 --- a/lib/Crypto/Math/_IntegerBase.pyi +++ b/lib/Crypto/Math/_IntegerBase.pyi @@ -60,4 +60,8 @@ class IntegerBase: def random(cls, **kwargs: Union[int,RandFunc]) -> IntegerBase : ... @classmethod def random_range(cls, **kwargs: Union[int,RandFunc]) -> IntegerBase : ... + @staticmethod + def _mult_modulo_bytes(term1: Union[IntegerBase, int], + term2: Union[IntegerBase, int], + modulus: Union[IntegerBase, int]) -> bytes: ... diff --git a/lib/Crypto/Math/_IntegerCustom.py b/lib/Crypto/Math/_IntegerCustom.py index d6f6f751a..7dfc2350d 100644 --- a/lib/Crypto/Math/_IntegerCustom.py +++ b/lib/Crypto/Math/_IntegerCustom.py @@ -41,12 +41,18 @@ from Crypto.Random.random import getrandbits c_defs = """ -int monty_pow(const uint8_t *base, - const uint8_t *exp, - const uint8_t *modulus, - uint8_t *out, - size_t len, - uint64_t seed); +int monty_pow(uint8_t *out, + const uint8_t *base, + const uint8_t *exp, + const uint8_t *modulus, + size_t len, + uint64_t seed); + +int monty_multiply(uint8_t *out, + const uint8_t *term1, + const uint8_t *term2, + const uint8_t *modulus, + size_t len); """ @@ -116,3 +122,41 @@ def inplace_pow(self, exponent, modulus=None): result = bytes_to_long(get_raw_buffer(out)) self._value = result return self + + @staticmethod + def _mult_modulo_bytes(term1, term2, modulus): + + # With modular reduction + mod_value = int(modulus) + if mod_value < 0: + raise ValueError("Modulus must be positive") + if mod_value == 0: + raise ZeroDivisionError("Modulus cannot be zero") + + # C extension only works with odd moduli + if (mod_value & 1) == 0: + raise ValueError("Odd modulus is required") + + # C extension only works with non-negative terms smaller than modulus + if term1 >= mod_value or term1 < 0: + term1 %= mod_value + if term2 >= mod_value or term2 < 0: + term2 %= mod_value + + modulus_b = long_to_bytes(mod_value) + numbers_len = len(modulus_b) + term1_b = long_to_bytes(term1, numbers_len) + term2_b = long_to_bytes(term2, numbers_len) + out = create_string_buffer(numbers_len) + + error = _raw_montgomery.monty_multiply( + out, + term1_b, + term2_b, + modulus_b, + c_size_t(numbers_len) + ) + if error: + raise ValueError("monty_multiply failed with error: %d" % error) + + return get_raw_buffer(out) diff --git a/lib/Crypto/Math/_IntegerGMP.py b/lib/Crypto/Math/_IntegerGMP.py index f552b71ad..e1b6d66cb 100644 --- a/lib/Crypto/Math/_IntegerGMP.py +++ b/lib/Crypto/Math/_IntegerGMP.py @@ -749,6 +749,26 @@ def jacobi_symbol(a, n): raise ValueError("n must be positive odd for the Jacobi symbol") return _gmp.mpz_jacobi(a._mpz_p, n._mpz_p) + @staticmethod + def _mult_modulo_bytes(term1, term2, modulus): + if not isinstance(term1, IntegerGMP): + term1 = IntegerGMP(term1) + if not isinstance(term2, IntegerGMP): + term2 = IntegerGMP(term2) + if not isinstance(modulus, IntegerGMP): + modulus = IntegerGMP(modulus) + + if modulus < 0: + raise ValueError("Modulus must be positive") + if modulus == 0: + raise ZeroDivisionError("Modulus cannot be zero") + if (modulus & 1) == 0: + raise ValueError("Odd modulus is required") + + numbers_len = len(modulus.to_bytes()) + result = ((term1 * term2) % modulus).to_bytes(numbers_len) + return result + # Clean-up def __del__(self): diff --git a/lib/Crypto/Math/_IntegerNative.py b/lib/Crypto/Math/_IntegerNative.py index 2173691ef..e9937d09b 100644 --- a/lib/Crypto/Math/_IntegerNative.py +++ b/lib/Crypto/Math/_IntegerNative.py @@ -368,3 +368,15 @@ def jacobi_symbol(a, n): n1 = n % a1 # Step 8 return s * IntegerNative.jacobi_symbol(n1, a1) + + @staticmethod + def _mult_modulo_bytes(term1, term2, modulus): + if modulus < 0: + raise ValueError("Modulus must be positive") + if modulus == 0: + raise ZeroDivisionError("Modulus cannot be zero") + if (modulus & 1) == 0: + raise ValueError("Odd modulus is required") + + number_len = len(long_to_bytes(modulus)) + return long_to_bytes((term1 * term2) % modulus, number_len) diff --git a/lib/Crypto/PublicKey/RSA.py b/lib/Crypto/PublicKey/RSA.py index f9beea4e0..9f4a0d35c 100644 --- a/lib/Crypto/PublicKey/RSA.py +++ b/lib/Crypto/PublicKey/RSA.py @@ -38,6 +38,7 @@ from Crypto import Random from Crypto.Util.py3compat import tobytes, bord, tostr from Crypto.Util.asn1 import DerSequence, DerNull +from Crypto.Util.number import bytes_to_long from Crypto.Math.Numbers import Integer from Crypto.Math.Primality import (test_probable_prime, @@ -198,10 +199,11 @@ def _decrypt(self, ciphertext): h = ((m2 - m1) * self._u) % self._q mp = h * self._p + m1 # Step 4: Compute m = m' * (r**(-1)) mod n - result = (r.inverse(self._n) * mp) % self._n - # Verify no faults occurred - if ciphertext != pow(result, self._e, self._n): - raise ValueError("Fault detected in RSA decryption") + # then encode into a big endian byte string + result = Integer._mult_modulo_bytes( + r.inverse(self._n), + mp, + self._n) return result def has_private(self): diff --git a/lib/Crypto/SelfTest/Math/__init__.py b/lib/Crypto/SelfTest/Math/__init__.py index 18e83d103..c72d7dc32 100644 --- a/lib/Crypto/SelfTest/Math/__init__.py +++ b/lib/Crypto/SelfTest/Math/__init__.py @@ -38,9 +38,11 @@ def get_tests(config={}): from Crypto.SelfTest.Math import test_Numbers from Crypto.SelfTest.Math import test_Primality from Crypto.SelfTest.Math import test_modexp + from Crypto.SelfTest.Math import test_modmult tests += test_Numbers.get_tests(config=config) tests += test_Primality.get_tests(config=config) tests += test_modexp.get_tests(config=config) + tests += test_modmult.get_tests(config=config) return tests if __name__ == '__main__': diff --git a/lib/Crypto/SelfTest/Math/test_Numbers.py b/lib/Crypto/SelfTest/Math/test_Numbers.py index 924eca48d..7609485a4 100644 --- a/lib/Crypto/SelfTest/Math/test_Numbers.py +++ b/lib/Crypto/SelfTest/Math/test_Numbers.py @@ -696,6 +696,34 @@ def test_hex(self): v1, = self.Integers(0x10) self.assertEqual(hex(v1), "0x10") + def test_mult_modulo_bytes(self): + modmult = self.Integer._mult_modulo_bytes + + res = modmult(4, 5, 19) + self.assertEqual(res, b'\x01') + + res = modmult(4 - 19, 5, 19) + self.assertEqual(res, b'\x01') + + res = modmult(4, 5 - 19, 19) + self.assertEqual(res, b'\x01') + + res = modmult(4 + 19, 5, 19) + self.assertEqual(res, b'\x01') + + res = modmult(4, 5 + 19, 19) + self.assertEqual(res, b'\x01') + + modulus = 2**512 - 1 # 64 bytes + t1 = 13**100 + t2 = 17**100 + expect = b"\xfa\xb2\x11\x87\xc3(y\x07\xf8\xf1n\xdepq\x0b\xca\xf3\xd3B,\xef\xf2\xfbf\xcc)\x8dZ*\x95\x98r\x96\xa8\xd5\xc3}\xe2q:\xa2'z\xf48\xde%\xef\t\x07\xbc\xc4[C\x8bUE2\x90\xef\x81\xaa:\x08" + self.assertEqual(expect, modmult(t1, t2, modulus)) + + self.assertRaises(ZeroDivisionError, modmult, 4, 5, 0) + self.assertRaises(ValueError, modmult, 4, 5, -1) + self.assertRaises(ValueError, modmult, 4, 5, 4) + class TestIntegerInt(TestIntegerBase): diff --git a/lib/Crypto/SelfTest/Math/test_modmult.py b/lib/Crypto/SelfTest/Math/test_modmult.py new file mode 100644 index 000000000..66aa3cd18 --- /dev/null +++ b/lib/Crypto/SelfTest/Math/test_modmult.py @@ -0,0 +1,120 @@ +# +# SelfTest/Math/test_modmult.py: Self-test for custom modular multiplication +# +# =================================================================== +# +# Copyright (c) 2023, Helder Eijs +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in +# the documentation and/or other materials provided with the +# distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# =================================================================== + +"""Self-test for the custom modular multiplication""" + +import unittest + +from Crypto.SelfTest.st_common import list_test_cases + +from Crypto.Util.number import long_to_bytes, bytes_to_long + +from Crypto.Util._raw_api import (create_string_buffer, + get_raw_buffer, + c_size_t) + +from Crypto.Math._IntegerCustom import _raw_montgomery + + +class ExceptionModulus(ValueError): + pass + + +def monty_mult(term1, term2, modulus): + + if term1 >= modulus: + term1 %= modulus + if term2 >= modulus: + term2 %= modulus + + modulus_b = long_to_bytes(modulus) + numbers_len = len(modulus_b) + term1_b = long_to_bytes(term1, numbers_len) + term2_b = long_to_bytes(term2, numbers_len) + + out = create_string_buffer(numbers_len) + error = _raw_montgomery.monty_multiply( + out, + term1_b, + term2_b, + modulus_b, + c_size_t(numbers_len) + ) + + if error == 17: + raise ExceptionModulus() + if error: + raise ValueError("monty_multiply() failed with error: %d" % error) + + return get_raw_buffer(out) + + +modulus1 = 0xd66691b20071be4d66d4b71032b37fa007cfabf579fcb91e50bfc2753b3f0ce7be74e216aef7e26d4ae180bc20d7bd3ea88a6cbf6f87380e613c8979b5b043b200a8ff8856a3b12875e36e98a7569f3852d028e967551000b02c19e9fa52e83115b89309aabb1e1cf1e2cb6369d637d46775ce4523ea31f64ad2794cbc365dd8a35e007ed3b57695877fbf102dbeb8b3212491398e494314e93726926e1383f8abb5889bea954eb8c0ca1c62c8e9d83f41888095c5e645ed6d32515fe0c58c1368cad84694e18da43668c6f43e61d7c9bca633ddcda7aef5b79bc396d4a9f48e2a9abe0836cc455e435305357228e93d25aaed46b952defae0f57339bf26f5a9 + + +class TestModMultiply(unittest.TestCase): + + def test_small(self): + self.assertEqual(b"\x01", monty_mult(5, 6, 29)) + + def test_large(self): + numbers_len = (modulus1.bit_length() + 7) // 8 + + t1 = modulus1 // 2 + t2 = modulus1 - 90 + expect = b'\x00' * (numbers_len - 1) + b'\x2d' + self.assertEqual(expect, monty_mult(t1, t2, modulus1)) + + def test_zero_term(self): + numbers_len = (modulus1.bit_length() + 7) // 8 + expect = b'\x00' * numbers_len + self.assertEqual(expect, monty_mult(0x100, 0, modulus1)) + self.assertEqual(expect, monty_mult(0, 0x100, modulus1)) + + def test_larger_term(self): + t1 = 2**2047 + expect_int = 0x8edf4071f78e3d7ba622cdbbbef74612e301d69186776ae6bf87ff38c320d9aebaa64889c2f67de2324e6bccd2b10ad89e91fd21ba4bb523904d033eff5e70e62f01a84f41fa90a4f248ef249b82e1d2729253fdfc2a3b5b740198123df8bfbf7057d03e15244ad5f26eb9a099763b5c5972121ec076b0bf899f59bd95f7cc129abddccf24217bce52ca0f3a44c9ccc504765dbb89734205f3ae6a8cc560494a60ea84b27d8e00fa24bdd5b4f1d4232edb61e47d3d984c1fa50a3820a2e580fbc3fc8bc11e99df53b9efadf5a40ac75d384e400905aa6f1d88950cd53b1c54dc2222115ad84a27260fa4d978155c1434c551de1ee7361a17a2f79d4388f78a5d + res = bytes_to_long(monty_mult(t1, t1, modulus1)) + self.assertEqual(res, expect_int) + + +def get_tests(config={}): + tests = [] + tests += list_test_cases(TestModMultiply) + return tests + + +if __name__ == '__main__': + def suite(): + return unittest.TestSuite(get_tests()) + unittest.main(defaultTest='suite') diff --git a/lib/Crypto/SelfTest/PublicKey/test_RSA.py b/lib/Crypto/SelfTest/PublicKey/test_RSA.py index 38616af09..f85e9dcde 100644 --- a/lib/Crypto/SelfTest/PublicKey/test_RSA.py +++ b/lib/Crypto/SelfTest/PublicKey/test_RSA.py @@ -279,7 +279,7 @@ def _exercise_primitive(self, rsaObj): ciphertext = bytes_to_long(a2b_hex(self.ciphertext)) # Test decryption - plaintext = rsaObj._decrypt(ciphertext) + plaintext = bytes_to_long(rsaObj._decrypt(ciphertext)) # Test encryption (2 arguments) new_ciphertext2 = rsaObj._encrypt(plaintext) @@ -304,7 +304,7 @@ def _check_decryption(self, rsaObj): ciphertext = bytes_to_long(a2b_hex(self.ciphertext)) # Test plain decryption - new_plaintext = rsaObj._decrypt(ciphertext) + new_plaintext = bytes_to_long(rsaObj._decrypt(ciphertext)) self.assertEqual(plaintext, new_plaintext) diff --git a/lib/Crypto/SelfTest/PublicKey/test_import_RSA.py b/lib/Crypto/SelfTest/PublicKey/test_import_RSA.py index fa92fb0aa..06f7e97db 100644 --- a/lib/Crypto/SelfTest/PublicKey/test_import_RSA.py +++ b/lib/Crypto/SelfTest/PublicKey/test_import_RSA.py @@ -29,7 +29,7 @@ from Crypto.PublicKey import RSA from Crypto.SelfTest.st_common import a2b_hex, list_test_cases from Crypto.Util.py3compat import b, tostr, FileNotFoundError -from Crypto.Util.number import inverse +from Crypto.Util.number import inverse, bytes_to_long from Crypto.Util import asn1 try: @@ -239,13 +239,13 @@ def testImportKey4bytes(self): def testImportKey5(self): """Verifies that the imported key is still a valid RSA pair""" key = RSA.importKey(self.rsaKeyPEM) - idem = key._encrypt(key._decrypt(89)) + idem = key._encrypt(bytes_to_long(key._decrypt(89))) self.assertEqual(idem, 89) def testImportKey6(self): """Verifies that the imported key is still a valid RSA pair""" key = RSA.importKey(self.rsaKeyDER) - idem = key._encrypt(key._decrypt(65)) + idem = key._encrypt(bytes_to_long(key._decrypt(65))) self.assertEqual(idem, 65) def testImportKey7(self): diff --git a/lib/Crypto/Signature/pkcs1_15.py b/lib/Crypto/Signature/pkcs1_15.py index 40726f47d..90da4ab6e 100644 --- a/lib/Crypto/Signature/pkcs1_15.py +++ b/lib/Crypto/Signature/pkcs1_15.py @@ -77,10 +77,11 @@ def sign(self, msg_hash): em = _EMSA_PKCS1_V1_5_ENCODE(msg_hash, k) # Step 2a (OS2IP) em_int = bytes_to_long(em) - # Step 2b (RSASP1) - m_int = self._key._decrypt(em_int) - # Step 2c (I2OSP) - signature = long_to_bytes(m_int, k) + # Step 2b (RSASP1) and Step 2c (I2OSP) + signature = self._key._decrypt(em_int) + # Verify no faults occurred + if em_int != pow(bytes_to_long(signature), self._key.e, self._key.n): + raise ValueError("Fault detected in RSA private key operation") return signature def verify(self, msg_hash, signature): diff --git a/lib/Crypto/Signature/pss.py b/lib/Crypto/Signature/pss.py index 5f34ace78..eeb7082ed 100644 --- a/lib/Crypto/Signature/pss.py +++ b/lib/Crypto/Signature/pss.py @@ -107,10 +107,11 @@ def sign(self, msg_hash): em = _EMSA_PSS_ENCODE(msg_hash, modBits-1, self._randfunc, mgf, sLen) # Step 2a (OS2IP) em_int = bytes_to_long(em) - # Step 2b (RSASP1) - m_int = self._key._decrypt(em_int) - # Step 2c (I2OSP) - signature = long_to_bytes(m_int, k) + # Step 2b (RSASP1) and Step 2c (I2OSP) + signature = self._key._decrypt(em_int) + # Verify no faults occurred + if em_int != pow(bytes_to_long(signature), self._key.e, self._key.n): + raise ValueError("Fault detected in RSA private key operation") return signature def verify(self, msg_hash, signature): @@ -178,7 +179,7 @@ def MGF1(mgfSeed, maskLen, hash_gen): :return: the mask, as a *byte string* """ - + T = b"" for counter in iter_range(ceil_div(maskLen, hash_gen.digest_size)): c = long_to_bytes(counter, 4) diff --git a/src/bignum.c b/src/bignum.c index f74c87ae8..b4004acf6 100644 --- a/src/bignum.c +++ b/src/bignum.c @@ -165,7 +165,7 @@ STATIC int mod_select(uint64_t *out, const uint64_t *a, const uint64_t *b, unsig mask = (uint64_t)((cond != 0) - 1); /* 0 for a, 1s for b */ #if SYS_BITS == 64 - r0 = _mm_set1_epi64x(mask); + r0 = _mm_set1_epi64x((int64_t)mask); #else r0 = _mm_loadl_epi64((__m128i*)&mask); r0 = _mm_unpacklo_epi64(r0, r0); diff --git a/src/modexp.c b/src/modexp.c index c337c0194..f4df96acc 100644 --- a/src/modexp.c +++ b/src/modexp.c @@ -179,6 +179,71 @@ EXPORT_SYM int monty_pow( return res; } +/* + * Modular multiplication. All numbers are + * encoded in big endian form, possibly with + * zero padding on the left. + * + * @param out The memory area where to store the result + * @param term1 First term of the multiplication, strictly smaller than the modulus + * @param term2 Second term of the multiplication, strictly smaller than the modulus + * @param modulus Modulus, it must be odd + * @param len Size in bytes of out, term1, term2, and modulus + * @return 0 in case of success, the appropriate error code otherwise + */ +EXPORT_SYM int monty_multiply( + uint8_t *out, + const uint8_t *term1, + const uint8_t *term2, + const uint8_t *modulus, + size_t len) +{ + MontContext *ctx = NULL; + uint64_t *mont_term1 = NULL; + uint64_t *mont_term2 = NULL; + uint64_t *mont_output = NULL; + uint64_t *scratchpad = NULL; + int res; + + if (!term1 || !term2 || !modulus || !out) + return ERR_NULL; + + if (len == 0) + return ERR_NOT_ENOUGH_DATA; + + /* Allocations **/ + res = mont_context_init(&ctx, modulus, len); + if (res) + return res; + + res = mont_from_bytes(&mont_term1, term1, len, ctx); + if (res) goto cleanup; + + res = mont_from_bytes(&mont_term2, term2, len, ctx); + if (res) goto cleanup; + + res = mont_number(&mont_output, 1, ctx); + if (res) goto cleanup; + + res = mont_number(&scratchpad, SCRATCHPAD_NR, ctx); + if (res) goto cleanup; + + /* Multiply, then transform result back into big-endian, byte form **/ + res = mont_mult(mont_output, mont_term1, mont_term2, scratchpad, ctx); + if (res) goto cleanup; + + res = mont_to_bytes(out, len, mont_output, ctx); + +cleanup: + mont_context_free(ctx); + free(mont_term1); + free(mont_term2); + free(mont_output); + free(scratchpad); + + return res; +} + #ifdef MAIN int main(void) { @@ -205,7 +270,7 @@ int main(void) res = fread(out, 1, length, stdin); result = monty_pow(out, base, exponent, modulus, length, 12); - + free(base); free(modulus); free(exponent); @@ -232,5 +297,6 @@ int main(void) monty_pow(out, base, exponent, modulus, length, 12); } + monty_multiply(out, base, out, modulus, length); } #endif