@@ -3090,6 +3090,97 @@ def wrapper(sock):
30903090 with self .tcp_server (run (server )) as srv :
30913091 self .loop .run_until_complete (client (srv .addr ))
30923092
3093+ def test_first_data_after_wakeup (self ):
3094+ if self .implementation == 'asyncio' :
3095+ raise unittest .SkipTest ()
3096+
3097+ server_context = self ._create_server_ssl_context (
3098+ self .ONLYCERT , self .ONLYKEY )
3099+ client_context = self ._create_client_ssl_context ()
3100+ loop = self .loop
3101+ this = self
3102+ fut = self .loop .create_future ()
3103+
3104+ def client (sock , addr ):
3105+ try :
3106+ sock .connect (addr )
3107+
3108+ incoming = ssl .MemoryBIO ()
3109+ outgoing = ssl .MemoryBIO ()
3110+ sslobj = client_context .wrap_bio (incoming , outgoing )
3111+
3112+ # Do handshake manually so that we could collect the last piece
3113+ while True :
3114+ try :
3115+ sslobj .do_handshake ()
3116+ break
3117+ except ssl .SSLWantReadError :
3118+ if outgoing .pending :
3119+ sock .send (outgoing .read ())
3120+ incoming .write (sock .recv (65536 ))
3121+
3122+ # Send the first data together with the last handshake payload
3123+ sslobj .write (b'hello' )
3124+ sock .send (outgoing .read ())
3125+
3126+ while True :
3127+ try :
3128+ incoming .write (sock .recv (65536 ))
3129+ self .assertEqual (sslobj .read (1024 ), b'hello' )
3130+ break
3131+ except ssl .SSLWantReadError :
3132+ pass
3133+
3134+ sock .close ()
3135+
3136+ except Exception as ex :
3137+ loop .call_soon_threadsafe (fut .set_exception , ex )
3138+ sock .close ()
3139+ else :
3140+ loop .call_soon_threadsafe (fut .set_result , None )
3141+
3142+ class EchoProto (asyncio .Protocol ):
3143+ def connection_made (self , tr ):
3144+ self .tr = tr
3145+ # manually run the coroutine, in order to avoid accidental data
3146+ coro = loop .start_tls (
3147+ tr , self , server_context ,
3148+ server_side = True ,
3149+ ssl_handshake_timeout = this .TIMEOUT ,
3150+ )
3151+ waiter = coro .send (None )
3152+
3153+ def tls_started (_ ):
3154+ try :
3155+ coro .send (None )
3156+ except StopIteration as e :
3157+ # update self.tr to SSL transport as soon as we know it
3158+ self .tr = e .value
3159+
3160+ waiter .add_done_callback (tls_started )
3161+
3162+ def data_received (self , data ):
3163+ # This is a dumb protocol that writes back whatever it receives
3164+ # regardless of whether self.tr is SSL or not
3165+ self .tr .write (data )
3166+
3167+ async def run_main ():
3168+ proto = EchoProto ()
3169+
3170+ server = await self .loop .create_server (
3171+ lambda : proto , '127.0.0.1' , 0 )
3172+ addr = server .sockets [0 ].getsockname ()
3173+
3174+ with self .tcp_client (lambda sock : client (sock , addr ),
3175+ timeout = self .TIMEOUT ):
3176+ await asyncio .wait_for (fut , timeout = self .TIMEOUT )
3177+ proto .tr .close ()
3178+
3179+ server .close ()
3180+ await server .wait_closed ()
3181+
3182+ self .loop .run_until_complete (run_main ())
3183+
30933184
30943185class Test_UV_TCPSSL (_TestSSL , tb .UVTestCase ):
30953186 pass
0 commit comments