Skip to content

Commit

Permalink
Add required_spec_claims (#225)
Browse files Browse the repository at this point in the history
  • Loading branch information
Keats committed Feb 2, 2022
1 parent 356fac0 commit 255c740
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 25 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
- Allow float values for `exp` and `nbf`, yes it's in the spec... floats will be rounded and converted to u64
- Error now implements Clone/Eq
- Change default leeway from 0s to 60s
- Add `Validation::require_spec_claims` to validate presence of the spec claims

## 7.2.0 (2020-06-30)

Expand Down
4 changes: 4 additions & 0 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ pub enum ErrorKind {
InvalidKeyFormat,

// Validation errors
/// When a claim required by the validation is not present
MissingRequiredClaim(String),
/// When a token’s `exp` claim indicates that it has expired
ExpiredSignature,
/// When a token’s `iss` claim does not match the expected issuer
Expand Down Expand Up @@ -88,6 +90,7 @@ impl StdError for Error {
ErrorKind::InvalidRsaKey(_) => None,
ErrorKind::ExpiredSignature => None,
ErrorKind::MissingAlgorithm => None,
ErrorKind::MissingRequiredClaim(_) => None,
ErrorKind::InvalidIssuer => None,
ErrorKind::InvalidAudience => None,
ErrorKind::InvalidSubject => None,
Expand Down Expand Up @@ -119,6 +122,7 @@ impl fmt::Display for Error {
| ErrorKind::InvalidAlgorithm
| ErrorKind::InvalidKeyFormat
| ErrorKind::InvalidAlgorithmName => write!(f, "{:?}", self.0),
ErrorKind::MissingRequiredClaim(ref c) => write!(f, "Missing required claim: {}", c),
ErrorKind::InvalidRsaKey(ref msg) => write!(f, "RSA key invalid: {}", msg),
ErrorKind::Json(ref err) => write!(f, "JSON error: {}", err),
ErrorKind::Utf8(ref err) => write!(f, "UTF-8 error: {}", err),
Expand Down
80 changes: 55 additions & 25 deletions src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@ use crate::errors::{new_error, ErrorKind, Result};
/// ```
#[derive(Debug, Clone, PartialEq)]
pub struct Validation {
/// Which claims are required to be present before starting the validation
/// Which claims are required to be present before starting the validation.
/// This does not interact with the various `validate_*`. If you remove `exp` from that list, you still need
/// to set `validate_exp` to `false`.
/// The only value that will be used are "exp", "nbf", "aud", "iss", "sub". Anything else will be ignored.
///
/// Defaults to `{"exp"}`
pub required_claims: HashSet<String>,
pub required_spec_claims: HashSet<String>,
/// Add some leeway (in seconds) to the `exp`, `iat` and `nbf` validation to
/// account for clock skew.
///
Expand Down Expand Up @@ -94,11 +97,13 @@ impl Validation {
self.iss = Some(items.iter().map(|x| x.to_string()).collect())
}

/// Which claims are required to be present for this JWT to be considered valid
/// This is not restricted to the claims from the JWT spec, you can add your own custom ones.
/// The simple usage is `set_required_claims(&["exp", "my_claim"])`
pub fn set_required_claims<T: ToString>(&mut self, items: &[T]) {
self.required_claims = items.iter().map(|x| x.to_string()).collect();
/// Which claims are required to be present for this JWT to be considered valid.
/// The only values that will be considered are "exp", "nbf", "aud", "iss", "sub".
/// The simple usage is `set_required_claims(&["exp", "nbf"])`.
/// If you want to have an empty set, do not use this function - set an empty set on the struct
/// param directly.
pub fn set_required_spec_claims<T: ToString>(&mut self, items: &[T]) {
self.required_spec_claims = items.iter().map(|x| x.to_string()).collect();
}

/// Whether to validate the JWT cryptographic signature
Expand All @@ -115,7 +120,7 @@ impl Default for Validation {
required_claims.insert("exp".to_owned());

Validation {
required_claims,
required_spec_claims: required_claims,
algorithms: vec![Algorithm::HS256],
leeway: 60,

Expand Down Expand Up @@ -186,6 +191,7 @@ enum Issuer<'a> {
Single(#[serde(borrow)] Cow<'a, str>),
Multiple(#[serde(borrow)] HashSet<BorrowedCowIfPossible<'a>>),
}

/// Usually #[serde(borrow)] on `Cow` enables deserializing with no allocations where
/// possible (no escapes in the original str) but it does not work on e.g. `HashSet<Cow<str>>`
/// We use this struct in this case.
Expand All @@ -207,12 +213,23 @@ fn is_subset(reference: &HashSet<String>, given: &HashSet<BorrowedCowIfPossible<
}

pub(crate) fn validate(claims: ClaimsForValidation, options: &Validation) -> Result<()> {
// for required_claim in &options.required_claims {
// if matches!(claims)
// }

let now = get_current_timestamp();

for required_claim in &options.required_spec_claims {
let present = match required_claim.as_str() {
"exp" => matches!(claims.exp, TryParse::Parsed(_)),
"sub" => matches!(claims.sub, TryParse::Parsed(_)),
"iss" => matches!(claims.iss, TryParse::Parsed(_)),
"aud" => matches!(claims.aud, TryParse::Parsed(_)),
"nbf" => matches!(claims.nbf, TryParse::Parsed(_)),
_ => continue,
};

if !present {
return Err(new_error(ErrorKind::MissingRequiredClaim(required_claim.clone())));
}
}

if options.validate_exp
&& !matches!(claims.exp, TryParse::Parsed(exp) if exp >= now-options.leeway)
{
Expand Down Expand Up @@ -304,6 +321,7 @@ mod tests {

use crate::errors::ErrorKind;
use crate::Algorithm;
use std::collections::HashSet;

fn deserialize_claims(claims: &serde_json::Value) -> ClaimsForValidation {
serde::Deserialize::deserialize(claims).unwrap()
Expand Down Expand Up @@ -360,18 +378,17 @@ mod tests {
#[test]
fn validation_called_even_if_field_is_empty() {
let claims = json!({});
let res = validate(deserialize_claims(&claims), &Validation::new(Algorithm::HS256));
assert!(res.is_err());
match res.unwrap_err().kind() {
ErrorKind::ExpiredSignature => (),
_ => unreachable!(),
};
let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims = HashSet::new();
let res = validate(deserialize_claims(&claims), &validation).unwrap_err();
assert_eq!(res.kind(), &ErrorKind::ExpiredSignature);
}

#[test]
fn nbf_in_past_ok() {
let claims = json!({ "nbf": get_current_timestamp() - 10000 });
let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims = HashSet::new();
validation.validate_exp = false;
validation.validate_nbf = true;
let res = validate(deserialize_claims(&claims), &validation);
Expand All @@ -382,6 +399,7 @@ mod tests {
fn nbf_float_in_past_ok() {
let claims = json!({ "nbf": (get_current_timestamp() as f64) - 10000.1234 });
let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims = HashSet::new();
validation.validate_exp = false;
validation.validate_nbf = true;
let res = validate(deserialize_claims(&claims), &validation);
Expand All @@ -392,6 +410,7 @@ mod tests {
fn nbf_in_future_fails() {
let claims = json!({ "nbf": get_current_timestamp() + 100000 });
let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims = HashSet::new();
validation.validate_exp = false;
validation.validate_nbf = true;
let res = validate(deserialize_claims(&claims), &validation);
Expand All @@ -407,6 +426,7 @@ mod tests {
fn nbf_in_future_but_in_leeway_ok() {
let claims = json!({ "nbf": get_current_timestamp() + 500 });
let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims = HashSet::new();
validation.validate_exp = false;
validation.validate_nbf = true;
validation.leeway = 1000 * 60;
Expand All @@ -418,6 +438,7 @@ mod tests {
fn iss_string_ok() {
let claims = json!({"iss": ["Keats"]});
let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims = HashSet::new();
validation.validate_exp = false;
validation.set_issuer(&["Keats"]);
let res = validate(deserialize_claims(&claims), &validation);
Expand All @@ -428,6 +449,7 @@ mod tests {
fn iss_array_of_string_ok() {
let claims = json!({"iss": ["UserA", "UserB"]});
let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims = HashSet::new();
validation.validate_exp = false;
validation.set_issuer(&["UserA", "UserB"]);
let res = validate(deserialize_claims(&claims), &validation);
Expand All @@ -439,6 +461,7 @@ mod tests {
let claims = json!({"iss": "Hacked"});

let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims = HashSet::new();
validation.validate_exp = false;
validation.set_issuer(&["Keats"]);
let res = validate(deserialize_claims(&claims), &validation);
Expand All @@ -455,6 +478,7 @@ mod tests {
let claims = json!({});

let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims = HashSet::new();
validation.validate_exp = false;
validation.set_issuer(&["Keats"]);
let res = validate(deserialize_claims(&claims), &validation);
Expand All @@ -469,6 +493,7 @@ mod tests {
fn sub_ok() {
let claims = json!({"sub": "Keats"});
let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims = HashSet::new();
validation.validate_exp = false;
validation.sub = Some("Keats".to_owned());
let res = validate(deserialize_claims(&claims), &validation);
Expand All @@ -479,6 +504,7 @@ mod tests {
fn sub_not_matching_fails() {
let claims = json!({"sub": "Hacked"});
let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims = HashSet::new();
validation.validate_exp = false;
validation.sub = Some("Keats".to_owned());
let res = validate(deserialize_claims(&claims), &validation);
Expand All @@ -495,6 +521,7 @@ mod tests {
let claims = json!({});
let mut validation = Validation::new(Algorithm::HS256);
validation.validate_exp = false;
validation.required_spec_claims = HashSet::new();
validation.sub = Some("Keats".to_owned());
let res = validate(deserialize_claims(&claims), &validation);
assert!(res.is_err());
Expand All @@ -510,6 +537,7 @@ mod tests {
let claims = json!({"aud": ["Everyone"]});
let mut validation = Validation::new(Algorithm::HS256);
validation.validate_exp = false;
validation.required_spec_claims = HashSet::new();
validation.set_audience(&["Everyone"]);
let res = validate(deserialize_claims(&claims), &validation);
assert!(res.is_ok());
Expand All @@ -520,6 +548,7 @@ mod tests {
let claims = json!({"aud": ["UserA", "UserB"]});
let mut validation = Validation::new(Algorithm::HS256);
validation.validate_exp = false;
validation.required_spec_claims = HashSet::new();
validation.set_audience(&["UserA", "UserB"]);
let res = validate(deserialize_claims(&claims), &validation);
assert!(res.is_ok());
Expand All @@ -530,6 +559,7 @@ mod tests {
let claims = json!({"aud": ["Everyone"]});
let mut validation = Validation::new(Algorithm::HS256);
validation.validate_exp = false;
validation.required_spec_claims = HashSet::new();
validation.set_audience(&["UserA", "UserB"]);
let res = validate(deserialize_claims(&claims), &validation);
assert!(res.is_err());
Expand All @@ -545,6 +575,7 @@ mod tests {
let claims = json!({"aud": ["Everyone"]});
let mut validation = Validation::new(Algorithm::HS256);
validation.validate_exp = false;
validation.required_spec_claims = HashSet::new();
validation.set_audience(&["None"]);
let res = validate(deserialize_claims(&claims), &validation);
assert!(res.is_err());
Expand All @@ -560,6 +591,7 @@ mod tests {
let claims = json!({});
let mut validation = Validation::new(Algorithm::HS256);
validation.validate_exp = false;
validation.required_spec_claims = HashSet::new();
validation.set_audience(&["None"]);
let res = validate(deserialize_claims(&claims), &validation);
assert!(res.is_err());
Expand Down Expand Up @@ -599,19 +631,17 @@ mod tests {
aud_hashset.insert(aud);
let mut validation = Validation::new(Algorithm::HS256);
validation.validate_exp = false;
validation.required_spec_claims = HashSet::new();
validation.set_audience(&["my-googleclientid1234.apps.googleusercontent.com"]);

let res = validate(deserialize_claims(&claims), &validation);
assert!(res.is_ok());
}

#[test]
fn required_claims_complains_if_field_not_found() {
let claims = json!({"aud": "my-googleclientid1234.apps.googleusercontent.com"});
let mut validation = Validation::new(Algorithm::HS256);
validation.validate_exp = false;

let res = validate(deserialize_claims(&claims), &validation);
assert!(res.is_err());
fn errors_when_required_claim_is_missing() {
let claims = json!({});
let res = validate(deserialize_claims(&claims), &Validation::default()).unwrap_err();
assert_eq!(res.kind(), &ErrorKind::MissingRequiredClaim("exp".to_owned()));
}
}

0 comments on commit 255c740

Please sign in to comment.