Skip to content

RUST-1894 Retry KMS requests on transient errors #1281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .config/nextest.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[profile.default]
test-threads = 1
default-filter = 'not test(test::happy_eyeballs)'
default-filter = 'not test(test::happy_eyeballs) and not test(kms_retry)'
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

POC for skipping tests locally that need additional setup (in this case, a mock server running). We don't need to include this in the CI filter because the tests in csfle are only run by the CSFLE tasks, in which case the mock server needed here has been started by start-servers.sh.


[profile.ci]
failure-output = "final"
Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ time = "0.3.9"
tokio = { version = ">= 0.0.0", features = ["fs", "parking_lot"] }
tracing-subscriber = "0.3.16"
regex = "1.6.0"
reqwest = { version = "0.12.2", features = ["rustls-tls"] }
serde-hex = "0.1.0"
serde_path_to_error = "0.1"

Expand Down
1 change: 1 addition & 0 deletions src/client/csfle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ impl ClientState {
let mut builder = Crypt::builder()
.kms_providers(&opts.kms_providers.credentials_doc()?)?
.use_need_kms_credentials_state()
.retry_kms(true)?
.use_range_v2()?;
if let Some(m) = &opts.schema_map {
builder = builder.schema_map(&bson::to_document(m)?)?;
Expand Down
1 change: 1 addition & 0 deletions src/client/csfle/client_encryption.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ impl ClientEncryption {
let crypt = Crypt::builder()
.kms_providers(&kms_providers.credentials_doc()?)?
.use_need_kms_credentials_state()
.retry_kms(true)?
.use_range_v2()?
.build()?;
let exec = CryptExecutor::new_explicit(
Expand Down
90 changes: 58 additions & 32 deletions src/client/csfle/state_machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@ use std::{
convert::TryInto,
ops::DerefMut,
path::{Path, PathBuf},
time::Duration,
};

use bson::{rawdoc, Document, RawDocument, RawDocumentBuf};
use futures_util::{stream, TryStreamExt};
use mongocrypt::ctx::{Ctx, KmsProviderType, State};
use mongocrypt::ctx::{Ctx, KmsCtx, KmsProviderType, State};
use rayon::ThreadPool;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
sync::{oneshot, Mutex},
};

use crate::{
client::{options::ServerAddress, WeakClient},
client::{csfle::options::KmsProvidersTlsOptions, options::ServerAddress, WeakClient},
error::{Error, Result},
operation::{run_command::RunCommand, RawOutput},
options::ReadConcern,
Expand Down Expand Up @@ -174,37 +175,62 @@ impl CryptExecutor {
State::NeedKms => {
let ctx = result_mut(&mut ctx)?;
let scope = ctx.kms_scope();
let mut kms_ctxen: Vec<Result<_>> = vec![];
while let Some(kms_ctx) = scope.next_kms_ctx() {
kms_ctxen.push(Ok(kms_ctx));

async fn execute(
kms_ctx: &mut KmsCtx<'_>,
tls_options: Option<&KmsProvidersTlsOptions>,
) -> Result<()> {
let endpoint = kms_ctx.endpoint()?;
let addr = ServerAddress::parse(endpoint)?;
let provider = kms_ctx.kms_provider()?;
let tls_options = tls_options
.and_then(|tls| tls.get(&provider))
.cloned()
.unwrap_or_default();
let mut stream =
AsyncStream::connect(addr, Some(&TlsConfig::new(tls_options)?)).await?;
stream.write_all(kms_ctx.message()?).await?;
let mut buf = vec![0];
while kms_ctx.bytes_needed() > 0 {
let buf_size = kms_ctx.bytes_needed().try_into().map_err(|e| {
Error::internal(format!("buffer size overflow: {}", e))
})?;
buf.resize(buf_size, 0);
let count = stream.read(&mut buf).await?;
kms_ctx.feed(&buf[0..count])?;
}
Ok(())
}

loop {
let mut kms_contexts: Vec<Result<_>> = Vec::new();
while let Some(kms_ctx) = scope.next_kms_ctx() {
kms_contexts.push(Ok(kms_ctx));
}
if kms_contexts.is_empty() {
break;
}

stream::iter(kms_contexts)
.try_for_each_concurrent(None, |mut kms_ctx| async move {
let sleep_micros =
u64::try_from(kms_ctx.sleep_micros()).unwrap_or(0);
if sleep_micros > 0 {
tokio::time::sleep(Duration::from_micros(sleep_micros)).await;
}

if let Err(error) =
execute(&mut kms_ctx, self.kms_providers.tls_options()).await
{
if !kms_ctx.retry_failure() {
return Err(error);
}
}

Ok(())
})
.await?;
}
stream::iter(kms_ctxen)
.try_for_each_concurrent(None, |mut kms_ctx| async move {
let endpoint = kms_ctx.endpoint()?;
let addr = ServerAddress::parse(endpoint)?;
let provider = kms_ctx.kms_provider()?;
let tls_options = self
.kms_providers
.tls_options()
.and_then(|tls| tls.get(&provider))
.cloned()
.unwrap_or_default();
let mut stream =
AsyncStream::connect(addr, Some(&TlsConfig::new(tls_options)?))
.await?;
stream.write_all(kms_ctx.message()?).await?;
let mut buf = vec![0];
while kms_ctx.bytes_needed() > 0 {
let buf_size = kms_ctx.bytes_needed().try_into().map_err(|e| {
Error::internal(format!("buffer size overflow: {}", e))
})?;
buf.resize(buf_size, 0);
let count = stream.read(&mut buf).await?;
kms_ctx.feed(&buf[0..count])?;
}
Ok(())
})
.await?;
}
State::NeedKmsCredentials => {
let ctx = result_mut(&mut ctx)?;
Expand Down
161 changes: 161 additions & 0 deletions src/test/csfle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3493,6 +3493,167 @@ async fn range_explicit_encryption_defaults() -> Result<()> {
Ok(())
}

// Prose Test 24. KMS Retry Tests
#[tokio::test]
// using openssl causes errors after configuring a network failpoint
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was getting opaque SSL errors from the Python server after configuring network errors with openssl-tls enabled - we're not gaining any useful coverage by running this test on rustls and openssl, so it didn't seem worth it to try to make it work

#[cfg(not(feature = "openssl-tls"))]
async fn kms_retry() {
use reqwest::{Certificate, Client as HttpClient};

let endpoint = "127.0.0.1:9003";

let mut certificate_file_path = PathBuf::from(std::env::var("CSFLE_TLS_CERT_DIR").unwrap());
certificate_file_path.push("ca.pem");
let certificate_file = std::fs::read(&certificate_file_path).unwrap();

let set_failpoint = |kind: &str, count: u8| {
// create a fresh client for each request to avoid hangs
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some more odd behavior: when I reused a single HTTP client for setting failpoints, this line would hang for 90 seconds and then connect successfully. I saw a few issues filed in reqwest's repo describing similar behavior that fixed it by re-creating the client for each request, so I just did the same here.

let http_client = HttpClient::builder()
.add_root_certificate(Certificate::from_pem(&certificate_file).unwrap())
.build()
.unwrap();
let url = format!("https://localhost:9003/set_failpoint/{}", kind);
let body = format!("{{\"count\":{}}}", count);
http_client.post(url).body(body).send()
};

let aws_kms = AWS_KMS.clone();
let mut azure_kms = AZURE_KMS.clone();
azure_kms.1.insert("identityPlatformEndpoint", endpoint);
let mut gcp_kms = GCP_KMS.clone();
gcp_kms.1.insert("endpoint", endpoint);
let mut kms_providers = vec![aws_kms, azure_kms, gcp_kms];

let tls_options = get_client_options().await.tls_options();
for kms_provider in kms_providers.iter_mut() {
kms_provider.2 = tls_options.clone();
}

let key_vault_client = Client::for_test().await.into_client();
let client_encryption = ClientEncryption::new(
key_vault_client,
Namespace::new("keyvault", "datakeys"),
kms_providers,
)
.unwrap();

let aws_master_key = AwsMasterKey::builder()
.region("foo")
.key("bar")
.endpoint(endpoint.to_string())
.build();
let azure_master_key = AzureMasterKey::builder()
.key_vault_endpoint(endpoint)
.key_name("foo")
.build();
let gcp_master_key = GcpMasterKey::builder()
.project_id("foo")
.location("bar")
.key_ring("baz")
.key_name("qux")
.endpoint(endpoint.to_string())
.build();

// Case 1: createDataKey and encrypt with TCP retry

// AWS
set_failpoint("network", 1).await.unwrap();
let key_id = client_encryption
.create_data_key(aws_master_key.clone())
.await
.unwrap();
set_failpoint("network", 1).await.unwrap();
client_encryption
.encrypt(123, key_id, Algorithm::Deterministic)
.await
.unwrap();

// Azure
set_failpoint("network", 1).await.unwrap();
let key_id = client_encryption
.create_data_key(azure_master_key.clone())
.await
.unwrap();
set_failpoint("network", 1).await.unwrap();
client_encryption
.encrypt(123, key_id, Algorithm::Deterministic)
.await
.unwrap();

// GCP
set_failpoint("network", 1).await.unwrap();
let key_id = client_encryption
.create_data_key(gcp_master_key.clone())
.await
.unwrap();
set_failpoint("network", 1).await.unwrap();
client_encryption
.encrypt(123, key_id, Algorithm::Deterministic)
.await
.unwrap();

// Case 2: createDataKey and encrypt with HTTP retry

// AWS
set_failpoint("http", 1).await.unwrap();
let key_id = client_encryption
.create_data_key(aws_master_key.clone())
.await
.unwrap();
set_failpoint("http", 1).await.unwrap();
client_encryption
.encrypt(123, key_id, Algorithm::Deterministic)
.await
.unwrap();

// Azure
set_failpoint("http", 1).await.unwrap();
let key_id = client_encryption
.create_data_key(azure_master_key.clone())
.await
.unwrap();
set_failpoint("http", 1).await.unwrap();
client_encryption
.encrypt(123, key_id, Algorithm::Deterministic)
.await
.unwrap();

// GCP
set_failpoint("http", 1).await.unwrap();
let key_id = client_encryption
.create_data_key(gcp_master_key.clone())
.await
.unwrap();
set_failpoint("http", 1).await.unwrap();
client_encryption
.encrypt(123, key_id, Algorithm::Deterministic)
.await
.unwrap();

// Case 3: createDataKey fails after too many retries

// AWS
set_failpoint("network", 4).await.unwrap();
client_encryption
.create_data_key(aws_master_key)
.await
.unwrap_err();

// Azure
set_failpoint("network", 4).await.unwrap();
client_encryption
.create_data_key(azure_master_key)
.await
.unwrap_err();

// GCP
set_failpoint("network", 4).await.unwrap();
client_encryption
.create_data_key(gcp_master_key)
.await
.unwrap_err();
}

// FLE 2.0 Documentation Example
#[tokio::test]
async fn fle2_example() -> Result<()> {
Expand Down