Skip to content

Commit

Permalink
added JWT support for userinfo endpoint, custom userinfo claims
Browse files Browse the repository at this point in the history
  • Loading branch information
kilork committed Mar 30, 2023
1 parent 0eeec4c commit c48e29e
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 33 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ base64 = "0.13"
biscuit = "0.6"
thiserror = "1"
validator = { version = "0.15", features = ["derive"] }
mime = "0.3"

[dependencies.url]
version = "2"
Expand Down
141 changes: 112 additions & 29 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use serde_json::Value;
use std::marker::PhantomData;
use url::{form_urlencoded::Serializer, Url};

/// OAuth 2.0 client.
/// OpenID Connect 1.0 / OAuth 2.0 client.
#[derive(Debug)]
pub struct Client<P = Discovered, C: CompactJson + Claims = StandardClaims> {
/// OAuth provider.
Expand Down Expand Up @@ -76,7 +76,7 @@ impl<C: CompactJson + Claims> Client<Discovered, C> {
/// Constructs a client from an issuer url and client parameters via discovery
pub async fn discover(
id: String,
secret: String,
secret: impl Into<Option<String>>,
redirect: impl Into<Option<String>>,
issuer: Url,
) -> Result<Self, Error> {
Expand All @@ -87,7 +87,7 @@ impl<C: CompactJson + Claims> Client<Discovered, C> {
pub async fn discover_with_client(
http_client: reqwest::Client,
id: String,
secret: String,
secret: impl Into<Option<String>>,
redirect: impl Into<Option<String>>,
issuer: Url,
) -> Result<Self, Error> {
Expand All @@ -100,7 +100,7 @@ impl<C: CompactJson + Claims> Client<Discovered, C> {
provider,
id,
secret,
redirect.into(),
redirect,
http_client,
Some(jwks),
))
Expand Down Expand Up @@ -191,15 +191,19 @@ impl<C: CompactJson + Claims, P: Provider + Configurable> Client<P, C> {
Ok(token)
}

/// Mutates a Compact::encoded Token to Compact::decoded. Errors are:
/// Mutates a Compact::encoded Token to Compact::decoded.
///
/// - Decode::MissingKid if the keyset has multiple keys but the key id on the token is missing
/// - Decode::MissingKey if the given key id is not in the key set
/// - Decode::EmptySet if the keyset is empty
/// - Jose::WrongKeyType if the alg of the key and the alg in the token header mismatch
/// - Jose::WrongKeyType if the specified key alg isn't a signature algorithm
/// - Jose error if decoding fails
pub fn decode_token(&self, token: &mut IdToken<C>) -> Result<(), Error> {
/// # Errors
///
/// - [Decode::MissingKid] if the keyset has multiple keys but the key id on the token is missing
/// - [Decode::MissingKey] if the given key id is not in the key set
/// - [Decode::EmptySet] if the keyset is empty
/// - [Jose::WrongKeyType] if the alg of the key and the alg in the token header mismatch
/// - [Jose::WrongKeyType] if the specified key alg isn't a signature algorithm
/// - [Decode::UnsupportedEllipticCurve] if the alg is cryptographic curve
/// - [Decode::UnsupportedOctetKeyPair] if the alg is octet key pair
/// - [Error::Jose] error if decoding fails
pub fn decode_token<T: CompactJson>(&self, token: &mut IdToken<T>) -> Result<(), Error> {
// This is an early return if the token is already decoded
if let Compact::Decoded { .. } = *token {
return Ok(());
Expand Down Expand Up @@ -263,18 +267,20 @@ impl<C: CompactJson + Claims, P: Provider + Configurable> Client<P, C> {
}

/// Validate a decoded token. If you don't get an error, its valid! Nonce and max_age come from
/// your auth_uri options. Errors are:
/// your auth_uri options.
///
/// # Errors
///
/// - Jose Error if the Token isn't decoded
/// - Validation::Mismatch::Issuer if the provider issuer and token issuer mismatch
/// - Validation::Mismatch::Nonce if a given nonce and the token nonce mismatch
/// - Validation::Missing::Nonce if args has a nonce and the token does not
/// - Validation::Missing::Audience if the token aud doesn't contain the client id
/// - Validation::Missing::AuthorizedParty if there are multiple audiences and azp is missing
/// - Validation::Mismatch::AuthorizedParty if the azp is not the client_id
/// - Validation::Expired::Expires if the current time is past the expiration time
/// - Validation::Expired::MaxAge is the token is older than the provided max_age
/// - Validation::Missing::Authtime if a max_age was given and the token has no auth time
/// - [Error::Jose] Error if the Token isn't decoded
/// - [Error::Validation]::[Mismatch](crate::error::Validation::Mismatch)::[Issuer](crate::error::Mismatch::Issuer) if the provider issuer and token issuer mismatch
/// - [Error::Validation]::[Mismatch](crate::error::Validation::Mismatch)::[Nonce](crate::error::Mismatch::Nonce) if a given nonce and the token nonce mismatch
/// - [Error::Validation]::[Missing](crate::error::Validation::Missing)::[Nonce](crate::error::Missing::Nonce) if args has a nonce and the token does not
/// - [Error::Validation]::[Missing](crate::error::Validation::Missing)::[Audience](crate::error::Missing::Audience) if the token aud doesn't contain the client id
/// - [Error::Validation]::[Missing](crate::error::Validation::Missing)::[AuthorizedParty](crate::error::Missing::AuthorizedParty) if there are multiple audiences and azp is missing
/// - [Error::Validation]::[Mismatch](crate::error::Validation::Mismatch)::[AuthorizedParty](crate::error::Mismatch::AuthorizedParty) if the azp is not the client_id
/// - [Error::Validation]::[Expired](crate::error::Validation::Expired)::[Expires](crate::error::Expiry::Expires) if the current time is past the expiration time
/// - [Error::Validation]::[Expired](crate::error::Validation::Expired)::[MaxAge](crate::error::Expiry::MaxAge) is the token is older than the provided max_age
/// - [Error::Validation]::[Missing](crate::error::Validation::Missing)::[AuthTime](crate::error::Missing::AuthTime) if a max_age was given and the token has no auth time
pub fn validate_token<'nonce, 'max_age>(
&self,
token: &IdToken<C>,
Expand All @@ -298,7 +304,7 @@ impl<C: CompactJson + Claims, P: Provider + Configurable> Client<P, C> {
/// Get a userinfo json document for a given token at the provider's userinfo endpoint.
/// Returns [Standard Claims](https://openid.net/specs/openid-connect-basic-1_0.html#StandardClaims) as [Userinfo] struct.
///
/// Errors are:
/// # Errors
///
/// - [ErrorUserinfo::NoUrl] if this provider doesn't have a userinfo endpoint
/// - [Error::Insecure] if the userinfo url is not https
Expand All @@ -315,31 +321,105 @@ impl<C: CompactJson + Claims, P: Provider + Configurable> Client<P, C> {
/// Returns [UserInfo Response](https://openid.net/specs/openid-connect-basic-1_0.html#UserInfoResponse)
/// including non-standard claims. The sub (subject) Claim MUST always be returned in the UserInfo Response.
///
/// Errors are:
/// # Errors
///
/// - [ErrorUserinfo::NoUrl] if this provider doesn't have a userinfo endpoint
/// - [Error::Insecure] if the userinfo url is not https
/// - [Decode::MissingKid] if the keyset has multiple keys but the key id on the token is missing
/// - [Decode::MissingKey] if the given key id is not in the key set
/// - [Decode::EmptySet] if the keyset is empty
/// - [Jose::WrongKeyType] if the alg of the key and the alg in the token header mismatch
/// - [Jose::WrongKeyType] if the specified key alg isn't a signature algorithm
/// - [Error::Jose] if the token is not decoded
/// - [Error::Http] if something goes wrong getting the document
/// - [Error::Json] if the response is not a valid Userinfo document
/// - [ErrorUserinfo::MissingSubject] if subject (sub) is missing
/// - [ErrorUserinfo::MismatchSubject] if the returned userinfo document and tokens subject mismatch
/// - [ErrorUserinfo::MissingContentType] if content-type header is missing
/// - [ErrorUserinfo::ParseContentType] if content-type header is not parsable
/// - [ErrorUserinfo::WrongContentType] if content-type header is not accepted
///
/// # Examples
///
/// ```no_run
/// # use openid::{Bearer, DiscoveredClient, error::StandardClaimsSubjectMissing, StandardClaims, StandardClaimsSubject, Token};
/// # use serde::{Deserialize, Serialize};
/// # async fn _main() -> Result<(), Box<dyn std::error::Error>> {
/// # let bearer: Bearer = serde_json::from_str("{}").unwrap();
/// # let token = Token::<StandardClaims>::from(bearer);
/// # let client = DiscoveredClient::discover("client_id".to_string(), "client_secret".to_string(), "http://redirect".to_string(), url::Url::parse("http://issuer".into()).unwrap(),).await?;
///#[derive(Debug, Deserialize, Serialize)]
///struct CustomUserinfo(std::collections::HashMap<String, serde_json::Value>);
///
///impl StandardClaimsSubject for CustomUserinfo {
/// fn sub(&self) -> Result<&str, StandardClaimsSubjectMissing> {
/// self.0
/// .get("sub")
/// .and_then(|x| x.as_str())
/// .ok_or(StandardClaimsSubjectMissing)
/// }
///}
///
///impl openid::CompactJson for CustomUserinfo {}
///
///let custom_userinfo: CustomUserinfo = client.request_userinfo_custom(&token).await?;
/// # Ok(()) }
/// ```
pub async fn request_userinfo_custom<U>(&self, token: &Token<C>) -> Result<U, Error>
where
U: StandardClaimsSubject + serde::de::DeserializeOwned,
U: StandardClaimsSubject,
{
match self.config().userinfo_endpoint {
Some(ref url) => {
let auth_code = token.bearer.access_token.to_string();

let info: U = self
let response = self
.http_client
.get(url.clone())
.bearer_auth(auth_code)
.send()
.await?
.json()
.await?;
.error_for_status()?;

let content_type = response
.headers()
.get(&CONTENT_TYPE)
.and_then(|content_type| content_type.to_str().ok())
.ok_or(ErrorUserinfo::MissingContentType)?;

let mime_type = match content_type {
"application/json" => mime::APPLICATION_JSON,
content_type => content_type.parse::<mime::Mime>().map_err(|_| {
ErrorUserinfo::ParseContentType {
content_type: content_type.to_string(),
}
})?,
};

let info: U = match (mime_type.type_(), mime_type.subtype().as_str()) {
(mime::APPLICATION, "json") => {
let info_value: Value = response.json().await?;
if info_value.get("error").is_some() {
let oauth2_error: OAuth2Error = serde_json::from_value(info_value)?;
return Err(Error::ClientError(oauth2_error.into()));
}
serde_json::from_value(info_value)?
}
(mime::APPLICATION, "jwt") => {
let jwt = response.text().await?;
let mut jwt_encoded: Compact<U, Empty> = Compact::new_encoded(&jwt);
self.decode_token(&mut jwt_encoded)?;
let (_, info) = jwt_encoded.unwrap_decoded();
info
}
_ => {
return Err(ErrorUserinfo::WrongContentType {
content_type: content_type.to_string(),
body: response.bytes().await?.to_vec(),
}
.into())
}
};

let claims = token.id_token.as_ref().map(|x| x.payload()).transpose()?;
if let Some(claims) = claims {
Expand Down Expand Up @@ -400,6 +480,9 @@ where

/// Returns an authorization endpoint URI to direct the user to.
///
/// This function is used by [Client::auth_url].
/// In most situations it is non needed to use it directly.
///
/// See [RFC 6749, section 3.1](http://tools.ietf.org/html/rfc6749#section-3.1).
///
/// # Examples
Expand Down
4 changes: 2 additions & 2 deletions src/discovered.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ pub async fn discover(client: &Client, mut issuer: Url) -> Result<Config, Error>
.map_err(|_| Error::CannotBeABase)?
.extend(&[".well-known", "openid-configuration"]);

let resp = client.get(issuer).send().await?;
let resp = client.get(issuer).send().await?.error_for_status()?;
resp.json().await.map_err(Error::from)
}

/// Get the JWK set from the given Url. Errors are either a reqwest error or an Insecure error if
/// the url isn't https.
pub async fn jwks(client: &Client, url: Url) -> Result<JWKSet<Empty>, Error> {
let resp = client.get(url).send().await?;
let resp = client.get(url).send().await?.error_for_status()?;
resp.json().await.map_err(Error::from)
}
6 changes: 6 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,12 @@ pub enum Expiry {
pub enum Userinfo {
#[error("Config has no userinfo url")]
NoUrl,
#[error("The UserInfo Endpoint MUST return a content-type header to indicate which format is being returned")]
MissingContentType,
#[error("Not parsable content type header: {content_type}")]
ParseContentType { content_type: String },
#[error("Wrong content type header: {content_type}. The following are accepted content types: application/json, application/jwt")]
WrongContentType { content_type: String, body: Vec<u8> },
#[error("Token and Userinfo Subjects mismatch: '{expected}', '{actual}'")]
MismatchSubject { expected: String, actual: String },
#[error(transparent)]
Expand Down
2 changes: 1 addition & 1 deletion src/standard_claims_subject.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::error::StandardClaimsSubjectMissing;

pub trait StandardClaimsSubject {
pub trait StandardClaimsSubject: crate::CompactJson {
/// Subject - Identifier for the End-User at the Issuer.
///
/// See [Standard Claims](https://openid.net/specs/openid-connect-basic-1_0.html#StandardClaims)
Expand Down
3 changes: 2 additions & 1 deletion src/userinfo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use url::Url;
use validator::Validate;

/// The userinfo struct contains all possible userinfo fields regardless of scope. [See spec.](https://openid.net/specs/openid-connect-basic-1_0.html#StandardClaims)
// TODO is there a way to use claims_supported in config to simplify this struct?
#[derive(Debug, Deserialize, Serialize, Validate, Clone, Eq, PartialEq)]
pub struct Userinfo {
#[serde(default)]
Expand Down Expand Up @@ -84,3 +83,5 @@ impl StandardClaimsSubject for Userinfo {
.ok_or(crate::error::StandardClaimsSubjectMissing)
}
}

impl biscuit::CompactJson for Userinfo {}

0 comments on commit c48e29e

Please sign in to comment.