Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor KeyProvider to receive the "Key Id" #167

Merged
merged 6 commits into from
May 4, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,65 @@ The library implements JWT Verification and Signing using the following algorith

## Usage

### Pick the Algorithm

The Algorithm defines how a token is signed and verified. It can be instantiated with the raw value of the secret in the case of HMAC algorithms, or the key pairs or `KeyProvider` in the case of RSA and ECDSA algorithms. Once created, the instance is reusable for token signing and verification operations.

#### Using static secrets or keys:

```java
//HMAC
Algorithm algorithmHS = Algorithm.HMAC256("secret");

//RSA
RSAPublicKey publicKey = //Get the key instance
RSAPrivateKey privateKey = //Get the key instance
Algorithm algorithmRS = Algorithm.RSA256(publicKey, privateKey);
```

#### Using a KeyProvider:

By using a `KeyProvider` you can change in runtime the key used either to verify the token signature or to sign a new token for RSA or ECDSA algorithms. This is achieved by implementing either `RSAKeyProvider` or `ECDSAKeyProvider` methods:

- `getPublicKeyById(String kid)`: Its called during token signature verification and it should return the key used to verify the token. If key rotation is being used, e.g. [JWK](https://tools.ietf.org/html/rfc7517) it can fetch the correct rotation key using the id. (Or just return the same key all the time).
- `getPrivateKey()`: Its called during token signing and it should return the key that will be used to sign the JWT.
- `getPrivateKeyId()`: Its called during token signing and it should return the id of the key that identifies the one returned by `getPrivateKey()`. This value is preferred over the one set in the `JWTCreator.Builder#withKeyId(String)` method. If you don't need to set a `kid` value avoid instantiating an Algorithm using a `KeyProvider`.


The following snippet uses example classes showing how this would work:


```java
final JwkStore jwkStore = new JwkStore("{JWKS_FILE_HOST}");
final RSAPrivateKey privateKey = //Get the key instance
final String privateKeyId = //Create an Id for the above key

RSAKeyProvider keyProvider = new RSAKeyProvider() {
@Override
public RSAPublicKey getPublicKeyById(String kid) {
//Received 'kid' value might be null if it wasn't defined in the Token's header
RSAPublicKey publicKey = jwkStore.get(kid);
return (RSAPublicKey) publicKey;
}

@Override
public RSAPrivateKey getPrivateKey() {
return privateKey;
}

@Override
public String getPrivateKeyId() {
return privateKeyId;
}
};

Algorithm algorithm = Algorithm.RSA256(keyProvider);
//Use the Algorithm to create and verify JWTs.
```

> For simple key rotation using JWKs try the [jwks-rsa-java](https://github.com/auth0/jwks-rsa-java) library.


### Create and Sign a Token

You'll first need to create a `JWTCreator` instance by calling `JWT.create()`. Use the builder to define the custom Claims your token needs to have. Finally to get the String token call `sign()` and pass the `Algorithm` instance.
Expand Down Expand Up @@ -220,7 +279,7 @@ When creating a Token with the `JWT.create()` you can specify header Claims by c

```java
Map<String, Object> headerClaims = new HashMap();
headerclaims.put("owner", "auth0");
headerClaims.put("owner", "auth0");
String token = JWT.create()
.withHeader(headerClaims)
.sign(algorithm);
Expand Down
2 changes: 1 addition & 1 deletion lib/src/main/java/com/auth0/jwt/JWT.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import com.auth0.jwt.interfaces.Verification;

@SuppressWarnings("WeakerAccess")
public abstract class JWT implements DecodedJWT {
public abstract class JWT {

/**
* Decode a given Json Web Token.
Expand Down
9 changes: 7 additions & 2 deletions lib/src/main/java/com/auth0/jwt/JWTCreator.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ public Builder withHeader(Map<String, Object> headerClaims) {

/**
* Add a specific Key Id ("kid") claim to the Header.
* If the {@link Algorithm} used to sign this token was instantiated with a KeyProvider, the 'kid' value will be taken from that provider and this one will be ignored.
*
* @param keyId the Key Id value.
* @return this same Builder instance.
Expand Down Expand Up @@ -303,6 +304,10 @@ public String sign(Algorithm algorithm) throws IllegalArgumentException, JWTCrea
}
headerClaims.put(PublicClaims.ALGORITHM, algorithm.getName());
headerClaims.put(PublicClaims.TYPE, "JWT");
String signingKeyId = algorithm.getSigningKeyId();
if (signingKeyId != null) {
withKeyId(signingKeyId);
}
return new JWTCreator(algorithm, headerClaims, payloadClaims).sign();
}

Expand All @@ -322,8 +327,8 @@ private void addClaim(String name, Object value) {
}

private String sign() throws SignatureGenerationException {
String header = Base64.encodeBase64URLSafeString((headerJson.getBytes(StandardCharsets.UTF_8)));
String payload = Base64.encodeBase64URLSafeString((payloadJson.getBytes(StandardCharsets.UTF_8)));
String header = Base64.encodeBase64URLSafeString(headerJson.getBytes(StandardCharsets.UTF_8));
String payload = Base64.encodeBase64URLSafeString(payloadJson.getBytes(StandardCharsets.UTF_8));
String content = String.format("%s.%s", header, payload);

byte[] signatureBytes = algorithm.sign(content.getBytes(StandardCharsets.UTF_8));
Expand Down
32 changes: 18 additions & 14 deletions lib/src/main/java/com/auth0/jwt/JWTDecoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.auth0.jwt.exceptions.JWTDecodeException;
import com.auth0.jwt.impl.JWTParser;
import com.auth0.jwt.interfaces.Claim;
import com.auth0.jwt.interfaces.DecodedJWT;
import com.auth0.jwt.interfaces.Header;
import com.auth0.jwt.interfaces.Payload;
import org.apache.commons.codec.binary.Base64;
Expand All @@ -16,20 +17,14 @@
* The JWTDecoder class holds the decode method to parse a given JWT token into it's JWT representation.
*/
@SuppressWarnings("WeakerAccess")
final class JWTDecoder extends JWT {
final class JWTDecoder implements DecodedJWT {

private final String token;
private Header header;
private Payload payload;
private String signature;
private final String[] parts;
private final Header header;
private final Payload payload;

JWTDecoder(String jwt) throws JWTDecodeException {
this.token = jwt;
parseToken(jwt);
}

private void parseToken(String token) throws JWTDecodeException {
final String[] parts = TokenUtils.splitToken(token);
parts = TokenUtils.splitToken(jwt);
final JWTParser converter = new JWTParser();
String headerJson;
String payloadJson;
Expand All @@ -41,7 +36,6 @@ private void parseToken(String token) throws JWTDecodeException {
}
header = converter.parseHeader(headerJson);
payload = converter.parsePayload(payloadJson);
signature = parts[2];
}

@Override
Expand Down Expand Up @@ -114,13 +108,23 @@ public Map<String, Claim> getClaims() {
return payload.getClaims();
}

@Override
public String getHeader() {
return parts[0];
}

@Override
public String getPayload() {
return parts[1];
}

@Override
public String getSignature() {
return signature;
return parts[2];
}

@Override
public String getToken() {
return token;
return String.format("%s.%s.%s", parts[0], parts[1], parts[2]);
}
}
72 changes: 30 additions & 42 deletions lib/src/main/java/com/auth0/jwt/JWTVerifier.java
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
package com.auth0.jwt;

import com.auth0.jwt.algorithms.Algorithm;
import com.auth0.jwt.exceptions.AlgorithmMismatchException;
import com.auth0.jwt.exceptions.InvalidClaimException;
import com.auth0.jwt.exceptions.JWTVerificationException;
import com.auth0.jwt.exceptions.SignatureVerificationException;
import com.auth0.jwt.exceptions.TokenExpiredException;
import com.auth0.jwt.exceptions.*;
import com.auth0.jwt.impl.PublicClaims;
import com.auth0.jwt.interfaces.Claim;
import com.auth0.jwt.interfaces.Clock;
import com.auth0.jwt.interfaces.DecodedJWT;
import com.auth0.jwt.interfaces.Verification;
import org.apache.commons.codec.binary.Base64;

import java.nio.charset.StandardCharsets;
import java.util.*;

/**
Expand Down Expand Up @@ -349,29 +343,26 @@ private void requireClaim(String name, Object value) {
*
* @param token to verify.
* @return a verified and decoded JWT.
* @throws JWTVerificationException if any of the required contents inside the JWT is invalid.
* @throws AlgorithmMismatchException if the algorithm stated in the token's header it's not equal to the one defined in the {@link JWTVerifier}.
* @throws SignatureVerificationException if the signature is invalid.
* @throws TokenExpiredException if the token has expired.
* @throws InvalidClaimException if a claim contained a different value than the expected one.
*/
public DecodedJWT verify(String token) throws JWTVerificationException {
DecodedJWT jwt = JWTDecoder.decode(token);
DecodedJWT jwt = JWT.decode(token);
verifyAlgorithm(jwt, algorithm);
verifySignature(TokenUtils.splitToken(token));
algorithm.verify(jwt);
verifyClaims(jwt, claims);
return jwt;
}

private void verifySignature(String[] parts) throws SignatureVerificationException {
byte[] content = String.format("%s.%s", parts[0], parts[1]).getBytes(StandardCharsets.UTF_8);
byte[] signature = Base64.decodeBase64(parts[2]);
algorithm.verify(content, signature);
}

private void verifyAlgorithm(DecodedJWT jwt, Algorithm expectedAlgorithm) throws AlgorithmMismatchException {
if (!expectedAlgorithm.getName().equals(jwt.getAlgorithm())) {
throw new AlgorithmMismatchException("The provided Algorithm doesn't match the one defined in the JWT's Header.");
}
}

private void verifyClaims(DecodedJWT jwt, Map<String, Object> claims) {
private void verifyClaims(DecodedJWT jwt, Map<String, Object> claims) throws TokenExpiredException, InvalidClaimException {
for (Map.Entry<String, Object> entry : claims.entrySet()) {
switch (entry.getKey()) {
case PublicClaims.AUDIENCE:
Expand Down Expand Up @@ -435,31 +426,28 @@ private void assertValidStringClaim(String claimName, String value, String expec
}

private void assertValidDateClaim(Date date, long leeway, boolean shouldBeFuture) {
Date today = clock.getToday();
today.setTime((long) Math.floor((today.getTime() / 1000) * 1000)); // truncate
// millis
if (shouldBeFuture) {
assertDateIsFuture(date, leeway, today);
} else {
assertDateIsPast(date, leeway, today);
}
}

private void assertDateIsFuture(Date date, long leeway, Date today) {

today.setTime(today.getTime() - leeway * 1000);
if (date != null && today.after(date)) {
throw new TokenExpiredException(String.format("The Token has expired on %s.", date));
}
}

private void assertDateIsPast(Date date, long leeway, Date today) {
today.setTime(today.getTime() + leeway * 1000);
if(date!=null && today.before(date)) {
throw new InvalidClaimException(String.format("The Token can't be used before %s.", date));
}

}
Date today = clock.getToday();
today.setTime((long) Math.floor((today.getTime() / 1000) * 1000)); // truncate millis
if (shouldBeFuture) {
assertDateIsFuture(date, leeway, today);
} else {
assertDateIsPast(date, leeway, today);
}
}

private void assertDateIsFuture(Date date, long leeway, Date today) {
today.setTime(today.getTime() - leeway * 1000);
if (date != null && today.after(date)) {
throw new TokenExpiredException(String.format("The Token has expired on %s.", date));
}
}

private void assertDateIsPast(Date date, long leeway, Date today) {
today.setTime(today.getTime() + leeway * 1000);
if (date != null && today.before(date)) {
throw new InvalidClaimException(String.format("The Token can't be used before %s.", date));
}
}

private void assertValidAudienceClaim(List<String> audience, List<String> value) {
if (audience == null || !audience.containsAll(value)) {
Expand Down
31 changes: 20 additions & 11 deletions lib/src/main/java/com/auth0/jwt/algorithms/Algorithm.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import com.auth0.jwt.exceptions.SignatureGenerationException;
import com.auth0.jwt.exceptions.SignatureVerificationException;
import com.auth0.jwt.interfaces.ECKeyProvider;
import com.auth0.jwt.interfaces.DecodedJWT;
import com.auth0.jwt.interfaces.ECDSAKeyProvider;
import com.auth0.jwt.interfaces.RSAKeyProvider;

import java.io.UnsupportedEncodingException;
Expand Down Expand Up @@ -207,7 +208,7 @@ public static Algorithm HMAC512(byte[] secret) throws IllegalArgumentException {
* @return a valid ECDSA256 Algorithm.
* @throws IllegalArgumentException if the Key Provider is null.
*/
public static Algorithm ECDSA256(ECKeyProvider keyProvider) throws IllegalArgumentException {
public static Algorithm ECDSA256(ECDSAKeyProvider keyProvider) throws IllegalArgumentException {
return new ECDSAAlgorithm("ES256", "SHA256withECDSA", 32, keyProvider);
}

Expand All @@ -229,7 +230,7 @@ public static Algorithm ECDSA256(ECPublicKey publicKey, ECPrivateKey privateKey)
* @param key the key to use in the verify or signing instance.
* @return a valid ECDSA256 Algorithm.
* @throws IllegalArgumentException if the provided Key is null.
* @deprecated use {@link #ECDSA256(ECPublicKey, ECPrivateKey)} or {@link #ECDSA256(ECKeyProvider)}
* @deprecated use {@link #ECDSA256(ECPublicKey, ECPrivateKey)} or {@link #ECDSA256(ECDSAKeyProvider)}
*/
@Deprecated
public static Algorithm ECDSA256(ECKey key) throws IllegalArgumentException {
Expand All @@ -245,7 +246,7 @@ public static Algorithm ECDSA256(ECKey key) throws IllegalArgumentException {
* @return a valid ECDSA384 Algorithm.
* @throws IllegalArgumentException if the Key Provider is null.
*/
public static Algorithm ECDSA384(ECKeyProvider keyProvider) throws IllegalArgumentException {
public static Algorithm ECDSA384(ECDSAKeyProvider keyProvider) throws IllegalArgumentException {
return new ECDSAAlgorithm("ES384", "SHA384withECDSA", 48, keyProvider);
}

Expand All @@ -267,7 +268,7 @@ public static Algorithm ECDSA384(ECPublicKey publicKey, ECPrivateKey privateKey)
* @param key the key to use in the verify or signing instance.
* @return a valid ECDSA384 Algorithm.
* @throws IllegalArgumentException if the provided Key is null.
* @deprecated use {@link #ECDSA384(ECPublicKey, ECPrivateKey)} or {@link #ECDSA384(ECKeyProvider)}
* @deprecated use {@link #ECDSA384(ECPublicKey, ECPrivateKey)} or {@link #ECDSA384(ECDSAKeyProvider)}
*/
@Deprecated
public static Algorithm ECDSA384(ECKey key) throws IllegalArgumentException {
Expand All @@ -283,7 +284,7 @@ public static Algorithm ECDSA384(ECKey key) throws IllegalArgumentException {
* @return a valid ECDSA512 Algorithm.
* @throws IllegalArgumentException if the Key Provider is null.
*/
public static Algorithm ECDSA512(ECKeyProvider keyProvider) throws IllegalArgumentException {
public static Algorithm ECDSA512(ECDSAKeyProvider keyProvider) throws IllegalArgumentException {
return new ECDSAAlgorithm("ES512", "SHA512withECDSA", 66, keyProvider);
}

Expand All @@ -305,7 +306,7 @@ public static Algorithm ECDSA512(ECPublicKey publicKey, ECPrivateKey privateKey)
* @param key the key to use in the verify or signing instance.
* @return a valid ECDSA512 Algorithm.
* @throws IllegalArgumentException if the provided Key is null.
* @deprecated use {@link #ECDSA512(ECPublicKey, ECPrivateKey)} or {@link #ECDSA512(ECKeyProvider)}
* @deprecated use {@link #ECDSA512(ECPublicKey, ECPrivateKey)} or {@link #ECDSA512(ECDSAKeyProvider)}
*/
@Deprecated
public static Algorithm ECDSA512(ECKey key) throws IllegalArgumentException {
Expand All @@ -324,6 +325,15 @@ protected Algorithm(String name, String description) {
this.description = description;
}

/**
* Getter for the Id of the Private Key used to sign the tokens. This is usually specified as the `kid` claim in the Header.
*
* @return the Key Id that identifies the Signing Key or null if it's not specified.
*/
public String getSigningKeyId() {
return null;
}

/**
* Getter for the name of this Algorithm, as defined in the JWT Standard. i.e. "HS256"
*
Expand All @@ -348,13 +358,12 @@ public String toString() {
}

/**
* Verify the given content using this Algorithm instance.
* Verify the given token using this Algorithm instance.
*
* @param contentBytes an array of bytes representing the base64 encoded content to be verified against the signature.
* @param signatureBytes an array of bytes representing the base64 encoded signature to compare the content against.
* @param jwt the already decoded JWT that it's going to be verified.
* @throws SignatureVerificationException if the Token's Signature is invalid, meaning that it doesn't match the signatureBytes, or if the Key is invalid.
*/
public abstract void verify(byte[] contentBytes, byte[] signatureBytes) throws SignatureVerificationException;
public abstract void verify(DecodedJWT jwt) throws SignatureVerificationException;

/**
* Sign the given content using this Algorithm instance.
Expand Down
Loading