Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion end-to-end-tests/src/bin/bootstrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ async fn main() -> Result<()> {
deserialize_byte_stream(
ctx.client
.device_auth_request()
.body(DeviceAuthRequest { client_id })
.body(DeviceAuthRequest { client_id, ttl_seconds: None })
.send()
.await?,
)
Expand Down
12 changes: 11 additions & 1 deletion nexus/db-model/src/device_auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ use chrono::{DateTime, Duration, Utc};
use nexus_types::external_api::views;
use omicron_uuid_kinds::{AccessTokenKind, GenericUuid, TypedUuid};
use rand::{Rng, RngCore, SeedableRng, distributions::Slice, rngs::StdRng};
use std::num::NonZeroU32;
use uuid::Uuid;

use crate::SqlU32;
use crate::typed_uuid::DbTypedUuid;

/// Default timeout in seconds for client to authenticate for a token request.
Expand All @@ -32,6 +34,9 @@ pub struct DeviceAuthRequest {
pub user_code: String,
pub time_created: DateTime<Utc>,
pub time_expires: DateTime<Utc>,

/// TTL requested by the user
pub token_ttl_seconds: Option<SqlU32>,
}

impl DeviceAuthRequest {
Expand Down Expand Up @@ -98,7 +103,10 @@ fn generate_user_code() -> String {
}

impl DeviceAuthRequest {
pub fn new(client_id: Uuid) -> Self {
pub fn new(
client_id: Uuid,
requested_ttl_seconds: Option<NonZeroU32>,
) -> Self {
let now = Utc::now();
Self {
client_id,
Expand All @@ -107,6 +115,8 @@ impl DeviceAuthRequest {
time_created: now,
time_expires: now
+ Duration::seconds(CLIENT_AUTHENTICATION_TIMEOUT),
token_ttl_seconds: requested_ttl_seconds
.map(|ttl| ttl.get().into()),
}
}

Expand Down
3 changes: 2 additions & 1 deletion nexus/db-model/src/schema_versions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use std::{collections::BTreeMap, sync::LazyLock};
///
/// This must be updated when you change the database schema. Refer to
/// schema/crdb/README.adoc in the root of this repository for details.
pub const SCHEMA_VERSION: Version = Version::new(146, 0, 0);
pub const SCHEMA_VERSION: Version = Version::new(147, 0, 0);

/// List of all past database schema versions, in *reverse* order
///
Expand All @@ -28,6 +28,7 @@ static KNOWN_VERSIONS: LazyLock<Vec<KnownVersion>> = LazyLock::new(|| {
// | leaving the first copy as an example for the next person.
// v
// KnownVersion::new(next_int, "unique-dirname-with-the-sql-files"),
KnownVersion::new(147, "device-auth-request-ttl"),
KnownVersion::new(146, "silo-settings-token-expiration"),
KnownVersion::new(145, "token-and-session-ids"),
KnownVersion::new(144, "inventory-omicron-sled-config"),
Expand Down
1 change: 1 addition & 0 deletions nexus/db-schema/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1360,6 +1360,7 @@ table! {
device_code -> Text,
time_created -> Timestamptz,
time_expires -> Timestamptz,
token_ttl_seconds -> Nullable<Int8>,
}
}

Expand Down
37 changes: 27 additions & 10 deletions nexus/src/app/device_auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ use nexus_db_queries::context::OpContext;
use nexus_db_queries::db::model::{DeviceAccessToken, DeviceAuthRequest};

use anyhow::anyhow;
use nexus_types::external_api::params::DeviceAccessTokenRequest;
use nexus_types::external_api::params;
use nexus_types::external_api::views;
use omicron_common::api::external::{
CreateResult, DataPageParams, Error, ListResultVec,
Expand All @@ -77,13 +77,17 @@ impl super::Nexus {
pub(crate) async fn device_auth_request_create(
&self,
opctx: &OpContext,
client_id: Uuid,
params: params::DeviceAuthRequest,
) -> CreateResult<DeviceAuthRequest> {
// TODO-correctness: the `user_code` generated for a new request
// is used as a primary key, but may potentially collide with an
// existing outstanding request. So we should retry some (small)
// number of times if inserting the new request fails.
let auth_request = DeviceAuthRequest::new(client_id);

// Note that we cannot validate the TTL here against the silo max
// because we do not know what silo we're talking about until verify
let auth_request =
DeviceAuthRequest::new(params.client_id, params.ttl_seconds);
self.db_datastore.device_auth_request_create(opctx, auth_request).await
}

Expand Down Expand Up @@ -115,17 +119,30 @@ impl super::Nexus {
.silo_auth_settings_view(opctx, &authz_silo)
.await?;

// Create an access token record.
let silo_max_ttl = silo_auth_settings.device_token_max_ttl_seconds;
let requested_ttl = db_request.token_ttl_seconds;

// Validate the requested TTL against the silo's max TTL
if let (Some(requested), Some(max)) = (requested_ttl, silo_max_ttl) {
if requested > max.0.into() {
return Err(Error::invalid_request(&format!(
"Requested TTL {} seconds exceeds maximum \
allowed TTL for this silo of {} seconds",
requested, max
)));
}
}

let time_expires = requested_ttl
.or(silo_max_ttl)
.map(|ttl| Utc::now() + Duration::seconds(ttl.0.into()));

let token = DeviceAccessToken::new(
db_request.client_id,
db_request.device_code,
db_request.time_created,
silo_user_id,
// Token gets the max TTL for the silo (if there is one) until we
// build a way for the user to ask for a different TTL
silo_auth_settings
.device_token_max_ttl_seconds
.map(|ttl| Utc::now() + Duration::seconds(ttl.0.into())),
time_expires,
);

if db_request.time_expires < Utc::now() {
Expand Down Expand Up @@ -224,7 +241,7 @@ impl super::Nexus {
pub(crate) async fn device_access_token(
&self,
opctx: &OpContext,
params: DeviceAccessTokenRequest,
params: params::DeviceAccessTokenRequest,
) -> Result<Response<Body>, HttpError> {
// RFC 8628 §3.4
if params.grant_type != "urn:ietf:params:oauth:grant-type:device_code" {
Expand Down
5 changes: 2 additions & 3 deletions nexus/src/external_api/http_entrypoints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7856,9 +7856,8 @@ impl NexusExternalApi for NexusExternalApiImpl {
}
};

let model = nexus
.device_auth_request_create(&opctx, params.client_id)
.await?;
let model =
nexus.device_auth_request_create(&opctx, params).await?;
nexus.build_oauth_response(
StatusCode::OK,
&model.into_response(rqctx.server.using_tls(), host),
Expand Down
186 changes: 183 additions & 3 deletions nexus/tests/integration_tests/device_auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
use std::num::NonZeroU32;

use chrono::Utc;
use dropshot::ResultsPage;
use dropshot::test_util::ClientTestContext;
use dropshot::{HttpErrorResponseBody, ResultsPage};
use nexus_auth::authn::USER_TEST_UNPRIVILEGED;
use nexus_db_queries::db::fixed_data::silo::DEFAULT_SILO;
use nexus_db_queries::db::identity::{Asset, Resource};
Expand Down Expand Up @@ -55,7 +55,9 @@ async fn test_device_auth_flow(cptestctx: &ControlPlaneTestContext) {
.expect("failed to reject device auth start without client_id");

let client_id = Uuid::new_v4();
let authn_params = DeviceAuthRequest { client_id };
// note that this exercises ttl_seconds being omitted from the body because
// it's URL encoded, so None means it's omitted
let authn_params = DeviceAuthRequest { client_id, ttl_seconds: None };

// Using a JSON encoded body fails.
RequestBuilder::new(testctx, Method::POST, "/device/auth")
Expand Down Expand Up @@ -241,7 +243,7 @@ async fn test_device_auth_flow(cptestctx: &ControlPlaneTestContext) {
/// as a string
async fn get_device_token(testctx: &ClientTestContext) -> String {
let client_id = Uuid::new_v4();
let authn_params = DeviceAuthRequest { client_id };
let authn_params = DeviceAuthRequest { client_id, ttl_seconds: None };

// Start a device authentication flow
let auth_response: DeviceAuthResponse =
Expand Down Expand Up @@ -431,6 +433,184 @@ async fn test_device_token_expiration(cptestctx: &ControlPlaneTestContext) {
assert_eq!(settings.device_token_max_ttl_seconds, None);
}

// lets me stick whatever I want in this thing to be URL-encoded
#[derive(serde::Serialize)]
struct BadAuthReq {
client_id: String,
ttl_seconds: String,
}

/// Test that 0 and negative values for ttl_seconds give immediate 400s
#[nexus_test]
async fn test_device_token_request_ttl_invalid(
cptestctx: &ControlPlaneTestContext,
) {
let testctx = &cptestctx.external_client;

let auth_response = NexusRequest::new(
RequestBuilder::new(testctx, Method::POST, "/device/auth")
.allow_non_dropshot_errors()
.body_urlencoded(Some(&BadAuthReq {
client_id: Uuid::new_v4().to_string(),
ttl_seconds: "0".to_string(),
}))
.expect_status(Some(StatusCode::BAD_REQUEST)),
)
.execute()
// .execute_and_parse_unwrap::<DeviceAuthResponse>()
.await
.expect("expected an Ok(TestResponse)");

let error_body: serde_json::Value =
serde_json::from_slice(&auth_response.body).unwrap();
assert_eq!(
error_body.get("message").unwrap().to_string(),
"\"unable to parse URL-encoded body: ttl_seconds: \
invalid value: integer `0`, expected a nonzero u32\""
);

let auth_response = NexusRequest::new(
RequestBuilder::new(testctx, Method::POST, "/device/auth")
.allow_non_dropshot_errors()
.body_urlencoded(Some(&BadAuthReq {
client_id: Uuid::new_v4().to_string(),
ttl_seconds: "-3".to_string(),
}))
.expect_status(Some(StatusCode::BAD_REQUEST)),
)
.execute()
// .execute_and_parse_unwrap::<DeviceAuthResponse>()
.await
.expect("expected an Ok(TestResponse)");

let error_body: serde_json::Value =
serde_json::from_slice(&auth_response.body).unwrap();
assert_eq!(
error_body.get("message").unwrap().to_string(),
"\"unable to parse URL-encoded body: ttl_seconds: \
invalid digit found in string\""
);
}

#[nexus_test]
async fn test_device_token_request_ttl(cptestctx: &ControlPlaneTestContext) {
let testctx = &cptestctx.external_client;

// Set silo max TTL to 10 seconds
let settings = params::SiloAuthSettingsUpdate {
device_token_max_ttl_seconds: NonZeroU32::new(10).into(),
};
let _: views::SiloAuthSettings =
object_put(testctx, "/v1/auth-settings", &settings).await;

// Request TTL above the max should fail at verification time
let invalid_ttl = DeviceAuthRequest {
client_id: Uuid::new_v4(),
ttl_seconds: NonZeroU32::new(20), // Above the 10 second max
};

let auth_response = NexusRequest::new(
RequestBuilder::new(testctx, Method::POST, "/device/auth")
.body_urlencoded(Some(&invalid_ttl))
.expect_status(Some(StatusCode::OK)),
)
.execute_and_parse_unwrap::<DeviceAuthResponse>()
.await;

let confirm_params =
DeviceAuthVerify { user_code: auth_response.user_code };

// Confirmation fails because requested TTL exceeds max
let confirm_error = NexusRequest::new(
RequestBuilder::new(testctx, Method::POST, "/device/confirm")
.body(Some(&confirm_params))
.expect_status(Some(StatusCode::BAD_REQUEST)),
)
.authn_as(AuthnMode::PrivilegedUser)
.execute_and_parse_unwrap::<HttpErrorResponseBody>()
.await;

// Check that the error message mentions TTL
assert_eq!(confirm_error.error_code, Some("InvalidRequest".to_string()));
assert_eq!(
confirm_error.message,
"Requested TTL 20 seconds exceeds maximum allowed TTL \
for this silo of 10 seconds"
);

// Request TTL below the max should succeed and be used
let valid_ttl = DeviceAuthRequest {
client_id: Uuid::new_v4(),
ttl_seconds: NonZeroU32::new(3), // Below the 10 second max
};

let auth_response = NexusRequest::new(
RequestBuilder::new(testctx, Method::POST, "/device/auth")
.body_urlencoded(Some(&valid_ttl))
.expect_status(Some(StatusCode::OK)),
)
.execute_and_parse_unwrap::<DeviceAuthResponse>()
.await;

let device_code = auth_response.device_code;
let user_code = auth_response.user_code;
let confirm_params = DeviceAuthVerify { user_code };

// this time will be pretty close to the now() used on the server when
// calculating expiration time
let t0 = Utc::now();

// Confirmation should succeed
NexusRequest::new(
RequestBuilder::new(testctx, Method::POST, "/device/confirm")
.body(Some(&confirm_params))
.expect_status(Some(StatusCode::NO_CONTENT)),
)
.authn_as(AuthnMode::PrivilegedUser)
.execute()
.await
.expect("failed to confirm");

let token_params = DeviceAccessTokenRequest {
grant_type: "urn:ietf:params:oauth:grant-type:device_code".to_string(),
device_code,
client_id: valid_ttl.client_id,
};

// Get the token
let token_grant = NexusRequest::new(
RequestBuilder::new(testctx, Method::POST, "/device/token")
.allow_non_dropshot_errors()
.body_urlencoded(Some(&token_params))
.expect_status(Some(StatusCode::OK)),
)
.authn_as(AuthnMode::PrivilegedUser)
.execute_and_parse_unwrap::<DeviceAccessTokenGrant>()
.await;

// Verify the token has roughly the correct expiration time. One second
// threshold is sufficient to confirm it's not getting the silo max of 10
// seconds. Locally, I saw diffs as low as 14ms.
let tokens = get_tokens_priv(testctx).await;
let time_expires = tokens[0].time_expires.unwrap();
let expected_expires = t0 + Duration::from_secs(3);
let diff_ms = (time_expires - expected_expires).num_milliseconds().abs();
assert!(diff_ms <= 1000, "time diff was {diff_ms} ms. should be near zero");

// Token should work initially
project_list(&testctx, &token_grant.access_token, StatusCode::OK)
.await
.expect("token should work initially");

// Wait for token to expire
sleep(Duration::from_secs(4)).await;

// Token is expired
project_list(&testctx, &token_grant.access_token, StatusCode::UNAUTHORIZED)
.await
.expect("token should be expired");
}

async fn get_tokens_priv(
testctx: &ClientTestContext,
) -> Vec<views::DeviceAccessToken> {
Expand Down
3 changes: 3 additions & 0 deletions nexus/types/src/external_api/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2467,6 +2467,9 @@ impl TryFrom<String> for RelativeUri {
#[derive(Clone, Debug, Deserialize, Serialize, JsonSchema)]
pub struct DeviceAuthRequest {
pub client_id: Uuid,
/// Optional lifetime for the access token in seconds. If not specified, the
/// silo's max TTL will be used (if set).
pub ttl_seconds: Option<NonZeroU32>,
}

#[derive(Clone, Debug, Deserialize, Serialize, JsonSchema)]
Expand Down
Loading
Loading