Skip to content

Commit

Permalink
Add possibility to use different hash algorithms in RSAES-OAEP
Browse files Browse the repository at this point in the history
The hash algorithms used in the MGF and to create the hash of the Label
must not forcibly be the same. This change allows to use different
algorithms.

Unfortunately this breaks the API if you use one of:
* `rsa_decrypt_key_ex()`
* `rsa_encrypt_key_ex()`
* `pkcs_1_oaep_decode()`
* `pkcs_1_oaep_encode()`

The `rsa_decrypt_key()` and `rsa_encrypt_key()` macros are still the same.

Signed-off-by: Steffen Jaeckel <s@jaeckel.eu>
  • Loading branch information
sjaeckel committed Oct 9, 2023
1 parent 91b7bbe commit 63091c9
Show file tree
Hide file tree
Showing 10 changed files with 137 additions and 87 deletions.
20 changes: 13 additions & 7 deletions doc/crypt.tex
Original file line number Diff line number Diff line change
Expand Up @@ -4190,7 +4190,8 @@ \subsection{OAEP Encoding}
unsigned long modulus_bitlen,
prng_state *prng,
int prng_idx,
int hash_idx,
int mgf_hash,
int lparam_hash,
unsigned char *out,
unsigned long *outlen);
\end{alltt}
Expand All @@ -4200,7 +4201,9 @@ \subsection{OAEP Encoding}
\textit{lparam} can be set to \textbf{NULL}.

OAEP encoding requires the length of the modulus in bits in order to calculate the size of the output. This is passed as the parameter
\textit{modulus\_bitlen}. \textit{hash\_idx} is the index into the hash descriptor table of the hash desired. PKCS \#1 allows any hash to be
\textit{modulus\_bitlen}. \textit{mgf\_hash} is the index into the hash descriptor table of the hash desired for the mask generation function (MGF).
\textit{lparam\_hash} is the index into the hash descriptor table of the hash desired for the \textit{lparam}. This value can also be set to $-1$
to indicate usage of the same algorithm than for the MGF. PKCS \#1 allows any hash to be
used but both the encoder and decoder must use the same hash in order for this to succeed. The size of hash output affects the maximum
sized input message. \textit{prng\_idx} and \textit{prng} are the random number generator arguments required to randomize the padding process.
The padded message is stored in \textit{out} along with the length in \textit{outlen}.
Expand All @@ -4221,7 +4224,8 @@ \subsection{OAEP Decoding}
const unsigned char *lparam,
unsigned long lparamlen,
unsigned long modulus_bitlen,
int hash_idx,
int mgf_hash,
int lparam_hash,
unsigned char *out,
unsigned long *outlen,
int *res);
Expand All @@ -4230,8 +4234,8 @@ \subsection{OAEP Decoding}
This function decodes an OAEP encoded message and outputs the original message that was passed to the OAEP encoder. \textit{msg} is the
output of pkcs\_1\_oaep\_encode() of length \textit{msglen}. \textit{lparam} is the same system variable passed to the OAEP encoder. If it does not
match what was used during encoding this function will not decode the packet. \textit{modulus\_bitlen} is the size of the RSA modulus in bits
and must match what was used during encoding. Similarly the \textit{hash\_idx} index into the hash descriptor table must match what was used
during encoding.
and must match what was used during encoding. Similarly the \textit{mgf\_hash} and \textit{lparam\_hash} indexes into the hash descriptor table must
match what was used during encoding.

If the function succeeds it decodes the OAEP encoded message into \textit{out} of length \textit{outlen} and stores a
$1$ in \textit{res}. If the packet is invalid it stores $0$ in \textit{res} and if the function fails for another reason
Expand Down Expand Up @@ -4426,7 +4430,8 @@ \subsection{Extended Encryption}
unsigned long lparamlen,
prng_state *prng,
int prng_idx,
int hash_idx,
int mgf_hash,
int lparam_hash,
int padding,
rsa_key *key);
\end{verbatim}
Expand All @@ -4447,7 +4452,8 @@ \subsection{Extended Encryption}
unsigned long *outlen,
const unsigned char *lparam,
unsigned long lparamlen,
int hash_idx,
int mgf_hash,
int lparam_hash,
int *stat,
rsa_key *key);
\end{verbatim}
Expand Down
10 changes: 6 additions & 4 deletions src/headers/tomcrypt_pk.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ void rsa_free(rsa_key *key);

/* These use PKCS #1 v2.0 padding */
#define rsa_encrypt_key(in, inlen, out, outlen, lparam, lparamlen, prng, prng_idx, hash_idx, key) \
rsa_encrypt_key_ex(in, inlen, out, outlen, lparam, lparamlen, prng, prng_idx, hash_idx, LTC_PKCS_1_OAEP, key)
rsa_encrypt_key_ex(in, inlen, out, outlen, lparam, lparamlen, prng, prng_idx, hash_idx, -1, LTC_PKCS_1_OAEP, key)

#define rsa_decrypt_key(in, inlen, out, outlen, lparam, lparamlen, hash_idx, stat, key) \
rsa_decrypt_key_ex(in, inlen, out, outlen, lparam, lparamlen, hash_idx, LTC_PKCS_1_OAEP, stat, key)
rsa_decrypt_key_ex(in, inlen, out, outlen, lparam, lparamlen, hash_idx, -1, LTC_PKCS_1_OAEP, stat, key)

#define rsa_sign_hash(in, inlen, out, outlen, prng, prng_idx, hash_idx, saltlen, key) \
rsa_sign_hash_ex(in, inlen, out, outlen, LTC_PKCS_1_PSS, prng, prng_idx, hash_idx, saltlen, key)
Expand All @@ -76,13 +76,15 @@ int rsa_encrypt_key_ex(const unsigned char *in, unsigned long inlen,
unsigned char *out, unsigned long *outlen,
const unsigned char *lparam, unsigned long lparamlen,
prng_state *prng, int prng_idx,
int hash_idx, int padding,
int mgf_hash, int lparam_hash,
int padding,
const rsa_key *key);

int rsa_decrypt_key_ex(const unsigned char *in, unsigned long inlen,
unsigned char *out, unsigned long *outlen,
const unsigned char *lparam, unsigned long lparamlen,
int hash_idx, int padding,
int mgf_hash, int lparam_hash,
int padding,
int *stat, const rsa_key *key);

int rsa_sign_hash_ex(const unsigned char *in, unsigned long inlen,
Expand Down
6 changes: 4 additions & 2 deletions src/headers/tomcrypt_pkcs.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,14 @@ int pkcs_1_v1_5_decode(const unsigned char *msg,
int pkcs_1_oaep_encode(const unsigned char *msg, unsigned long msglen,
const unsigned char *lparam, unsigned long lparamlen,
unsigned long modulus_bitlen, prng_state *prng,
int prng_idx, int hash_idx,
int prng_idx,
int mgf_hash, int lparam_hash,
unsigned char *out, unsigned long *outlen);

int pkcs_1_oaep_decode(const unsigned char *msg, unsigned long msglen,
const unsigned char *lparam, unsigned long lparamlen,
unsigned long modulus_bitlen, int hash_idx,
unsigned long modulus_bitlen,
int mgf_hash, int lparam_hash,
unsigned char *out, unsigned long *outlen,
int *res);

Expand Down
28 changes: 19 additions & 9 deletions src/pk/pkcs1/pkcs_1_oaep_decode.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,23 @@
@param lparam The session or system data (can be NULL)
@param lparamlen The length of the lparam
@param modulus_bitlen The bit length of the RSA modulus
@param hash_idx The index of the hash desired
@param mgf_hash The hash algorithm used for the MGF
@param lparam_hash The hash algorithm used when hashing the lparam (can be -1)
@param out [out] Destination of decoding
@param outlen [in/out] The max size and resulting size of the decoding
@param res [out] Result of decoding, 1==valid, 0==invalid
@return CRYPT_OK if successful
*/
int pkcs_1_oaep_decode(const unsigned char *msg, unsigned long msglen,
const unsigned char *lparam, unsigned long lparamlen,
unsigned long modulus_bitlen, int hash_idx,
unsigned long modulus_bitlen,
int mgf_hash, int lparam_hash,
unsigned char *out, unsigned long *outlen,
int *res)
{
unsigned char *DB, *seed, *mask;
unsigned long hLen, x, y, modulus_len;
int err, ret;
int err, ret, lparam_hash_used;

LTC_ARGCHK(msg != NULL);
LTC_ARGCHK(out != NULL);
Expand All @@ -41,10 +43,18 @@ int pkcs_1_oaep_decode(const unsigned char *msg, unsigned long msglen,
*res = 0;

/* test valid hash */
if ((err = hash_is_valid(hash_idx)) != CRYPT_OK) {
if ((err = hash_is_valid(mgf_hash)) != CRYPT_OK) {
return err;
}
hLen = hash_descriptor[hash_idx].hashsize;
if (lparam_hash != -1) {
if ((err = hash_is_valid(lparam_hash)) != CRYPT_OK) {
return err;
}
lparam_hash_used = lparam_hash;
} else {
lparam_hash_used = mgf_hash;
}
hLen = hash_descriptor[lparam_hash_used].hashsize;
modulus_len = (modulus_bitlen >> 3) + (modulus_bitlen & 7 ? 1 : 0);

/* test hash/message size */
Expand Down Expand Up @@ -94,7 +104,7 @@ int pkcs_1_oaep_decode(const unsigned char *msg, unsigned long msglen,
x += modulus_len - hLen - 1;

/* compute MGF1 of maskedDB (hLen) */
if ((err = pkcs_1_mgf1(hash_idx, DB, modulus_len - hLen - 1, mask, hLen)) != CRYPT_OK) {
if ((err = pkcs_1_mgf1(mgf_hash, DB, modulus_len - hLen - 1, mask, hLen)) != CRYPT_OK) {
goto LBL_ERR;
}

Expand All @@ -104,7 +114,7 @@ int pkcs_1_oaep_decode(const unsigned char *msg, unsigned long msglen,
}

/* compute MGF1 of seed (k - hlen - 1) */
if ((err = pkcs_1_mgf1(hash_idx, seed, hLen, mask, modulus_len - hLen - 1)) != CRYPT_OK) {
if ((err = pkcs_1_mgf1(mgf_hash, seed, hLen, mask, modulus_len - hLen - 1)) != CRYPT_OK) {
goto LBL_ERR;
}

Expand All @@ -118,12 +128,12 @@ int pkcs_1_oaep_decode(const unsigned char *msg, unsigned long msglen,
/* compute lhash and store it in seed [reuse temps!] */
x = modulus_len;
if (lparam != NULL) {
if ((err = hash_memory(hash_idx, lparam, lparamlen, seed, &x)) != CRYPT_OK) {
if ((err = hash_memory(lparam_hash_used, lparam, lparamlen, seed, &x)) != CRYPT_OK) {
goto LBL_ERR;
}
} else {
/* can't pass hash_memory a NULL so use DB with zero length */
if ((err = hash_memory(hash_idx, DB, 0, seed, &x)) != CRYPT_OK) {
if ((err = hash_memory(lparam_hash_used, DB, 0, seed, &x)) != CRYPT_OK) {
goto LBL_ERR;
}
}
Expand Down
25 changes: 17 additions & 8 deletions src/pk/pkcs1/pkcs_1_oaep_encode.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,37 @@
int pkcs_1_oaep_encode(const unsigned char *msg, unsigned long msglen,
const unsigned char *lparam, unsigned long lparamlen,
unsigned long modulus_bitlen, prng_state *prng,
int prng_idx, int hash_idx,
int prng_idx,
int mgf_hash, int lparam_hash,
unsigned char *out, unsigned long *outlen)
{
unsigned char *DB, *seed, *mask;
unsigned long hLen, x, y, modulus_len;
int err;
int err, lparam_hash_used;

LTC_ARGCHK((msglen == 0) || (msg != NULL));
LTC_ARGCHK(out != NULL);
LTC_ARGCHK(outlen != NULL);

/* test valid hash */
if ((err = hash_is_valid(hash_idx)) != CRYPT_OK) {
if ((err = hash_is_valid(mgf_hash)) != CRYPT_OK) {
return err;
}
if (lparam_hash != -1) {
if ((err = hash_is_valid(lparam_hash)) != CRYPT_OK) {
return err;
}
lparam_hash_used = lparam_hash;
} else {
lparam_hash_used = mgf_hash;
}

/* valid prng */
if ((err = prng_is_valid(prng_idx)) != CRYPT_OK) {
return err;
}

hLen = hash_descriptor[hash_idx].hashsize;
hLen = hash_descriptor[lparam_hash_used].hashsize;
modulus_len = (modulus_bitlen >> 3) + (modulus_bitlen & 7 ? 1 : 0);

/* test message size */
Expand Down Expand Up @@ -76,12 +85,12 @@ int pkcs_1_oaep_encode(const unsigned char *msg, unsigned long msglen,
/* DB == lhash || PS || 0x01 || M, PS == k - mlen - 2hlen - 2 zeroes */
x = modulus_len;
if (lparam != NULL) {
if ((err = hash_memory(hash_idx, lparam, lparamlen, DB, &x)) != CRYPT_OK) {
if ((err = hash_memory(lparam_hash_used, lparam, lparamlen, DB, &x)) != CRYPT_OK) {
goto LBL_ERR;
}
} else {
/* can't pass hash_memory a NULL so use DB with zero length */
if ((err = hash_memory(hash_idx, DB, 0, DB, &x)) != CRYPT_OK) {
if ((err = hash_memory(lparam_hash_used, DB, 0, DB, &x)) != CRYPT_OK) {
goto LBL_ERR;
}
}
Expand All @@ -108,7 +117,7 @@ int pkcs_1_oaep_encode(const unsigned char *msg, unsigned long msglen,
}

/* compute MGF1 of seed (k - hlen - 1) */
if ((err = pkcs_1_mgf1(hash_idx, seed, hLen, mask, modulus_len - hLen - 1)) != CRYPT_OK) {
if ((err = pkcs_1_mgf1(mgf_hash, seed, hLen, mask, modulus_len - hLen - 1)) != CRYPT_OK) {
goto LBL_ERR;
}

Expand All @@ -118,7 +127,7 @@ int pkcs_1_oaep_encode(const unsigned char *msg, unsigned long msglen,
}

/* compute MGF1 of maskedDB (hLen) */
if ((err = pkcs_1_mgf1(hash_idx, DB, modulus_len - hLen - 1, mask, hLen)) != CRYPT_OK) {
if ((err = pkcs_1_mgf1(mgf_hash, DB, modulus_len - hLen - 1, mask, hLen)) != CRYPT_OK) {
goto LBL_ERR;
}

Expand Down
13 changes: 7 additions & 6 deletions src/pk/rsa/rsa_decrypt_key.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
@param outlen [in/out] The max size and resulting size of the plaintext (octets)
@param lparam The system "lparam" value
@param lparamlen The length of the lparam value (octets)
@param hash_idx The index of the hash desired
@param mgf_hash The hash algorithm used for the MGF
@param lparam_hash The hash algorithm used when hashing the lparam (can be -1)
@param padding Type of padding (LTC_PKCS_1_OAEP or LTC_PKCS_1_V1_5)
@param stat [out] Result of the decryption, 1==valid, 0==invalid
@param key The corresponding private RSA key
Expand All @@ -26,7 +27,8 @@
int rsa_decrypt_key_ex(const unsigned char *in, unsigned long inlen,
unsigned char *out, unsigned long *outlen,
const unsigned char *lparam, unsigned long lparamlen,
int hash_idx, int padding,
int mgf_hash, int lparam_hash,
int padding,
int *stat, const rsa_key *key)
{
unsigned long modulus_bitlen, modulus_bytelen, x;
Expand All @@ -43,15 +45,14 @@ int rsa_decrypt_key_ex(const unsigned char *in, unsigned long inlen
*stat = 0;

/* valid padding? */

if ((padding != LTC_PKCS_1_V1_5) &&
(padding != LTC_PKCS_1_OAEP)) {
return CRYPT_PK_INVALID_PADDING;
}

if (padding == LTC_PKCS_1_OAEP) {
/* valid hash ? */
if ((err = hash_is_valid(hash_idx)) != CRYPT_OK) {
if ((err = hash_is_valid(mgf_hash)) != CRYPT_OK) {
return err;
}
}
Expand Down Expand Up @@ -80,8 +81,8 @@ int rsa_decrypt_key_ex(const unsigned char *in, unsigned long inlen

if (padding == LTC_PKCS_1_OAEP) {
/* now OAEP decode the packet */
err = pkcs_1_oaep_decode(tmp, x, lparam, lparamlen, modulus_bitlen, hash_idx,
out, outlen, stat);
err = pkcs_1_oaep_decode(tmp, x, lparam, lparamlen, modulus_bitlen, mgf_hash,
lparam_hash, out, outlen, stat);
} else {
/* now PKCS #1 v1.5 depad the packet */
err = pkcs_1_v1_5_decode(tmp, x, LTC_PKCS_1_EME, modulus_bitlen, out, outlen, stat);
Expand Down
9 changes: 5 additions & 4 deletions src/pk/rsa/rsa_encrypt_key.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ int rsa_encrypt_key_ex(const unsigned char *in, unsigned long inlen,
unsigned char *out, unsigned long *outlen,
const unsigned char *lparam, unsigned long lparamlen,
prng_state *prng, int prng_idx,
int hash_idx, int padding,
int mgf_hash, int lparam_hash,
int padding,
const rsa_key *key)
{
unsigned long modulus_bitlen, modulus_bytelen, x;
Expand All @@ -52,7 +53,7 @@ int rsa_encrypt_key_ex(const unsigned char *in, unsigned long inlen,

if (padding == LTC_PKCS_1_OAEP) {
/* valid hash? */
if ((err = hash_is_valid(hash_idx)) != CRYPT_OK) {
if ((err = hash_is_valid(mgf_hash)) != CRYPT_OK) {
return err;
}
}
Expand All @@ -71,8 +72,8 @@ int rsa_encrypt_key_ex(const unsigned char *in, unsigned long inlen,
/* OAEP pad the key */
x = *outlen;
if ((err = pkcs_1_oaep_encode(in, inlen, lparam,
lparamlen, modulus_bitlen, prng, prng_idx, hash_idx,
out, &x)) != CRYPT_OK) {
lparamlen, modulus_bitlen, prng, prng_idx, mgf_hash,
lparam_hash, out, &x)) != CRYPT_OK) {
return err;
}
} else {
Expand Down
4 changes: 2 additions & 2 deletions tests/pkcs_1_eme_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ int pkcs_1_eme_test(void)
unsigned long buflen = sizeof(buf), obuflen = sizeof(obuf);
int stat;
prng_descriptor[prng_idx].add_entropy(s->o2, s->o2_l, (void*)no_prng_desc);
DOX(rsa_encrypt_key_ex(s->o1, s->o1_l, obuf, &obuflen, NULL, 0, (void*)no_prng_desc, prng_idx, -1, LTC_PKCS_1_V1_5, key), s->name);
DOX(rsa_encrypt_key_ex(s->o1, s->o1_l, obuf, &obuflen, NULL, 0, (void*)no_prng_desc, prng_idx, -1, -1, LTC_PKCS_1_V1_5, key), s->name);
COMPARE_TESTVECTOR(obuf, obuflen, s->o3, s->o3_l,s->name, j);
DOX(rsa_decrypt_key_ex(obuf, obuflen, buf, &buflen, NULL, 0, -1, LTC_PKCS_1_V1_5, &stat, key), s->name);
DOX(rsa_decrypt_key_ex(obuf, obuflen, buf, &buflen, NULL, 0, -1, -1, LTC_PKCS_1_V1_5, &stat, key), s->name);
DOX(stat == 1?CRYPT_OK:CRYPT_FAIL_TESTVECTOR, s->name);
} /* for */

Expand Down
4 changes: 2 additions & 2 deletions tests/pkcs_1_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ int pkcs_1_test(void)

/* encode it */
l1 = sizeof(buf[1]);
DO(pkcs_1_oaep_encode(buf[0], l3, lparam, lparamlen, modlen, &yarrow_prng, prng_idx, hash_idx, buf[1], &l1));
DO(pkcs_1_oaep_encode(buf[0], l3, lparam, lparamlen, modlen, &yarrow_prng, prng_idx, hash_idx, -1, buf[1], &l1));

/* decode it */
l2 = sizeof(buf[2]);
DO(pkcs_1_oaep_decode(buf[1], l1, lparam, lparamlen, modlen, hash_idx, buf[2], &l2, &res1));
DO(pkcs_1_oaep_decode(buf[1], l1, lparam, lparamlen, modlen, hash_idx, -1, buf[2], &l2, &res1));

if (res1 != 1 || l2 != l3 || memcmp(buf[2], buf[0], l3) != 0) {
fprintf(stderr, "Outsize == %lu, should have been %lu, res1 = %d, lparamlen = %lu, msg contents follow.\n", l2, l3, res1, lparamlen);
Expand Down
Loading

0 comments on commit 63091c9

Please sign in to comment.