@@ -57,6 +57,21 @@ impl PallasChainReader {
57
57
. with_context ( || "PallasChainReader failed to get a client" )
58
58
}
59
59
60
+ /// Check if the client already exists.
61
+ fn has_client ( & self ) -> bool {
62
+ self . client . is_some ( )
63
+ }
64
+
65
+ /// Drops the client by aborting the connection and setting it to `None`.
66
+ fn drop_client ( & mut self ) {
67
+ if let Some ( client) = self . client . take ( ) {
68
+ tokio:: spawn ( async move {
69
+ let _ = client. abort ( ) . await ;
70
+ } ) ;
71
+ }
72
+ self . client = None ;
73
+ }
74
+
60
75
/// Intersects the point of the chain with the given point.
61
76
async fn find_intersect_point ( & mut self , point : & RawCardanoPoint ) -> StdResult < ( ) > {
62
77
let logger = self . logger . clone ( ) ;
@@ -99,11 +114,7 @@ impl PallasChainReader {
99
114
100
115
impl Drop for PallasChainReader {
101
116
fn drop ( & mut self ) {
102
- if let Some ( client) = self . client . take ( ) {
103
- tokio:: spawn ( async move {
104
- let _ = client. abort ( ) . await ;
105
- } ) ;
106
- }
117
+ self . drop_client ( ) ;
107
118
}
108
119
}
109
120
@@ -118,11 +129,18 @@ impl ChainBlockReader for PallasChainReader {
118
129
let chainsync = client. chainsync ( ) ;
119
130
120
131
let next = match chainsync. has_agency ( ) {
121
- true => chainsync. request_next ( ) . await ? ,
122
- false => chainsync. recv_while_must_reply ( ) . await ? ,
132
+ true => chainsync. request_next ( ) . await ,
133
+ false => chainsync. recv_while_must_reply ( ) . await ,
123
134
} ;
124
135
125
- self . process_chain_block_next_action ( next) . await
136
+ match next {
137
+ Ok ( next) => self . process_chain_block_next_action ( next) . await ,
138
+ Err ( err) => {
139
+ self . drop_client ( ) ;
140
+
141
+ return Err ( err. into ( ) ) ;
142
+ }
143
+ }
126
144
}
127
145
}
128
146
@@ -375,4 +393,61 @@ mod tests {
375
393
_ => panic ! ( "Unexpected chain block action" ) ,
376
394
}
377
395
}
396
+
397
+ /* #[tokio::test]
398
+ async fn get_client_is_dropped_when_error_returned_from_server() {
399
+ let socket_path = create_temp_dir("get_client_is_dropped_when_error_returned_from_server")
400
+ .join("node.socket");
401
+ let known_point = get_fake_specific_point();
402
+ let known_point_clone = known_point.clone();
403
+ let server = setup_server(
404
+ socket_path.clone(),
405
+ ServerAction::RollForward,
406
+ HasAgency::No,
407
+ )
408
+ .await;
409
+ let socket_path_clone = socket_path.clone();
410
+ let client = tokio::spawn(async move {
411
+ let mut chain_reader = PallasChainReader::new(
412
+ socket_path_clone.as_path(),
413
+ CardanoNetwork::TestNet(10),
414
+ TestLogger::stdout(),
415
+ );
416
+
417
+ chain_reader
418
+ .set_chain_point(&RawCardanoPoint::from(known_point_clone))
419
+ .await
420
+ .unwrap();
421
+
422
+ // forces the client to change the chainsync server agency state
423
+ let client = chain_reader.get_client().await.unwrap();
424
+ client.chainsync().request_next().await.unwrap();
425
+
426
+ chain_reader.get_next_chain_block().await.unwrap().unwrap();
427
+
428
+ chain_reader
429
+ });
430
+
431
+ let (_, client_res) = tokio::join!(server, client);
432
+ let chain_reader = client_res.expect("Client failed to return chain reader");
433
+ let server = setup_server(
434
+ socket_path.clone(),
435
+ ServerAction::RollForward,
436
+ HasAgency::No,
437
+ )
438
+ .await;
439
+ let client = tokio::spawn(async move {
440
+ let mut chain_reader = chain_reader;
441
+
442
+ println!("RES 123");
443
+ let res = chain_reader.get_next_chain_block().await;
444
+ println!("RES 456");
445
+
446
+ res
447
+ });
448
+ let (_, client_res) = tokio::join!(server, client);
449
+ let chain_block = client_res
450
+ .unwrap()
451
+ .expect_err("Server should have returned an error");
452
+ } */
378
453
}
0 commit comments