From c48e29e8893da021e4d6a5fb1cfd6e551bc24451 Mon Sep 17 00:00:00 2001 From: Alexander Korolev Date: Thu, 30 Mar 2023 23:01:14 +0200 Subject: [PATCH] added JWT support for userinfo endpoint, custom userinfo claims --- Cargo.toml | 1 + src/client.rs | 141 ++++++++++++++++++++++++++------- src/discovered.rs | 4 +- src/error.rs | 6 ++ src/standard_claims_subject.rs | 2 +- src/userinfo.rs | 3 +- 6 files changed, 124 insertions(+), 33 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8930f80..14d854b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ base64 = "0.13" biscuit = "0.6" thiserror = "1" validator = { version = "0.15", features = ["derive"] } +mime = "0.3" [dependencies.url] version = "2" diff --git a/src/client.rs b/src/client.rs index 8566997..731451c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -21,7 +21,7 @@ use serde_json::Value; use std::marker::PhantomData; use url::{form_urlencoded::Serializer, Url}; -/// OAuth 2.0 client. +/// OpenID Connect 1.0 / OAuth 2.0 client. #[derive(Debug)] pub struct Client

{ /// OAuth provider. @@ -76,7 +76,7 @@ impl Client { /// Constructs a client from an issuer url and client parameters via discovery pub async fn discover( id: String, - secret: String, + secret: impl Into>, redirect: impl Into>, issuer: Url, ) -> Result { @@ -87,7 +87,7 @@ impl Client { pub async fn discover_with_client( http_client: reqwest::Client, id: String, - secret: String, + secret: impl Into>, redirect: impl Into>, issuer: Url, ) -> Result { @@ -100,7 +100,7 @@ impl Client { provider, id, secret, - redirect.into(), + redirect, http_client, Some(jwks), )) @@ -191,15 +191,19 @@ impl Client { Ok(token) } - /// Mutates a Compact::encoded Token to Compact::decoded. Errors are: + /// Mutates a Compact::encoded Token to Compact::decoded. /// - /// - Decode::MissingKid if the keyset has multiple keys but the key id on the token is missing - /// - Decode::MissingKey if the given key id is not in the key set - /// - Decode::EmptySet if the keyset is empty - /// - Jose::WrongKeyType if the alg of the key and the alg in the token header mismatch - /// - Jose::WrongKeyType if the specified key alg isn't a signature algorithm - /// - Jose error if decoding fails - pub fn decode_token(&self, token: &mut IdToken) -> Result<(), Error> { + /// # Errors + /// + /// - [Decode::MissingKid] if the keyset has multiple keys but the key id on the token is missing + /// - [Decode::MissingKey] if the given key id is not in the key set + /// - [Decode::EmptySet] if the keyset is empty + /// - [Jose::WrongKeyType] if the alg of the key and the alg in the token header mismatch + /// - [Jose::WrongKeyType] if the specified key alg isn't a signature algorithm + /// - [Decode::UnsupportedEllipticCurve] if the alg is cryptographic curve + /// - [Decode::UnsupportedOctetKeyPair] if the alg is octet key pair + /// - [Error::Jose] error if decoding fails + pub fn decode_token(&self, token: &mut IdToken) -> Result<(), Error> { // This is an early return if the token is already decoded if let Compact::Decoded { .. } = *token { return Ok(()); @@ -263,18 +267,20 @@ impl Client { } /// Validate a decoded token. If you don't get an error, its valid! Nonce and max_age come from - /// your auth_uri options. Errors are: + /// your auth_uri options. + /// + /// # Errors /// - /// - Jose Error if the Token isn't decoded - /// - Validation::Mismatch::Issuer if the provider issuer and token issuer mismatch - /// - Validation::Mismatch::Nonce if a given nonce and the token nonce mismatch - /// - Validation::Missing::Nonce if args has a nonce and the token does not - /// - Validation::Missing::Audience if the token aud doesn't contain the client id - /// - Validation::Missing::AuthorizedParty if there are multiple audiences and azp is missing - /// - Validation::Mismatch::AuthorizedParty if the azp is not the client_id - /// - Validation::Expired::Expires if the current time is past the expiration time - /// - Validation::Expired::MaxAge is the token is older than the provided max_age - /// - Validation::Missing::Authtime if a max_age was given and the token has no auth time + /// - [Error::Jose] Error if the Token isn't decoded + /// - [Error::Validation]::[Mismatch](crate::error::Validation::Mismatch)::[Issuer](crate::error::Mismatch::Issuer) if the provider issuer and token issuer mismatch + /// - [Error::Validation]::[Mismatch](crate::error::Validation::Mismatch)::[Nonce](crate::error::Mismatch::Nonce) if a given nonce and the token nonce mismatch + /// - [Error::Validation]::[Missing](crate::error::Validation::Missing)::[Nonce](crate::error::Missing::Nonce) if args has a nonce and the token does not + /// - [Error::Validation]::[Missing](crate::error::Validation::Missing)::[Audience](crate::error::Missing::Audience) if the token aud doesn't contain the client id + /// - [Error::Validation]::[Missing](crate::error::Validation::Missing)::[AuthorizedParty](crate::error::Missing::AuthorizedParty) if there are multiple audiences and azp is missing + /// - [Error::Validation]::[Mismatch](crate::error::Validation::Mismatch)::[AuthorizedParty](crate::error::Mismatch::AuthorizedParty) if the azp is not the client_id + /// - [Error::Validation]::[Expired](crate::error::Validation::Expired)::[Expires](crate::error::Expiry::Expires) if the current time is past the expiration time + /// - [Error::Validation]::[Expired](crate::error::Validation::Expired)::[MaxAge](crate::error::Expiry::MaxAge) is the token is older than the provided max_age + /// - [Error::Validation]::[Missing](crate::error::Validation::Missing)::[AuthTime](crate::error::Missing::AuthTime) if a max_age was given and the token has no auth time pub fn validate_token<'nonce, 'max_age>( &self, token: &IdToken, @@ -298,7 +304,7 @@ impl Client { /// Get a userinfo json document for a given token at the provider's userinfo endpoint. /// Returns [Standard Claims](https://openid.net/specs/openid-connect-basic-1_0.html#StandardClaims) as [Userinfo] struct. /// - /// Errors are: + /// # Errors /// /// - [ErrorUserinfo::NoUrl] if this provider doesn't have a userinfo endpoint /// - [Error::Insecure] if the userinfo url is not https @@ -315,31 +321,105 @@ impl Client { /// Returns [UserInfo Response](https://openid.net/specs/openid-connect-basic-1_0.html#UserInfoResponse) /// including non-standard claims. The sub (subject) Claim MUST always be returned in the UserInfo Response. /// - /// Errors are: + /// # Errors /// /// - [ErrorUserinfo::NoUrl] if this provider doesn't have a userinfo endpoint /// - [Error::Insecure] if the userinfo url is not https + /// - [Decode::MissingKid] if the keyset has multiple keys but the key id on the token is missing + /// - [Decode::MissingKey] if the given key id is not in the key set + /// - [Decode::EmptySet] if the keyset is empty + /// - [Jose::WrongKeyType] if the alg of the key and the alg in the token header mismatch + /// - [Jose::WrongKeyType] if the specified key alg isn't a signature algorithm /// - [Error::Jose] if the token is not decoded /// - [Error::Http] if something goes wrong getting the document /// - [Error::Json] if the response is not a valid Userinfo document /// - [ErrorUserinfo::MissingSubject] if subject (sub) is missing /// - [ErrorUserinfo::MismatchSubject] if the returned userinfo document and tokens subject mismatch + /// - [ErrorUserinfo::MissingContentType] if content-type header is missing + /// - [ErrorUserinfo::ParseContentType] if content-type header is not parsable + /// - [ErrorUserinfo::WrongContentType] if content-type header is not accepted + /// + /// # Examples + /// + /// ```no_run + /// # use openid::{Bearer, DiscoveredClient, error::StandardClaimsSubjectMissing, StandardClaims, StandardClaimsSubject, Token}; + /// # use serde::{Deserialize, Serialize}; + /// # async fn _main() -> Result<(), Box> { + /// # let bearer: Bearer = serde_json::from_str("{}").unwrap(); + /// # let token = Token::::from(bearer); + /// # let client = DiscoveredClient::discover("client_id".to_string(), "client_secret".to_string(), "http://redirect".to_string(), url::Url::parse("http://issuer".into()).unwrap(),).await?; + ///#[derive(Debug, Deserialize, Serialize)] + ///struct CustomUserinfo(std::collections::HashMap); + /// + ///impl StandardClaimsSubject for CustomUserinfo { + /// fn sub(&self) -> Result<&str, StandardClaimsSubjectMissing> { + /// self.0 + /// .get("sub") + /// .and_then(|x| x.as_str()) + /// .ok_or(StandardClaimsSubjectMissing) + /// } + ///} + /// + ///impl openid::CompactJson for CustomUserinfo {} + /// + ///let custom_userinfo: CustomUserinfo = client.request_userinfo_custom(&token).await?; + /// # Ok(()) } + /// ``` pub async fn request_userinfo_custom(&self, token: &Token) -> Result where - U: StandardClaimsSubject + serde::de::DeserializeOwned, + U: StandardClaimsSubject, { match self.config().userinfo_endpoint { Some(ref url) => { let auth_code = token.bearer.access_token.to_string(); - let info: U = self + let response = self .http_client .get(url.clone()) .bearer_auth(auth_code) .send() .await? - .json() - .await?; + .error_for_status()?; + + let content_type = response + .headers() + .get(&CONTENT_TYPE) + .and_then(|content_type| content_type.to_str().ok()) + .ok_or(ErrorUserinfo::MissingContentType)?; + + let mime_type = match content_type { + "application/json" => mime::APPLICATION_JSON, + content_type => content_type.parse::().map_err(|_| { + ErrorUserinfo::ParseContentType { + content_type: content_type.to_string(), + } + })?, + }; + + let info: U = match (mime_type.type_(), mime_type.subtype().as_str()) { + (mime::APPLICATION, "json") => { + let info_value: Value = response.json().await?; + if info_value.get("error").is_some() { + let oauth2_error: OAuth2Error = serde_json::from_value(info_value)?; + return Err(Error::ClientError(oauth2_error.into())); + } + serde_json::from_value(info_value)? + } + (mime::APPLICATION, "jwt") => { + let jwt = response.text().await?; + let mut jwt_encoded: Compact = Compact::new_encoded(&jwt); + self.decode_token(&mut jwt_encoded)?; + let (_, info) = jwt_encoded.unwrap_decoded(); + info + } + _ => { + return Err(ErrorUserinfo::WrongContentType { + content_type: content_type.to_string(), + body: response.bytes().await?.to_vec(), + } + .into()) + } + }; let claims = token.id_token.as_ref().map(|x| x.payload()).transpose()?; if let Some(claims) = claims { @@ -400,6 +480,9 @@ where /// Returns an authorization endpoint URI to direct the user to. /// + /// This function is used by [Client::auth_url]. + /// In most situations it is non needed to use it directly. + /// /// See [RFC 6749, section 3.1](http://tools.ietf.org/html/rfc6749#section-3.1). /// /// # Examples diff --git a/src/discovered.rs b/src/discovered.rs index 7474c8f..83dae25 100644 --- a/src/discovered.rs +++ b/src/discovered.rs @@ -35,13 +35,13 @@ pub async fn discover(client: &Client, mut issuer: Url) -> Result .map_err(|_| Error::CannotBeABase)? .extend(&[".well-known", "openid-configuration"]); - let resp = client.get(issuer).send().await?; + let resp = client.get(issuer).send().await?.error_for_status()?; resp.json().await.map_err(Error::from) } /// Get the JWK set from the given Url. Errors are either a reqwest error or an Insecure error if /// the url isn't https. pub async fn jwks(client: &Client, url: Url) -> Result, Error> { - let resp = client.get(url).send().await?; + let resp = client.get(url).send().await?.error_for_status()?; resp.json().await.map_err(Error::from) } diff --git a/src/error.rs b/src/error.rs index 67db286..0f84eb5 100644 --- a/src/error.rs +++ b/src/error.rs @@ -247,6 +247,12 @@ pub enum Expiry { pub enum Userinfo { #[error("Config has no userinfo url")] NoUrl, + #[error("The UserInfo Endpoint MUST return a content-type header to indicate which format is being returned")] + MissingContentType, + #[error("Not parsable content type header: {content_type}")] + ParseContentType { content_type: String }, + #[error("Wrong content type header: {content_type}. The following are accepted content types: application/json, application/jwt")] + WrongContentType { content_type: String, body: Vec }, #[error("Token and Userinfo Subjects mismatch: '{expected}', '{actual}'")] MismatchSubject { expected: String, actual: String }, #[error(transparent)] diff --git a/src/standard_claims_subject.rs b/src/standard_claims_subject.rs index 5f87100..4c6889e 100644 --- a/src/standard_claims_subject.rs +++ b/src/standard_claims_subject.rs @@ -1,6 +1,6 @@ use crate::error::StandardClaimsSubjectMissing; -pub trait StandardClaimsSubject { +pub trait StandardClaimsSubject: crate::CompactJson { /// Subject - Identifier for the End-User at the Issuer. /// /// See [Standard Claims](https://openid.net/specs/openid-connect-basic-1_0.html#StandardClaims) diff --git a/src/userinfo.rs b/src/userinfo.rs index 8609d13..57cb96e 100644 --- a/src/userinfo.rs +++ b/src/userinfo.rs @@ -6,7 +6,6 @@ use url::Url; use validator::Validate; /// The userinfo struct contains all possible userinfo fields regardless of scope. [See spec.](https://openid.net/specs/openid-connect-basic-1_0.html#StandardClaims) -// TODO is there a way to use claims_supported in config to simplify this struct? #[derive(Debug, Deserialize, Serialize, Validate, Clone, Eq, PartialEq)] pub struct Userinfo { #[serde(default)] @@ -84,3 +83,5 @@ impl StandardClaimsSubject for Userinfo { .ok_or(crate::error::StandardClaimsSubjectMissing) } } + +impl biscuit::CompactJson for Userinfo {}