@@ -273,161 +273,60 @@ where
273273mod  test { 
274274    use  std:: sync:: Arc ; 
275275
276+     use  futures:: { future,  pin_mut} ; 
276277    use  tokio:: sync:: Notify ; 
277278    use  turmoil:: net:: { TcpListener ,  TcpStream } ; 
278-     use  uuid:: Uuid ; 
279279
280280    use  super :: * ; 
281281
282282    #[ test]  
283283    fn  invalid_handshake ( )  { 
284284        let  mut  sim = turmoil:: Builder :: new ( ) . build ( ) ; 
285285
286-         let  host_node_id = NodeId :: new_v4 ( ) ; 
287-         sim. host ( "host" ,  move  || async  move  { 
288-             let  bus = Bus :: new ( host_node_id) ; 
289-             let  listener = turmoil:: net:: TcpListener :: bind ( "0.0.0.0:1234" ) 
290-                 . await 
291-                 . unwrap ( ) ; 
292-             let  ( s,  _)  = listener. accept ( ) . await . unwrap ( ) ; 
293-             let  mut  connection = Connection :: new_acceptor ( s,  bus) ; 
294-             connection. tick ( ) . await ; 
295- 
296-             Ok ( ( ) ) 
286+         let  host_node_id = 0 ; 
287+         let  done = Arc :: new ( Notify :: new ( ) ) ; 
288+         let  done_clone = done. clone ( ) ; 
289+         sim. host ( "host" ,  move  || { 
290+             let  done_clone = done_clone. clone ( ) ; 
291+             async  move  { 
292+                 let  bus = Arc :: new ( Bus :: new ( host_node_id,  |_,  _| async  { } ) ) ; 
293+                 let  listener = turmoil:: net:: TcpListener :: bind ( "0.0.0.0:1234" ) 
294+                     . await 
295+                     . unwrap ( ) ; 
296+                 let  ( s,  _)  = listener. accept ( ) . await . unwrap ( ) ; 
297+                 let  connection = Connection :: new_acceptor ( s,  bus) ; 
298+                 let  done = done_clone. notified ( ) ; 
299+                 let  run = connection. run ( ) ; 
300+                 pin_mut ! ( done) ; 
301+                 pin_mut ! ( run) ; 
302+                 future:: select ( run,  done) . await ; 
303+ 
304+                 Ok ( ( ) ) 
305+             } 
297306        } ) ; 
298307
299308        sim. client ( "client" ,  async  move  { 
300309            let  s = TcpStream :: connect ( "host:1234" ) . await . unwrap ( ) ; 
301-             let  mut  s = AsyncBincodeStream :: < _ ,  Message ,  Message ,  _ > :: from ( s) . for_async ( ) ; 
302- 
303-             s. send ( Message :: Node ( NodeMessage :: Handshake  { 
304-                 protocol_version :  1234 , 
305-                 node_id :  Uuid :: new_v4 ( ) , 
306-             } ) ) 
307-             . await 
308-             . unwrap ( ) ; 
310+             let  mut  s = AsyncBincodeStream :: < _ ,  Enveloppe ,  Enveloppe ,  _ > :: from ( s) . for_async ( ) ; 
311+ 
312+             let  msg = Enveloppe  { 
313+                 database_id :  None , 
314+                 message :  Message :: Handshake  { 
315+                     protocol_version :  1234 , 
316+                     node_id :  1 , 
317+                 } , 
318+             } ; 
319+             s. send ( msg) . await . unwrap ( ) ; 
309320            let  m = s. next ( ) . await . unwrap ( ) . unwrap ( ) ; 
310321
311322            assert ! ( matches!( 
312-                 m, 
313-                 Message :: Node ( NodeMessage :: Error ( 
314-                     NodeError :: HandshakeVersionMismatch  {  .. } 
315-                 ) ) 
323+                 m. message , 
324+                 Message :: Error ( 
325+                     ProtoError :: HandshakeVersionMismatch  {  .. } 
326+                 ) 
316327            ) ) ; 
317328
318-             Ok ( ( ) ) 
319-         } ) ; 
320- 
321-         sim. run ( ) . unwrap ( ) ; 
322-     } 
323- 
324-     #[ test]  
325-     fn  stream_closed ( )  { 
326-         let  mut  sim = turmoil:: Builder :: new ( ) . build ( ) ; 
327- 
328-         let  database_id = DatabaseId :: new_v4 ( ) ; 
329-         let  host_node_id = NodeId :: new_v4 ( ) ; 
330-         let  notify = Arc :: new ( Notify :: new ( ) ) ; 
331-         sim. host ( "host" ,  { 
332-             let  notify = notify. clone ( ) ; 
333-             move  || { 
334-                 let  notify = notify. clone ( ) ; 
335-                 async  move  { 
336-                     let  bus = Bus :: new ( host_node_id) ; 
337-                     let  mut  sub = bus. subscribe ( database_id) . unwrap ( ) ; 
338-                     let  listener = turmoil:: net:: TcpListener :: bind ( "0.0.0.0:1234" ) 
339-                         . await 
340-                         . unwrap ( ) ; 
341-                     let  ( s,  _)  = listener. accept ( ) . await . unwrap ( ) ; 
342-                     let  connection = Connection :: new_acceptor ( s,  bus) ; 
343-                     tokio:: task:: spawn_local ( connection. run ( ) ) ; 
344-                     let  mut  streams = Vec :: new ( ) ; 
345-                     loop  { 
346-                         tokio:: select! { 
347-                             Some ( mut  stream)  = sub. next( )  => { 
348-                                 let  m = stream. next( ) . await . unwrap( ) ; 
349-                                 stream. send( m) . await . unwrap( ) ; 
350-                                 streams. push( stream) ; 
351-                             } 
352-                             _ = notify. notified( )  => { 
353-                                 break ; 
354-                             } 
355-                         } 
356-                     } 
357- 
358-                     Ok ( ( ) ) 
359-                 } 
360-             } 
361-         } ) ; 
362- 
363-         sim. client ( "client" ,  async  move  { 
364-             let  stream_id = StreamId :: new ( 1 ) ; 
365-             let  node_id = NodeId :: new_v4 ( ) ; 
366-             let  s = TcpStream :: connect ( "host:1234" ) . await . unwrap ( ) ; 
367-             let  mut  s = AsyncBincodeStream :: < _ ,  Message ,  Message ,  _ > :: from ( s) . for_async ( ) ; 
368- 
369-             s. send ( Message :: Node ( NodeMessage :: Handshake  { 
370-                 protocol_version :  CURRENT_PROTO_VERSION , 
371-                 node_id, 
372-             } ) ) 
373-             . await 
374-             . unwrap ( ) ; 
375-             let  m = s. next ( ) . await . unwrap ( ) . unwrap ( ) ; 
376-             assert ! ( matches!( m,  Message :: Node ( NodeMessage :: Handshake  {  .. } ) ) ) ; 
377- 
378-             // send message to unexisting stream: 
379-             s. send ( Message :: Stream  { 
380-                 stream_id, 
381-                 payload :  StreamMessage :: Dummy , 
382-             } ) 
383-             . await 
384-             . unwrap ( ) ; 
385-             let  m = s. next ( ) . await . unwrap ( ) . unwrap ( ) ; 
386-             assert_eq ! ( 
387-                 m, 
388-                 Message :: Node ( NodeMessage :: Error ( NodeError :: UnknownStream ( stream_id) ) ) 
389-             ) ; 
390- 
391-             // open stream then send message 
392-             s. send ( Message :: Node ( NodeMessage :: OpenStream  { 
393-                 stream_id, 
394-                 database_id, 
395-             } ) ) 
396-             . await 
397-             . unwrap ( ) ; 
398-             s. send ( Message :: Stream  { 
399-                 stream_id, 
400-                 payload :  StreamMessage :: Dummy , 
401-             } ) 
402-             . await 
403-             . unwrap ( ) ; 
404-             let  m = s. next ( ) . await . unwrap ( ) . unwrap ( ) ; 
405-             assert_eq ! ( 
406-                 m, 
407-                 Message :: Stream  { 
408-                     stream_id, 
409-                     payload:  StreamMessage :: Dummy 
410-                 } 
411-             ) ; 
412- 
413-             s. send ( Message :: Node ( NodeMessage :: CloseStream  { 
414-                 stream_id :  StreamId :: new ( 1 ) , 
415-             } ) ) 
416-             . await 
417-             . unwrap ( ) ; 
418-             s. send ( Message :: Stream  { 
419-                 stream_id, 
420-                 payload :  StreamMessage :: Dummy , 
421-             } ) 
422-             . await 
423-             . unwrap ( ) ; 
424-             let  m = s. next ( ) . await . unwrap ( ) . unwrap ( ) ; 
425-             assert_eq ! ( 
426-                 m, 
427-                 Message :: Node ( NodeMessage :: Error ( NodeError :: UnknownStream ( stream_id) ) ) 
428-             ) ; 
429- 
430-             notify. notify_waiters ( ) ; 
329+             done. notify_waiters ( ) ; 
431330
432331            Ok ( ( ) ) 
433332        } ) ; 
@@ -459,7 +358,7 @@ mod test {
459358
460359        sim. client ( "client" ,  async  move  { 
461360            let  stream = TcpStream :: connect ( "host:1234" ) . await . unwrap ( ) ; 
462-             let  bus = Bus :: new ( NodeId :: new_v4 ( ) ) ; 
361+             let  bus = Arc :: new ( Bus :: new ( 1 ,  |_ ,  _|  async   { } ) ) ; 
463362            let  mut  conn = Connection :: new_acceptor ( stream,  bus) ; 
464363
465364            notify. notify_waiters ( ) ; 
@@ -473,57 +372,4 @@ mod test {
473372
474373        sim. run ( ) . unwrap ( ) ; 
475374    } 
476- 
477-     #[ test]  
478-     fn  zero_stream_id ( )  { 
479-         let  mut  sim = turmoil:: Builder :: new ( ) . build ( ) ; 
480- 
481-         let  notify = Arc :: new ( Notify :: new ( ) ) ; 
482-         sim. host ( "host" ,  { 
483-             let  notify = notify. clone ( ) ; 
484-             move  || { 
485-                 let  notify = notify. clone ( ) ; 
486-                 async  move  { 
487-                     let  listener = TcpListener :: bind ( "0.0.0.0:1234" ) . await . unwrap ( ) ; 
488-                     let  ( stream,  _)  = listener. accept ( ) . await . unwrap ( ) ; 
489-                     let  ( connection_messages_sender,  connection_messages)  = mpsc:: channel ( 1 ) ; 
490-                     let  conn = Connection  { 
491-                         peer :  Some ( NodeId :: new_v4 ( ) ) , 
492-                         state :  ConnectionState :: Connected , 
493-                         conn :  AsyncBincodeStream :: from ( stream) . for_async ( ) , 
494-                         streams :  HashMap :: new ( ) , 
495-                         connection_messages, 
496-                         connection_messages_sender, 
497-                         is_initiator :  false , 
498-                         bus :  Bus :: new ( NodeId :: new_v4 ( ) ) , 
499-                         stream_id_allocator :  StreamIdAllocator :: new ( false ) , 
500-                         registration :  None , 
501-                     } ; 
502- 
503-                     conn. run ( ) . await ; 
504- 
505-                     Ok ( ( ) ) 
506-                 } 
507-             } 
508-         } ) ; 
509- 
510-         sim. client ( "client" ,  async  move  { 
511-             let  stream = TcpStream :: connect ( "host:1234" ) . await . unwrap ( ) ; 
512-             let  mut  stream = AsyncBincodeStream :: < _ ,  Message ,  Message ,  _ > :: from ( stream) . for_async ( ) ; 
513- 
514-             stream
515-                 . send ( Message :: Stream  { 
516-                     stream_id :  StreamId :: new_unchecked ( 0 ) , 
517-                     payload :  StreamMessage :: Dummy , 
518-                 } ) 
519-                 . await 
520-                 . unwrap ( ) ; 
521- 
522-             assert ! ( stream. next( ) . await . is_none( ) ) ; 
523- 
524-             Ok ( ( ) ) 
525-         } ) ; 
526- 
527-         sim. run ( ) . unwrap ( ) ; 
528-     } 
529375} 
0 commit comments