Skip to content

RUST-1442 On-demand Azure KMS credentials #872

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions .evergreen/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ functions:
export TOPOLOGY=${TOPOLOGY}
export MONGODB_VERSION=${MONGODB_VERSION}

export AZURE_IMDS_MOCK_PORT=44175

if [ "Windows_NT" != "$OS" ]; then
ulimit -n 64000
fi
Expand Down Expand Up @@ -488,6 +490,16 @@ functions:
export TLS_FEATURE=${TLS_FEATURE}
.evergreen/run-csfle-kmip-servers.sh

"run mock azure imds server":
- command: shell.exec
params:
shell: bash
working_dir: "src"
background: true
script: |
${PREPARE_SHELL}
.evergreen/run-csfle-mock-azure-imds.sh

"build csfle expansions":
- command: shell.exec
params:
Expand Down Expand Up @@ -1214,6 +1226,7 @@ tasks:
- func: "install junit dependencies"
- func: "bootstrap mongo-orchestration"
- func: "run kmip server"
- func: "run mock azure imds server"
- func: "build csfle expansions"
- func: "run csfle tests"

Expand All @@ -1229,6 +1242,7 @@ tasks:
- func: "install junit dependencies"
- func: "install libmongocrypt"
- func: "run kmip server"
- func: "run mock azure imds server"
- func: "build csfle expansions"
- func: "run csfle serverless tests"

Expand Down
2 changes: 1 addition & 1 deletion .evergreen/feature-combinations.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ export NO_FEATURES=''
# async-std-related features that conflict with the library's default features.
export ASYNC_STD_FEATURES='--no-default-features --features async-std-runtime,sync'
# All additional features that do not conflict with the default features. New features added to the library should also be added to this list.
export ADDITIONAL_FEATURES='--features tokio-sync,zstd-compression,snappy-compression,zlib-compression,openssl-tls,aws-auth,tracing-unstable,in-use-encryption-unstable'
export ADDITIONAL_FEATURES='--features tokio-sync,zstd-compression,snappy-compression,zlib-compression,openssl-tls,aws-auth,tracing-unstable,in-use-encryption-unstable,azure-kms'


# Array of feature combinations that, in total, provides complete coverage of the driver.
Expand Down
9 changes: 9 additions & 0 deletions .evergreen/run-csfle-mock-azure-imds.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/bin/bash

. ${DRIVERS_TOOLS}/.evergreen/find-python3.sh
PYTHON=$(find_python3)

function prepend() { while read line; do echo "${1}${line}"; done; }

cd ${DRIVERS_TOOLS}/.evergreen/csfle
${PYTHON} bottle.py fake_azure:imds -b localhost:${AZURE_IMDS_MOCK_PORT} 2>&1 | prepend "[MOCK AZURE IMDS] "
2 changes: 1 addition & 1 deletion .evergreen/run-csfle-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ source ./.evergreen/env.sh

set -o xtrace

FEATURE_FLAGS="in-use-encryption-unstable,aws-auth,${TLS_FEATURE}"
FEATURE_FLAGS="in-use-encryption-unstable,aws-auth,azure-kms,${TLS_FEATURE}"
OPTIONS="-- -Z unstable-options --format json --report-time"

if [ "$SINGLE_THREAD" = true ]; then
Expand Down
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ bson-uuid-1 = ["bson/uuid-1"]
# This can only be used with the tokio-runtime feature flag.
aws-auth = ["reqwest"]

# Enable support for on-demand Azure KMS credentials.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a note here that this can only be used with tokio (similar to above)? It's a bummer that we need to add a new feature flag here; I'll also need to do so for the GCP KMS work. We should try to consolidate these in 3.0.0; we could unify them under the reqwest feature created by the optional dependency, or possibly make them simpler if we remove support for tokio.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, done. And yeah, agreed that it's not a great situation w.r.t. features and dependencies.

# This can only be used with the tokio-runtime feature flag.
azure-kms = ["reqwest"]

zstd-compression = ["zstd"]
zlib-compression = ["flate2"]
snappy-compression = ["snap"]
Expand Down
13 changes: 9 additions & 4 deletions src/client/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ use crate::{
client::options::ServerApi,
cmap::{Command, Connection, StreamDescription},
error::{Error, ErrorKind, Result},
runtime::HttpClient,
};

const SCRAM_SHA_1_STR: &str = "SCRAM-SHA-1";
Expand Down Expand Up @@ -253,7 +252,7 @@ impl AuthMechanism {
stream: &mut Connection,
credential: &Credential,
server_api: Option<&ServerApi>,
#[cfg_attr(not(feature = "aws-auth"), allow(unused))] http_client: &HttpClient,
#[cfg(feature = "aws-auth")] http_client: &crate::runtime::HttpClient,
) -> Result<()> {
self.validate_credential(credential)?;

Expand Down Expand Up @@ -398,9 +397,9 @@ impl Credential {
pub(crate) async fn authenticate_stream(
&self,
conn: &mut Connection,
http_client: &HttpClient,
server_api: Option<&ServerApi>,
first_round: Option<FirstRound>,
#[cfg(feature = "aws-auth")] http_client: &crate::runtime::HttpClient,
) -> Result<()> {
let stream_description = conn.stream_description()?;

Expand Down Expand Up @@ -431,7 +430,13 @@ impl Credential {

// Authenticate according to the chosen mechanism.
mechanism
.authenticate_stream(conn, self, server_api, http_client)
.authenticate_stream(
conn,
self,
server_api,
#[cfg(feature = "aws-auth")]
http_client,
)
.await
}

Expand Down
2 changes: 1 addition & 1 deletion src/client/csfle.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
pub(crate) mod client_builder;
pub mod client_encryption;
pub mod options;
mod state_machine;
pub(crate) mod state_machine;

use std::{path::Path, time::Duration};

Expand Down
157 changes: 153 additions & 4 deletions src/client/csfle/state_machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ pub(crate) struct CryptExecutor {
mongocryptd: Option<Mongocryptd>,
mongocryptd_client: Option<Client>,
metadata_client: Option<WeakClient>,
#[cfg(feature = "azure-kms")]
azure: azure::ExecutorState,
}

impl CryptExecutor {
Expand All @@ -56,6 +58,8 @@ impl CryptExecutor {
mongocryptd: None,
mongocryptd_client: None,
metadata_client: None,
#[cfg(feature = "azure-kms")]
azure: azure::ExecutorState::new()?,
})
}

Expand Down Expand Up @@ -211,11 +215,10 @@ impl CryptExecutor {
let ctx = result_mut(&mut ctx)?;
#[allow(unused_mut)]
let mut out = rawdoc! {};
if self
.kms_providers
.credentials()
let credentials = self.kms_providers.credentials();
if credentials
.get(&KmsProvider::Aws)
.map_or(false, |d| d.is_empty())
.map_or(false, Document::is_empty)
{
#[cfg(feature = "aws-auth")]
{
Expand All @@ -240,6 +243,21 @@ impl CryptExecutor {
));
}
}
if credentials
.get(&KmsProvider::Azure)
.map_or(false, Document::is_empty)
{
#[cfg(feature = "azure-kms")]
{
out.append("azure", self.azure.get_token().await?);
}
#[cfg(not(feature = "azure-kms"))]
{
return Err(Error::invalid_argument(
"On-demand Azure KMS credentials require the `azure-kms` feature.",
));
}
}
ctx.provide_kms_providers(&out)?;
}
State::Ready => {
Expand Down Expand Up @@ -346,3 +364,134 @@ fn raw_to_doc(raw: &RawDocument) -> Result<Document> {
raw.try_into()
.map_err(|e| Error::internal(format!("could not parse raw document: {}", e)))
}

#[cfg(feature = "azure-kms")]
pub(crate) mod azure {
use bson::{rawdoc, RawDocumentBuf};
use serde::Deserialize;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;

use crate::{
error::{Error, Result},
runtime::HttpClient,
};

#[derive(Debug)]
pub(crate) struct ExecutorState {
cached_access_token: Mutex<Option<CachedAccessToken>>,
http: HttpClient,
#[cfg(test)]
pub(crate) test_host: Option<(&'static str, u16)>,
#[cfg(test)]
pub(crate) test_param: Option<&'static str>,
}

impl ExecutorState {
pub(crate) fn new() -> Result<Self> {
const AZURE_IMDS_TIMEOUT: Duration = Duration::from_secs(10);
Ok(Self {
cached_access_token: Mutex::new(None),
http: HttpClient::with_timeout(AZURE_IMDS_TIMEOUT)?,
#[cfg(test)]
test_host: None,
#[cfg(test)]
test_param: None,
})
}

pub(crate) async fn get_token(&self) -> Result<RawDocumentBuf> {
let mut cached_token = self.cached_access_token.lock().await;
if let Some(cached) = &*cached_token {
if cached.expire_time.saturating_duration_since(Instant::now())
> Duration::from_secs(60)
{
return Ok(cached.token_doc.clone());
}
}
let token = self.fetch_new_token().await?;
let out = token.token_doc.clone();
*cached_token = Some(token);
Ok(out)
}

async fn fetch_new_token(&self) -> Result<CachedAccessToken> {
let now = Instant::now();
let server_response: ServerResponse = self
.http
.get_and_deserialize_json(self.make_url()?, &self.make_headers())
.await
.map_err(|e| Error::authentication_error("azure imds", &format!("{}", e)))?;
let expires_in_secs: u64 = server_response.expires_in.parse().map_err(|e| {
Error::authentication_error(
"azure imds",
&format!("invalid `expires_in` response field: {}", e),
)
})?;
#[allow(clippy::redundant_clone)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is a redundant clone needed here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The #[cfg(test)] line below needs server_response, but without the clone this line is a partial move.

Ok(CachedAccessToken {
token_doc: rawdoc! { "accessToken": server_response.access_token.clone() },
expire_time: now + Duration::from_secs(expires_in_secs),
#[cfg(test)]
server_response,
})
}

fn make_url(&self) -> Result<reqwest::Url> {
let url = reqwest::Url::parse_with_params(
"http://169.254.169.254/metadata/identity/oauth2/token",
&[
("api-version", "2018-02-01"),
("resource", "https://vault.azure.net"),
],
)
.map_err(|e| Error::internal(format!("invalid Azure IMDS URL: {}", e)))?;
#[cfg(test)]
let url = {
let mut url = url;
if let Some((host, port)) = self.test_host {
url.set_host(Some(host))
.map_err(|e| Error::internal(format!("invalid test host: {}", e)))?;
url.set_port(Some(port))
.map_err(|()| Error::internal(format!("invalid test port {}", port)))?;
}
url
};
Ok(url)
}

fn make_headers(&self) -> Vec<(&'static str, &'static str)> {
let headers = vec![("Metadata", "true"), ("Accept", "application/json")];
#[cfg(test)]
let headers = {
let mut headers = headers;
if let Some(p) = self.test_param {
headers.push(("X-MongoDB-HTTP-TestParams", p));
}
headers
};
headers
}

#[cfg(test)]
pub(crate) async fn take_cached(&self) -> Option<CachedAccessToken> {
self.cached_access_token.lock().await.take()
}
}

#[derive(Debug, Deserialize)]
pub(crate) struct ServerResponse {
pub(crate) access_token: String,
pub(crate) expires_in: String,
#[allow(unused)]
pub(crate) resource: String,
}

#[derive(Debug)]
pub(crate) struct CachedAccessToken {
pub(crate) token_doc: RawDocumentBuf,
pub(crate) expire_time: Instant,
#[cfg(test)]
pub(crate) server_response: ServerResponse,
}
}
14 changes: 8 additions & 6 deletions src/cmap/establish/handshake/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use crate::{
error::Result,
hello::{hello_command, run_hello, HelloReply},
options::{AuthMechanism, Credential, DriverInfo, ServerApi},
runtime::HttpClient,
};

#[cfg(all(feature = "tokio-runtime", not(feature = "tokio-sync")))]
Expand Down Expand Up @@ -323,16 +322,17 @@ pub(crate) struct Handshaker {
#[allow(dead_code)]
compressors: Option<Vec<Compressor>>,

http_client: HttpClient,

server_api: Option<ServerApi>,

metadata: ClientMetadata,

#[cfg(feature = "aws-auth")]
http_client: crate::runtime::HttpClient,
}

impl Handshaker {
/// Creates a new Handshaker.
pub(crate) fn new(http_client: HttpClient, options: HandshakerOptions) -> Self {
pub(crate) fn new(options: HandshakerOptions) -> Self {
let mut metadata = BASE_CLIENT_METADATA.clone();
let compressors = options.compressors;

Expand Down Expand Up @@ -383,11 +383,12 @@ impl Handshaker {
command.body.insert("client", metadata.clone());

Self {
http_client,
command,
compressors,
server_api: options.server_api,
metadata,
#[cfg(feature = "aws-auth")]
http_client: crate::runtime::HttpClient::default(),
}
}

Expand Down Expand Up @@ -457,9 +458,10 @@ impl Handshaker {
credential
.authenticate_stream(
conn,
&self.http_client,
self.server_api.as_ref(),
first_round,
#[cfg(feature = "aws-auth")]
&self.http_client,
)
.await?
}
Expand Down
Loading