4747#include "tls.h"
4848
4949struct tls_decrypt_arg {
50+ struct_group (inargs ,
5051 bool zc ;
5152 bool async ;
5253 u8 tail ;
54+ );
55+
56+ struct sk_buff * skb ;
5357};
5458
5559struct tls_decrypt_ctx {
@@ -1412,6 +1416,7 @@ static int tls_setup_from_iter(struct iov_iter *from,
14121416 * -------------------------------------------------------------------
14131417 * zc | Zero-copy decrypt allowed | Zero-copy performed
14141418 * async | Async decrypt allowed | Async crypto used / in progress
1419+ * skb | * | Output skb
14151420 */
14161421
14171422/* This function decrypts the input skb into either out_iov or in out_sg
@@ -1551,12 +1556,17 @@ static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov,
15511556 /* Prepare and submit AEAD request */
15521557 err = tls_do_decryption (sk , skb , sgin , sgout , dctx -> iv ,
15531558 data_len + prot -> tail_size , aead_req , darg );
1559+ if (err )
1560+ goto exit_free_pages ;
1561+
1562+ darg -> skb = tls_strp_msg (ctx );
15541563 if (darg -> async )
15551564 return 0 ;
15561565
15571566 if (prot -> tail_size )
15581567 darg -> tail = dctx -> tail ;
15591568
1569+ exit_free_pages :
15601570 /* Release the pages in case iov was mapped to pages */
15611571 for (; pages > 0 ; pages -- )
15621572 put_page (sg_page (& sgout [pages ]));
@@ -1569,6 +1579,7 @@ static int
15691579tls_decrypt_device (struct sock * sk , struct tls_context * tls_ctx ,
15701580 struct tls_decrypt_arg * darg )
15711581{
1582+ struct tls_sw_context_rx * ctx = tls_sw_ctx_rx (tls_ctx );
15721583 int err ;
15731584
15741585 if (tls_ctx -> rx_conf != TLS_HW )
@@ -1580,6 +1591,8 @@ tls_decrypt_device(struct sock *sk, struct tls_context *tls_ctx,
15801591
15811592 darg -> zc = false;
15821593 darg -> async = false;
1594+ darg -> skb = tls_strp_msg (ctx );
1595+ ctx -> recv_pkt = NULL ;
15831596 return 1 ;
15841597}
15851598
@@ -1604,8 +1617,11 @@ static int tls_rx_one_record(struct sock *sk, struct iov_iter *dest,
16041617 TLS_INC_STATS (sock_net (sk ), LINUX_MIB_TLSDECRYPTERROR );
16051618 return err ;
16061619 }
1607- if (darg -> async )
1620+ if (darg -> async ) {
1621+ if (darg -> skb == ctx -> recv_pkt )
1622+ ctx -> recv_pkt = NULL ;
16081623 goto decrypt_next ;
1624+ }
16091625 /* If opportunistic TLS 1.3 ZC failed retry without ZC */
16101626 if (unlikely (darg -> zc && prot -> version == TLS_1_3_VERSION &&
16111627 darg -> tail != TLS_RECORD_TYPE_DATA )) {
@@ -1616,12 +1632,17 @@ static int tls_rx_one_record(struct sock *sk, struct iov_iter *dest,
16161632 return tls_rx_one_record (sk , dest , darg );
16171633 }
16181634
1635+ if (darg -> skb == ctx -> recv_pkt )
1636+ ctx -> recv_pkt = NULL ;
1637+
16191638decrypt_done :
1620- pad = tls_padding_length (prot , ctx -> recv_pkt , darg );
1621- if (pad < 0 )
1639+ pad = tls_padding_length (prot , darg -> skb , darg );
1640+ if (pad < 0 ) {
1641+ consume_skb (darg -> skb );
16221642 return pad ;
1643+ }
16231644
1624- rxm = strp_msg (ctx -> recv_pkt );
1645+ rxm = strp_msg (darg -> skb );
16251646 rxm -> full_len -= pad ;
16261647 rxm -> offset += prot -> prepend_size ;
16271648 rxm -> full_len -= prot -> overhead_size ;
@@ -1663,6 +1684,7 @@ static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm,
16631684
16641685static void tls_rx_rec_done (struct tls_sw_context_rx * ctx )
16651686{
1687+ consume_skb (ctx -> recv_pkt );
16661688 ctx -> recv_pkt = NULL ;
16671689 __strp_unpause (& ctx -> strp );
16681690}
@@ -1872,7 +1894,7 @@ int tls_sw_recvmsg(struct sock *sk,
18721894 ctx -> zc_capable ;
18731895 decrypted = 0 ;
18741896 while (len && (decrypted + copied < target || ctx -> recv_pkt )) {
1875- struct tls_decrypt_arg darg = {} ;
1897+ struct tls_decrypt_arg darg ;
18761898 int to_decrypt , chunk ;
18771899
18781900 err = tls_rx_rec_wait (sk , psock , flags & MSG_DONTWAIT , timeo );
@@ -1889,9 +1911,10 @@ int tls_sw_recvmsg(struct sock *sk,
18891911 goto recv_end ;
18901912 }
18911913
1892- skb = ctx -> recv_pkt ;
1893- rxm = strp_msg (skb );
1894- tlm = tls_msg (skb );
1914+ memset (& darg .inargs , 0 , sizeof (darg .inargs ));
1915+
1916+ rxm = strp_msg (ctx -> recv_pkt );
1917+ tlm = tls_msg (ctx -> recv_pkt );
18951918
18961919 to_decrypt = rxm -> full_len - prot -> overhead_size ;
18971920
@@ -1911,6 +1934,10 @@ int tls_sw_recvmsg(struct sock *sk,
19111934 goto recv_end ;
19121935 }
19131936
1937+ skb = darg .skb ;
1938+ rxm = strp_msg (skb );
1939+ tlm = tls_msg (skb );
1940+
19141941 async |= darg .async ;
19151942
19161943 /* If the type of records being processed is not known yet,
@@ -2051,21 +2078,23 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
20512078 if (!skb_queue_empty (& ctx -> rx_list )) {
20522079 skb = __skb_dequeue (& ctx -> rx_list );
20532080 } else {
2054- struct tls_decrypt_arg darg = {} ;
2081+ struct tls_decrypt_arg darg ;
20552082
20562083 err = tls_rx_rec_wait (sk , NULL , flags & SPLICE_F_NONBLOCK ,
20572084 timeo );
20582085 if (err <= 0 )
20592086 goto splice_read_end ;
20602087
2088+ memset (& darg .inargs , 0 , sizeof (darg .inargs ));
2089+
20612090 err = tls_rx_one_record (sk , NULL , & darg );
20622091 if (err < 0 ) {
20632092 tls_err_abort (sk , - EBADMSG );
20642093 goto splice_read_end ;
20652094 }
20662095
2067- skb = ctx -> recv_pkt ;
20682096 tls_rx_rec_done (ctx );
2097+ skb = darg .skb ;
20692098 }
20702099
20712100 rxm = strp_msg (skb );
0 commit comments