Skip to content

Commit 23c5678

Browse files
committed
RSA decrypt: don't write past buffer end on error
When the decrypted data is bigger than the buffer, the one extra bytes was being written to.
1 parent 59f4fa5 commit 23c5678

File tree

2 files changed

+27
-6
lines changed

2 files changed

+27
-6
lines changed

tests/api/test_rsa.c

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,15 +789,18 @@ int test_wc_RsaPublicEncryptDecrypt(void)
789789
WC_DECLARE_VAR(in, byte, TEST_STRING_SZ, NULL);
790790
WC_DECLARE_VAR(plain, byte, TEST_STRING_SZ, NULL);
791791
WC_DECLARE_VAR(cipher, byte, TEST_RSA_BYTES, NULL);
792+
WC_DECLARE_VAR(shortPlain, byte, TEST_STRING_SZ - 4, NULL);
792793

793794
WC_ALLOC_VAR(in, byte, TEST_STRING_SZ, NULL);
794795
WC_ALLOC_VAR(plain, byte, TEST_STRING_SZ, NULL);
795796
WC_ALLOC_VAR(cipher, byte, TEST_RSA_BYTES, NULL);
797+
WC_ALLOC_VAR(shortPlain, byte, TEST_STRING_SZ - 4, NULL);
796798

797799
#ifdef WC_DECLARE_VAR_IS_HEAP_ALLOC
798800
ExpectNotNull(in);
799801
ExpectNotNull(plain);
800802
ExpectNotNull(cipher);
803+
ExpectNotNull(shortPlain);
801804
#endif
802805
ExpectNotNull(XMEMCPY(in, inStr, inLen));
803806

@@ -824,6 +827,11 @@ int test_wc_RsaPublicEncryptDecrypt(void)
824827
ExpectIntEQ(XMEMCMP(plain, inStr, plainLen), 0);
825828
/* Pass bad args - tested in another testing function.*/
826829

830+
/* Test for when plain length is less than required. */
831+
ExpectIntEQ(wc_RsaPrivateDecrypt(cipher, cipherLenResult, shortPlain,
832+
TEST_STRING_SZ - 4, &key), RSA_BUFFER_E);
833+
834+
WC_FREE_VAR(shortPlain, NULL);
827835
WC_FREE_VAR(in, NULL);
828836
WC_FREE_VAR(plain, NULL);
829837
WC_FREE_VAR(cipher, NULL);

wolfcrypt/src/rsa.c

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3636,15 +3636,28 @@ static int RsaPrivateDecryptEx(const byte* in, word32 inLen, byte* out,
36363636
if (rsa_type == RSA_PRIVATE_DECRYPT) {
36373637
word32 i = 0;
36383638
word32 j;
3639+
byte last = 0;
36393640
int start = (int)((size_t)pad - (size_t)key->data);
36403641

36413642
for (j = 0; j < key->dataLen; j++) {
3642-
signed char c;
3643-
out[i] = key->data[j];
3644-
c = (signed char)ctMaskGTE((int)j, start);
3645-
c &= (signed char)ctMaskLT((int)i, (int)outLen);
3646-
/* 0 - no add, -1 add */
3647-
i += (word32)((byte)(-c));
3643+
signed char incMask;
3644+
signed char maskData;
3645+
3646+
/* When j < start + outLen then out[i] = key->data[j]
3647+
* else out[i] = last
3648+
*/
3649+
maskData = (signed char)ctMaskLT((int)j,
3650+
start + (int)outLen);
3651+
out[i] = (byte)(key->data[j] & maskData ) |
3652+
(byte)(last & (~maskData));
3653+
last = out[i];
3654+
3655+
/* Increment i when j is in range:
3656+
* [start..(start + outLen - 1)]. */
3657+
incMask = (signed char)ctMaskGTE((int)j, start);
3658+
incMask &= (signed char)ctMaskLT((int)j,
3659+
start + (int)outLen - 1);
3660+
i += (word32)((byte)(-incMask));
36483661
}
36493662
}
36503663
else

0 commit comments

Comments
 (0)