diff --git a/src/tls13.c b/src/tls13.c index aa2ab160de..0aa331e699 100644 --- a/src/tls13.c +++ b/src/tls13.c @@ -8759,6 +8759,10 @@ typedef struct Scv13Args { byte sigAlgo; byte* sigData; word16 sigDataSz; +#ifndef NO_RSA + byte* toSign; /* not allocated */ + word32 toSignSz; +#endif #ifdef WOLFSSL_DUAL_ALG_CERTS byte altSigAlgo; word32 altSigLen; /* Only used in the case of both native and alt. */ @@ -9313,7 +9317,17 @@ static int SendTls13CertificateVerify(WOLFSSL* ssl) #endif /* HAVE_DILITHIUM */ #ifndef NO_RSA if (ssl->hsType == DYNAMIC_TYPE_RSA) { - ret = RsaSign(ssl, rsaSigBuf->buffer, (word32)rsaSigBuf->length, + args->toSign = rsaSigBuf->buffer; + args->toSignSz = (word32)rsaSigBuf->length; + #if defined(HAVE_PK_CALLBACKS) && \ + defined(TLS13_RSA_PSS_SIGN_CB_NO_PREHASH) + /* Pass full data to sign (args->sigData), not hash of */ + if (ssl->ctx->RsaPssSignCb) { + args->toSign = args->sigData; + args->toSignSz = args->sigDataSz; + } + #endif + ret = RsaSign(ssl, (const byte*)args->toSign, args->toSignSz, sigOut, &args->sigLen, args->sigAlgo, ssl->options.hashAlgo, (RsaKey*)ssl->hsKey, ssl->buffers.key); @@ -9357,10 +9371,20 @@ static int SendTls13CertificateVerify(WOLFSSL* ssl) #endif /* HAVE_ECC */ #ifndef NO_RSA if (ssl->hsAltType == DYNAMIC_TYPE_RSA) { - ret = RsaSign(ssl, rsaSigBuf->buffer, - (word32)rsaSigBuf->length, sigOut, - &args->altSigLen, args->altSigAlgo, - ssl->options.hashAlgo, (RsaKey*)ssl->hsAltKey, + args->toSign = rsaSigBuf->buffer; + args->toSignSz = (word32)rsaSigBuf->length; + #if defined(HAVE_PK_CALLBACKS) && \ + defined(TLS13_RSA_PSS_SIGN_CB_NO_PREHASH) + /* Pass full data to sign (args->altSigData), not hash of */ + if (ssl->ctx->RsaPssSignCb) { + args->toSign = args->altSigData; + args->toSignSz = (word32)args->altSigDataSz; + } + #endif + ret = RsaSign(ssl, (const byte*)args->toSign, + args->toSignSz, sigOut, &args->altSigLen, + args->altSigAlgo, ssl->options.hashAlgo, + (RsaKey*)ssl->hsAltKey, ssl->buffers.altKey); if (ret == 0) { diff --git a/wolfssl/test.h b/wolfssl/test.h index 888d7f1ae4..26fa9c3815 100644 --- a/wolfssl/test.h +++ b/wolfssl/test.h @@ -3902,9 +3902,11 @@ static WC_INLINE int myRsaPssSign(WOLFSSL* ssl, const byte* in, word32 inSz, { enum wc_HashType hashType = WC_HASH_TYPE_NONE; WC_RNG rng; - int ret; + int ret = 0; word32 idx = 0; RsaKey myKey; + byte* inBuf = (byte*)in; + word32 inBufSz = inSz; byte* keyBuf = (byte*)key; PkCbInfo* cbInfo = (PkCbInfo*)ctx; @@ -3942,17 +3944,40 @@ static WC_INLINE int myRsaPssSign(WOLFSSL* ssl, const byte* in, word32 inSz, if (ret != 0) return ret; - ret = wc_InitRsaKey(&myKey, NULL); + #ifdef TLS13_RSA_PSS_SIGN_CB_NO_PREHASH + /* With this defined, RSA-PSS sign callback when used from TLS 1.3 + * does not hash data before giving to this callback. User must + * compute hash themselves. */ + if (wolfSSL_GetVersion(ssl) == WOLFSSL_TLSV1_3) { + inBufSz = wc_HashGetDigestSize(hashType); + inBuf = (byte*)XMALLOC(inBufSz, NULL, DYNAMIC_TYPE_TMP_BUFFER); + if (inBuf == NULL) { + ret = MEMORY_E; + } + if (ret == 0) { + ret = wc_Hash(hashType, in, inSz, inBuf, inBufSz); + } + } + #endif + + if (ret == 0) { + ret = wc_InitRsaKey(&myKey, NULL); + } if (ret == 0) { ret = wc_RsaPrivateKeyDecode(keyBuf, &idx, &myKey, keySz); if (ret == 0) { - ret = wc_RsaPSS_Sign(in, inSz, out, *outSz, hashType, mgf, &myKey, - &rng); + ret = wc_RsaPSS_Sign(inBuf, inBufSz, out, *outSz, hashType, mgf, + &myKey, &rng); } if (ret > 0) { /* save and convert to 0 success */ *outSz = (word32) ret; ret = 0; } + #ifdef TLS13_RSA_PSS_SIGN_CB_NO_PREHASH + if ((inBuf != NULL) && (wolfSSL_GetVersion(ssl) == WOLFSSL_TLSV1_3)) { + XFREE(inBuf, NULL, DYNAMIC_TYPE_TMP_BUFFER); + } + #endif wc_FreeRsaKey(&myKey); } wc_FreeRng(&rng);