Skip to content

Commit 9643314

Browse files
committed
wip(common): fix
1 parent 6f362ca commit 9643314

File tree

1 file changed

+83
-8
lines changed

1 file changed

+83
-8
lines changed

mithril-common/src/chain_reader/pallas_chain_reader.rs

+83-8
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,21 @@ impl PallasChainReader {
5757
.with_context(|| "PallasChainReader failed to get a client")
5858
}
5959

60+
/// Check if the client already exists.
61+
fn has_client(&self) -> bool {
62+
self.client.is_some()
63+
}
64+
65+
/// Drops the client by aborting the connection and setting it to `None`.
66+
fn drop_client(&mut self) {
67+
if let Some(client) = self.client.take() {
68+
tokio::spawn(async move {
69+
let _ = client.abort().await;
70+
});
71+
}
72+
self.client = None;
73+
}
74+
6075
/// Intersects the point of the chain with the given point.
6176
async fn find_intersect_point(&mut self, point: &RawCardanoPoint) -> StdResult<()> {
6277
let logger = self.logger.clone();
@@ -99,11 +114,7 @@ impl PallasChainReader {
99114

100115
impl Drop for PallasChainReader {
101116
fn drop(&mut self) {
102-
if let Some(client) = self.client.take() {
103-
tokio::spawn(async move {
104-
let _ = client.abort().await;
105-
});
106-
}
117+
self.drop_client();
107118
}
108119
}
109120

@@ -118,11 +129,18 @@ impl ChainBlockReader for PallasChainReader {
118129
let chainsync = client.chainsync();
119130

120131
let next = match chainsync.has_agency() {
121-
true => chainsync.request_next().await?,
122-
false => chainsync.recv_while_must_reply().await?,
132+
true => chainsync.request_next().await,
133+
false => chainsync.recv_while_must_reply().await,
123134
};
124135

125-
self.process_chain_block_next_action(next).await
136+
match next {
137+
Ok(next) => self.process_chain_block_next_action(next).await,
138+
Err(err) => {
139+
self.drop_client();
140+
141+
return Err(err.into());
142+
}
143+
}
126144
}
127145
}
128146

@@ -375,4 +393,61 @@ mod tests {
375393
_ => panic!("Unexpected chain block action"),
376394
}
377395
}
396+
397+
/* #[tokio::test]
398+
async fn get_client_is_dropped_when_error_returned_from_server() {
399+
let socket_path = create_temp_dir("get_client_is_dropped_when_error_returned_from_server")
400+
.join("node.socket");
401+
let known_point = get_fake_specific_point();
402+
let known_point_clone = known_point.clone();
403+
let server = setup_server(
404+
socket_path.clone(),
405+
ServerAction::RollForward,
406+
HasAgency::No,
407+
)
408+
.await;
409+
let socket_path_clone = socket_path.clone();
410+
let client = tokio::spawn(async move {
411+
let mut chain_reader = PallasChainReader::new(
412+
socket_path_clone.as_path(),
413+
CardanoNetwork::TestNet(10),
414+
TestLogger::stdout(),
415+
);
416+
417+
chain_reader
418+
.set_chain_point(&RawCardanoPoint::from(known_point_clone))
419+
.await
420+
.unwrap();
421+
422+
// forces the client to change the chainsync server agency state
423+
let client = chain_reader.get_client().await.unwrap();
424+
client.chainsync().request_next().await.unwrap();
425+
426+
chain_reader.get_next_chain_block().await.unwrap().unwrap();
427+
428+
chain_reader
429+
});
430+
431+
let (_, client_res) = tokio::join!(server, client);
432+
let chain_reader = client_res.expect("Client failed to return chain reader");
433+
let server = setup_server(
434+
socket_path.clone(),
435+
ServerAction::RollForward,
436+
HasAgency::No,
437+
)
438+
.await;
439+
let client = tokio::spawn(async move {
440+
let mut chain_reader = chain_reader;
441+
442+
println!("RES 123");
443+
let res = chain_reader.get_next_chain_block().await;
444+
println!("RES 456");
445+
446+
res
447+
});
448+
let (_, client_res) = tokio::join!(server, client);
449+
let chain_block = client_res
450+
.unwrap()
451+
.expect_err("Server should have returned an error");
452+
} */
378453
}

0 commit comments

Comments
 (0)