Skip to content

Commit

Permalink
Implement custom JWT assertion signing (#1001) (#1215)
Browse files Browse the repository at this point in the history
* Implement custom JWT assertion signing (#1001)

- Can be used to sign with KMS services instead of local private key

* added typeInfoAnnotation mustache template

* cleanup unused import

---------

Co-authored-by: Clément Denis <clement.denis@gmail.com>
  • Loading branch information
arvindkrishnakumar-okta and clementdenis authored May 15, 2024
1 parent 596a52a commit 5abbeb6
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 8 deletions.
11 changes: 11 additions & 0 deletions api/src/main/java/com/okta/sdk/client/ClientBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.nio.file.Path;
import java.security.PrivateKey;
import java.util.Set;
import java.util.function.UnaryOperator;

/**
*
Expand Down Expand Up @@ -241,6 +242,16 @@ public interface ClientBuilder {
*/
ClientBuilder setPrivateKey(PrivateKey privateKey);

/**
* Allows specifying a custom signer for signing JWT token, instead of using a locally stored private key.
*
* @param jwtSigner the JWT signer instance.
* @return the ClientBuilder instance for method chaining.
*
* @since 16.x.x
*/
ClientBuilder setCustomJwtSigner(UnaryOperator<byte[]> jwtSigner, String algorithm);

/**
* Allows specifying the user obtained OAuth2 access token to be used by the SDK.
* The SDK will NOT obtain access token automatically (using the supplied private key)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
{{!
Copyright 2021-Present Okta, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
}}
{{!
Based on https://github.com/OpenAPITools/openapi-generator/blob/v6.6.0/modules/openapi-generator/src/main/resources/Java/typeInfoAnnotation.mustache
- Add defaultImpl to deserialize to base type if discriminator is null or unknown
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
import java.security.PrivateKey;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.function.UnaryOperator;
import java.util.stream.Collectors;

/**
Expand Down Expand Up @@ -447,8 +448,12 @@ private void validateOAuth2ClientConfig(ClientConfiguration clientConfiguration)
"At least one scope is required");
String privateKey = clientConfiguration.getPrivateKey();
String oAuth2AccessToken = clientConfiguration.getOAuth2AccessToken();
Assert.isTrue(Objects.nonNull(privateKey) || Objects.nonNull(oAuth2AccessToken),
"Either Private Key (or) Access Token must be supplied for OAuth2 Authentication mode");
UnaryOperator<byte[]> jwtSigner = clientConfiguration.getJwtSigner();
String jwtSigningAlgorithm = clientConfiguration.getJwtSigningAlgorithm();
Assert.isTrue(Objects.nonNull(privateKey) || Objects.nonNull(oAuth2AccessToken)
|| Objects.nonNull(jwtSigner) && Objects.nonNull(jwtSigningAlgorithm),
"Either Private Key (or) Access Token (or) JWT Signer + Algorithm" +
" must be supplied for OAuth2 Authentication mode");

if (Strings.hasText(privateKey) && !ConfigUtil.hasPrivateKeyContentWrapper(privateKey)) {
// privateKey is a file path, check if the file exists
Expand Down Expand Up @@ -575,6 +580,14 @@ private String readFromInputStream(InputStream inputStream) throws IOException {
return resultStringBuilder.toString();
}

@Override
public ClientBuilder setCustomJwtSigner(UnaryOperator<byte[]> jwtSigner, String algorithm) {
Assert.notNull(jwtSigner, "jwtSigner cannot be null.");
Assert.notNull(algorithm, "algorithm cannot be null.");
clientConfig.setJwtSigner(jwtSigner, algorithm);
return this;
}

@Override
public ClientBuilder setClientId(String clientId) {
ConfigurationValidator.assertClientId(clientId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.function.UnaryOperator;

/**
* This class holds the default configuration properties.
Expand All @@ -54,6 +55,8 @@ public class ClientConfiguration extends HttpClientConfiguration {
private String privateKey;
private String oAuth2AccessToken;
private String kid;
private UnaryOperator<byte[]> jwtSigner;
private String jwtSigningAlgorithm;

public String getApiToken() {
return apiToken;
Expand Down Expand Up @@ -151,6 +154,23 @@ public void setKid(String kid) {
this.kid = kid;
}

public UnaryOperator<byte[]> getJwtSigner() {
return jwtSigner;
}

public void setJwtSigner(UnaryOperator<byte[]> jwtSigner, String algorithm) {
this.jwtSigner = jwtSigner;
this.jwtSigningAlgorithm = algorithm;
}

public String getJwtSigningAlgorithm() {
return jwtSigningAlgorithm;
}

public boolean hasCustomJwtSigner() {
return jwtSigner != null && jwtSigningAlgorithm != null;
}

/**
* Time to idle for cache manager in seconds
* @return seconds until time to idle expires
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,32 @@
import com.okta.sdk.impl.api.DefaultClientCredentialsResolver;
import com.okta.sdk.impl.config.ClientConfiguration;
import com.okta.sdk.impl.util.ConfigUtil;
import com.okta.sdk.resource.client.ApiClient;
import com.okta.sdk.resource.client.ApiException;
import com.okta.sdk.resource.model.HttpMethod;
import io.jsonwebtoken.JwtBuilder;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.security.SecureDigestAlgorithm;
import io.jsonwebtoken.security.SecureRequest;
import io.jsonwebtoken.security.SecurityException;
import io.jsonwebtoken.security.VerifySecureDigestRequest;
import org.bouncycastle.asn1.pkcs.PrivateKeyInfo;
import org.bouncycastle.openssl.PEMKeyPair;
import org.bouncycastle.openssl.PEMParser;
import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter;
import com.okta.sdk.resource.client.ApiClient;
import com.okta.sdk.resource.client.ApiException;
import com.okta.sdk.resource.model.HttpMethod;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.Reader;
import java.io.StringReader;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.KeyPair;
import java.security.PrivateKey;
import java.time.Instant;
Expand All @@ -60,8 +67,48 @@ public class AccessTokenRetrieverServiceImpl implements AccessTokenRetrieverServ

static final String TOKEN_URI = "/oauth2/v1/token";

private static final KeyPair DUMMY_KEY_PAIR = Jwts.SIG.RS256.keyPair().build();

/**
* Custom SecureDigestAlgorithm that delegates signature to the jwtSigner in tokenClientConfiguration
*/
private class CustomJwtSigningAlgorithm implements SecureDigestAlgorithm<PrivateKey, Key> {
@Override
public byte[] digest(SecureRequest<InputStream, PrivateKey> request) throws SecurityException {
try {
byte[] bytes = readAllBytes(request.getPayload());
return tokenClientConfiguration.getJwtSigner().apply(bytes);
} catch (IOException e) {
throw new RuntimeException(e);
}
}

//to replace with InputStream.readAllBytes after migrating to Java 9+
private byte[] readAllBytes(InputStream payload) throws IOException {
ByteArrayOutputStream buffer = new ByteArrayOutputStream();
int nRead;
byte[] data = new byte[16384];
while ((nRead = payload.read(data, 0, data.length)) != -1) {
buffer.write(data, 0, nRead);
}
return buffer.toByteArray();
}

@Override
public boolean verify(VerifySecureDigestRequest<Key> request) throws SecurityException {
//no need to verify JWTs
throw new UnsupportedOperationException();
}

@Override
public String getId() {
return tokenClientConfiguration.getJwtSigningAlgorithm();
}
}

private final ClientConfiguration tokenClientConfiguration;
private final ApiClient apiClient;
private final CustomJwtSigningAlgorithm customJwtSigningAlgorithm = new CustomJwtSigningAlgorithm();

public AccessTokenRetrieverServiceImpl(ClientConfiguration apiClientConfiguration, ApiClient apiClient) {
Assert.notNull(apiClientConfiguration, "apiClientConfiguration must not be null.");
Expand Down Expand Up @@ -133,7 +180,6 @@ public OAuth2AccessToken getOAuth2AccessToken() throws IOException, InvalidKeyEx
*/
String createSignedJWT() throws InvalidKeyException, IOException {
String clientId = tokenClientConfiguration.getClientId();
PrivateKey privateKey = parsePrivateKey(getPemReader());
Instant now = Instant.now();

JwtBuilder builder = Jwts.builder()
Expand All @@ -142,8 +188,14 @@ String createSignedJWT() throws InvalidKeyException, IOException {
.expiration(Date.from(now.plus(50, ChronoUnit.MINUTES))) // see Javadoc
.issuer(clientId)
.subject(clientId)
.claim("jti", UUID.randomUUID().toString())
.signWith(privateKey);
.claim("jti", UUID.randomUUID().toString());

if (tokenClientConfiguration.hasCustomJwtSigner()) {
//JwtBuilder requires a key to be passed, even if it's actually not used by the algorithm
builder.signWith(DUMMY_KEY_PAIR.getPrivate(), customJwtSigningAlgorithm);
} else {
builder = builder.signWith(parsePrivateKey(getPemReader()));
}

if (Strings.hasText(tokenClientConfiguration.getKid())) {
builder.header().add("kid", tokenClientConfiguration.getKid());
Expand Down Expand Up @@ -248,6 +300,7 @@ ClientConfiguration constructTokenClientConfig(ClientConfiguration apiClientConf
tokenClientConfiguration.setClientId(apiClientConfiguration.getClientId());
tokenClientConfiguration.setScopes(apiClientConfiguration.getScopes());
tokenClientConfiguration.setPrivateKey(apiClientConfiguration.getPrivateKey());
tokenClientConfiguration.setJwtSigner(apiClientConfiguration.getJwtSigner(), apiClientConfiguration.getJwtSigningAlgorithm());
tokenClientConfiguration.setKid(apiClientConfiguration.getKid());

// setting this to '0' will disable this check and only 'retryMaxAttempts' will be effective
Expand Down

0 comments on commit 5abbeb6

Please sign in to comment.