Skip to content

Commit

Permalink
support for public clients
Browse files Browse the repository at this point in the history
  • Loading branch information
kilork committed Mar 15, 2023
1 parent e5563da commit 416e3eb
Show file tree
Hide file tree
Showing 8 changed files with 392 additions and 456 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ keywords = ["authentication", "authorization", "oauth", "openid", "uma2"]
license = "Unlicense OR MIT"
readme = "README.md"
repository = "https://github.com/kilork/openid"
rust-version = "1.65"

[features]
default = ["native-tls"]
Expand Down
147 changes: 85 additions & 62 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::{
Bearer, Claims, Config, Configurable, Discovered, IdToken, OAuth2Error, Options, Provider,
StandardClaims, Token, Userinfo,
};

use biscuit::{
jwa::{self, SignatureAlgorithm},
jwk::{AlgorithmParameters, JWKSet},
Expand All @@ -29,7 +30,7 @@ pub struct Client<P = Discovered, C: CompactJson + Claims = StandardClaims> {
pub client_id: String,

/// Client secret.
pub client_secret: String,
pub client_secret: Option<String>,

/// Redirect URI.
pub redirect_uri: Option<String>,
Expand Down Expand Up @@ -75,7 +76,7 @@ impl<C: CompactJson + Claims> Client<Discovered, C> {
pub async fn discover(
id: String,
secret: String,
redirect: Option<String>,
redirect: impl Into<Option<String>>,
issuer: Url,
) -> Result<Self, Error> {
Self::discover_with_client(reqwest::Client::new(), id, secret, redirect, issuer).await
Expand All @@ -86,7 +87,7 @@ impl<C: CompactJson + Claims> Client<Discovered, C> {
http_client: reqwest::Client,
id: String,
secret: String,
redirect: Option<String>,
redirect: impl Into<Option<String>>,
issuer: Url,
) -> Result<Self, Error> {
let config = discovered::discover(&http_client, issuer).await?;
Expand All @@ -98,7 +99,7 @@ impl<C: CompactJson + Claims> Client<Discovered, C> {
provider,
id,
secret,
redirect,
redirect.into(),
http_client,
Some(jwks),
))
Expand Down Expand Up @@ -134,7 +135,7 @@ impl<C: CompactJson + Claims, P: Provider + Configurable> Client<P, C> {
None => String::from("openid"),
};

let mut url = self.auth_uri(Some(&scope), options.state.as_deref());
let mut url = self.auth_uri(&*scope, options.state.as_deref());
{
let mut query = url.query_pairs_mut();
if let Some(ref nonce) = options.nonce {
Expand Down Expand Up @@ -177,8 +178,8 @@ impl<C: CompactJson + Claims, P: Provider + Configurable> Client<P, C> {
pub async fn authenticate(
&self,
auth_code: &str,
nonce: Option<&str>,
max_age: Option<&Duration>,
nonce: impl Into<Option<&str>>,
max_age: impl Into<Option<&Duration>>,
) -> Result<Token<C>, Error> {
let bearer = self.request_token(auth_code).await.map_err(Error::from)?;
let mut token: Token<C> = bearer.into();
Expand Down Expand Up @@ -273,11 +274,11 @@ impl<C: CompactJson + Claims, P: Provider + Configurable> Client<P, C> {
/// - Validation::Expired::Expires if the current time is past the expiration time
/// - Validation::Expired::MaxAge is the token is older than the provided max_age
/// - Validation::Missing::Authtime if a max_age was given and the token has no auth time
pub fn validate_token(
pub fn validate_token<'nonce, 'max_age>(
&self,
token: &IdToken<C>,
nonce: Option<&str>,
max_age: Option<&Duration>,
nonce: impl Into<Option<&'nonce str>>,
max_age: impl Into<Option<&'max_age Duration>>,
) -> Result<(), Error> {
let claims = token.payload()?;
let config = self.config();
Expand Down Expand Up @@ -354,16 +355,16 @@ where
pub fn new(
provider: P,
client_id: String,
client_secret: String,
redirect_uri: Option<String>,
client_secret: impl Into<Option<String>>,
redirect_uri: impl Into<Option<String>>,
http_client: reqwest::Client,
jwks: Option<JWKSet<Empty>>,
) -> Self {
Client {
provider,
client_id,
client_secret,
redirect_uri,
client_secret: client_secret.into(),
redirect_uri: redirect_uri.into(),
http_client,
jwks,
marker: PhantomData,
Expand Down Expand Up @@ -393,7 +394,11 @@ where
/// None,
/// );
/// ```
pub fn auth_uri(&self, scope: Option<&str>, state: Option<&str>) -> Url {
pub fn auth_uri<'scope, 'state>(
&self,
scope: impl Into<Option<&'scope str>>,
state: impl Into<Option<&'state str>>,
) -> Url {
let mut uri = self.provider.auth_uri().clone();

{
Expand All @@ -405,39 +410,17 @@ where
if let Some(ref redirect_uri) = self.redirect_uri {
query.append_pair("redirect_uri", redirect_uri);
}
if let Some(scope) = scope {
query.append_pair("scope", scope);
}
if let Some(state) = state {

self.append_scope(&mut query, scope);

if let Some(state) = state.into() {
query.append_pair("state", state);
}
}

uri
}

async fn post_token(&self, body: String) -> Result<Value, ClientError> {
let json = self
.http_client
.post(self.provider.token_uri().clone())
.basic_auth(&self.client_id, Some(&self.client_secret))
.header(ACCEPT, "application/json")
.header(CONTENT_TYPE, "application/x-www-form-urlencoded")
.body(body)
.send()
.await?
.json::<Value>()
.await?;

let error: Result<OAuth2Error, _> = serde_json::from_value(json.clone());

if let Ok(error) = error {
Err(ClientError::from(error))
} else {
Ok(json)
}
}

/// Requests an access token using an authorization code.
///
/// See [RFC 6749, section 4.1.3](http://tools.ietf.org/html/rfc6749#section-4.1.3).
Expand All @@ -453,10 +436,8 @@ where
body.append_pair("redirect_uri", redirect_uri);
}

if self.provider.credentials_in_body() {
body.append_pair("client_id", &self.client_id);
body.append_pair("client_secret", &self.client_secret);
}
self.append_credentials(&mut body);

body.finish()
};

Expand All @@ -472,7 +453,7 @@ where
&self,
username: &str,
password: &str,
scope: Option<&str>,
scope: impl Into<Option<&str>>,
) -> Result<Bearer, ClientError> {
// Ensure the non thread-safe `Serializer` is not kept across
// an `await` boundary by localizing it to this inner scope.
Expand All @@ -481,12 +462,10 @@ where
body.append_pair("grant_type", "password");
body.append_pair("username", username);
body.append_pair("password", password);
body.append_pair("client_id", &self.client_id);
body.append_pair("client_secret", &self.client_secret);

if let Some(scope) = scope {
body.append_pair("scope", scope);
}
self.append_scope(&mut body, scope);

self.append_credentials(&mut body);

body.finish()
};
Expand All @@ -499,14 +478,20 @@ where
/// Requests an access token using the Client Credentials Grant flow
///
/// See [RFC 6749, section 4.4](https://tools.ietf.org/html/rfc6749#section-4.4)
pub async fn request_token_using_client_credentials(&self) -> Result<Bearer, ClientError> {
pub async fn request_token_using_client_credentials(
&self,
scope: impl Into<Option<&str>>,
) -> Result<Bearer, ClientError> {
// Ensure the non thread-safe `Serializer` is not kept across
// an `await` boundary by localizing it to this inner scope.
let body = {
let mut body = Serializer::new(String::new());
body.append_pair("grant_type", "client_credentials");
body.append_pair("client_id", &self.client_id);
body.append_pair("client_secret", &self.client_secret);

self.append_scope(&mut body, scope);

self.append_credentials(&mut body);

body.finish()
};

Expand All @@ -521,7 +506,7 @@ where
pub async fn refresh_token(
&self,
token: Bearer,
scope: Option<&str>,
scope: impl Into<Option<&str>>,
) -> Result<Bearer, ClientError> {
// Ensure the non thread-safe `Serializer` is not kept across
// an `await` boundary by localizing it to this inner scope.
Expand All @@ -536,14 +521,9 @@ where
.expect("No refresh_token field"),
);

if let Some(scope) = scope {
body.append_pair("scope", scope);
}
self.append_scope(&mut body, scope);

if self.provider.credentials_in_body() {
body.append_pair("client_id", &self.client_id);
body.append_pair("client_secret", &self.client_secret);
}
self.append_credentials(&mut body);

body.finish()
};
Expand All @@ -564,6 +544,49 @@ where
Ok(token)
}
}

async fn post_token(&self, body: String) -> Result<Value, ClientError> {
let json = self
.http_client
.post(self.provider.token_uri().clone())
.basic_auth(&self.client_id, self.client_secret.as_ref())
.header(ACCEPT, "application/json")
.header(CONTENT_TYPE, "application/x-www-form-urlencoded")
.body(body)
.send()
.await?
.json::<Value>()
.await?;

let error: Result<OAuth2Error, _> = serde_json::from_value(json.clone());

if let Ok(error) = error {
Err(ClientError::from(error))
} else {
Ok(json)
}
}

fn append_credentials(&self, body: &mut Serializer<String>) {
if self.provider.credentials_in_body() {
body.append_pair("client_id", &self.client_id);
if let Some(client_secret) = self.client_secret.as_deref() {
body.append_pair("client_secret", client_secret);
}
}
}

fn append_scope<'scope, T>(
&self,
body: &mut Serializer<T>,
scope: impl Into<Option<&'scope str>>,
) where
T: url::form_urlencoded::Target,
{
if let Some(scope) = scope.into() {
body.append_pair("scope", scope);
}
}
}

#[cfg(test)]
Expand Down
3 changes: 2 additions & 1 deletion src/uma2/config.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::Config;
use serde::{Deserialize, Serialize};
use url::Url;

use crate::Config;

#[derive(Debug, Deserialize, Serialize)]
pub struct Uma2Config {
// UMA2 additions
Expand Down
5 changes: 3 additions & 2 deletions src/uma2/discovered.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::{
uma2::{Uma2Config, Uma2Provider},
Claims, Client, Config, Configurable, Provider,
};

use biscuit::CompactJson;
use url::Url;

Expand Down Expand Up @@ -53,7 +54,7 @@ impl<C: CompactJson + Claims> Client<DiscoveredUma2, C> {
pub async fn discover_uma2(
id: String,
secret: String,
redirect: Option<String>,
redirect: impl Into<Option<String>>,
issuer: Url,
) -> Result<Self, Error> {
let http_client = reqwest::Client::new();
Expand All @@ -67,7 +68,7 @@ impl<C: CompactJson + Claims> Client<DiscoveredUma2, C> {
provider,
id,
secret,
redirect,
redirect.into(),
http_client,
Some(jwks),
))
Expand Down
Loading

0 comments on commit 416e3eb

Please sign in to comment.