Skip to content

Commit f2073b7

Browse files
author
Nathan Blinn
committed
Fix recovery token logic and add test
1 parent e7cd547 commit f2073b7

File tree

4 files changed

+77
-11
lines changed

4 files changed

+77
-11
lines changed

src/client/executor.rs

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -411,10 +411,8 @@ impl Client {
411411
self.update_cluster_time(&r, session).await;
412412
if r.is_success() {
413413
// Retrieve recovery token from successful response.
414-
if let Some(ref mut session) = session {
415-
if session.in_transaction() && is_sharded {
416-
session.transaction.recovery_token = r.recovery_token();
417-
}
414+
if is_sharded {
415+
Client::update_recovery_token(&r, session).await;
418416
}
419417

420418
Ok(CommandResult {
@@ -461,7 +459,10 @@ impl Client {
461459
}))
462460
}
463461
// for ok: 1 just return the original deserialization error.
464-
_ => Err(deserialize_error),
462+
_ => {
463+
Client::update_recovery_token(&error_response, session).await;
464+
Err(deserialize_error)
465+
},
465466
}
466467
}
467468
// We failed to deserialize even that, so just return the original
@@ -640,6 +641,17 @@ impl Client {
640641
}
641642
}
642643
}
644+
645+
async fn update_recovery_token<T: Response>(
646+
response: &T,
647+
session: &mut Option<&mut ClientSession>,
648+
) {
649+
if let Some(ref mut session) = session {
650+
if session.in_transaction() {
651+
session.transaction.recovery_token = response.recovery_token().cloned();
652+
}
653+
}
654+
}
643655
}
644656

645657
impl Error {

src/operation/mod.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ pub(crate) trait Response: Sized {
154154
fn at_cluster_time(&self) -> Option<Timestamp>;
155155

156156
/// The `recoveryToken` field of the response.
157-
fn recovery_token(&self) -> Option<Document>;
157+
fn recovery_token(&self) -> Option<&Document>;
158158

159159
/// Convert into the body of the response.
160160
fn into_body(self) -> Self::Body;
@@ -202,8 +202,8 @@ impl<T: DeserializeOwned> Response for CommandResponse<T> {
202202
self.at_cluster_time
203203
}
204204

205-
fn recovery_token(&self) -> Option<Document> {
206-
self.recovery_token.clone()
205+
fn recovery_token(&self) -> Option<&Document> {
206+
self.recovery_token.as_ref()
207207
}
208208

209209
fn into_body(self) -> Self::Body {
@@ -238,7 +238,7 @@ impl<T: DeserializeOwned> Response for CursorResponse<T> {
238238
self.response.body.cursor.at_cluster_time
239239
}
240240

241-
fn recovery_token(&self) -> Option<Document> {
241+
fn recovery_token(&self) -> Option<&Document> {
242242
self.response.recovery_token()
243243
}
244244

src/operation/run_command/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,8 @@ impl super::Response for Response {
138138
.ok()
139139
}
140140

141-
fn recovery_token(&self) -> Option<Document> {
142-
self.recovery_token.clone()
141+
fn recovery_token(&self) -> Option<&Document> {
142+
self.recovery_token.as_ref()
143143
}
144144

145145
fn into_body(self) -> Self::Body {

src/test/spec/transactions.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use crate::{
55
bson::{doc, serde_helpers::serialize_u64_as_i32, Document},
66
client::session::TransactionState,
77
test::{run_spec_test, TestClient, LOCK},
8+
Collection,
89
};
910

1011
use super::{run_unified_format_test, run_v2_test};
@@ -92,3 +93,56 @@ async fn client_errors() {
9293
assert!(result.is_err());
9394
assert_eq!(session.transaction.state, TransactionState::InProgress);
9495
}
96+
97+
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
98+
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
99+
#[function_name::named]
100+
async fn deserialize_recovery_token() {
101+
let _guard: RwLockReadGuard<()> = LOCK.run_concurrently().await;
102+
103+
#[derive(Debug, Serialize)]
104+
struct A {
105+
num: i32,
106+
}
107+
108+
#[derive(Debug, Deserialize)]
109+
struct B {
110+
str: String,
111+
}
112+
113+
let client = TestClient::new().await;
114+
if !client.is_sharded() || client.server_version_lt(4, 2) {
115+
return;
116+
}
117+
118+
let mut session = client.start_session(None).await.unwrap();
119+
120+
// Insert a document with schema A.
121+
client
122+
.database(function_name!())
123+
.collection::<Document>(function_name!())
124+
.drop(None)
125+
.await
126+
.unwrap();
127+
client
128+
.database(function_name!())
129+
.create_collection(function_name!(), None)
130+
.await
131+
.unwrap();
132+
let coll = client
133+
.database(function_name!())
134+
.collection(function_name!());
135+
coll.insert_one(A { num: 4 }, None).await.unwrap();
136+
137+
// Attempt to execute Find on a document with schema B.
138+
let coll: Collection<B> = client
139+
.database(function_name!())
140+
.collection(function_name!());
141+
session.start_transaction(None).await.unwrap();
142+
assert!(session.transaction.recovery_token.is_none());
143+
let result = coll.find_one_with_session(None, None, &mut session).await;
144+
assert!(result.is_err()); // Assert that the deserialization failed.
145+
146+
// Nevertheless, the recovery token should have been retrieved from the ok: 1 response.
147+
assert!(session.transaction.recovery_token.is_some());
148+
}

0 commit comments

Comments
 (0)