From 11657ec647243a8caa818f370b392b723ee18548 Mon Sep 17 00:00:00 2001 From: Arvind Krishnakumar <61501885+arvindkrishnakumar-okta@users.noreply.github.com> Date: Thu, 25 Aug 2022 13:35:03 -0700 Subject: [PATCH] [OASv3] - OAuth for Okta (#753) * wip - oauth for okta * fix unit test failure * fix unit test failure --- .../sdk/impl/client/DefaultClientBuilder.java | 69 ++- .../oauth2/AccessTokenRetrieverService.java | 36 ++ .../AccessTokenRetrieverServiceImpl.java | 270 +++++++++++ .../sdk/impl/oauth2/OAuth2AccessToken.java | 102 ++++ .../impl/oauth2/OAuth2ClientCredentials.java | 73 +++ .../sdk/impl/oauth2/OAuth2HttpException.java | 28 ++ .../oauth2/OAuth2TokenRetrieverException.java | 30 ++ ...AccessTokenRetrieverServiceImplTest.groovy | 438 ++++++++++++++++++ 8 files changed, 1036 insertions(+), 10 deletions(-) create mode 100644 impl/src/main/java/com/okta/sdk/impl/oauth2/AccessTokenRetrieverService.java create mode 100644 impl/src/main/java/com/okta/sdk/impl/oauth2/AccessTokenRetrieverServiceImpl.java create mode 100644 impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2AccessToken.java create mode 100644 impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2ClientCredentials.java create mode 100644 impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2HttpException.java create mode 100644 impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2TokenRetrieverException.java create mode 100644 impl/src/test/groovy/com/okta/sdk/impl/oauth2/AccessTokenRetrieverServiceImplTest.groovy diff --git a/impl/src/main/java/com/okta/sdk/impl/client/DefaultClientBuilder.java b/impl/src/main/java/com/okta/sdk/impl/client/DefaultClientBuilder.java index c6790c5021d..f8f86d3a17c 100644 --- a/impl/src/main/java/com/okta/sdk/impl/client/DefaultClientBuilder.java +++ b/impl/src/main/java/com/okta/sdk/impl/client/DefaultClientBuilder.java @@ -44,6 +44,9 @@ import com.okta.sdk.impl.io.DefaultResourceFactory; import com.okta.sdk.impl.io.Resource; import com.okta.sdk.impl.io.ResourceFactory; +import com.okta.sdk.impl.oauth2.AccessTokenRetrieverService; +import com.okta.sdk.impl.oauth2.AccessTokenRetrieverServiceImpl; +import com.okta.sdk.impl.oauth2.OAuth2ClientCredentials; import com.okta.sdk.impl.util.ConfigUtil; import com.okta.sdk.impl.util.DefaultBaseUrlResolver; import org.apache.http.HttpHost; @@ -73,7 +76,11 @@ import java.io.InputStreamReader; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.InvalidPathException; +import java.nio.file.LinkOption; import java.nio.file.Path; +import java.nio.file.Paths; import java.security.PrivateKey; import java.util.ArrayList; import java.util.Arrays; @@ -117,6 +124,8 @@ public class DefaultClientBuilder implements ClientBuilder { private ClientConfiguration clientConfig = new ClientConfiguration(); + private AccessTokenRetrieverService accessTokenRetrieverService; + public DefaultClientBuilder() { this(new DefaultResourceFactory()); } @@ -272,21 +281,57 @@ public ApiClient build() { this.clientConfig.setBaseUrlResolver(new DefaultBaseUrlResolver(this.clientConfig.getBaseUrl())); } - if (this.clientConfig.getClientCredentialsResolver() == null && this.clientCredentials != null) { - this.clientConfig.setClientCredentialsResolver(new DefaultClientCredentialsResolver(this.clientCredentials)); - } else if (this.clientConfig.getClientCredentialsResolver() == null) { - this.clientConfig.setClientCredentialsResolver(new DefaultClientCredentialsResolver(this.clientConfig)); - } - ApiClient apiClient = new ApiClient(restTemplate(this.clientConfig)); - apiClient.setBasePath(this.clientConfig.getBaseUrl()); - apiClient.setApiKey((String) this.clientConfig.getClientCredentialsResolver().getClientCredentials().getCredentials()); - // for beta release, we support only SSWS, OAuth2 support planned to be added in later release - apiClient.setApiKeyPrefix("SSWS"); + + if (!isOAuth2Flow()) { + if (this.clientConfig.getClientCredentialsResolver() == null && this.clientCredentials != null) { + this.clientConfig.setClientCredentialsResolver(new DefaultClientCredentialsResolver(this.clientCredentials)); + } else if (this.clientConfig.getClientCredentialsResolver() == null) { + this.clientConfig.setClientCredentialsResolver(new DefaultClientCredentialsResolver(this.clientConfig)); + } + + apiClient.setBasePath(this.clientConfig.getBaseUrl()); + apiClient.setApiKeyPrefix("SSWS"); + apiClient.setApiKey((String) this.clientConfig.getClientCredentialsResolver().getClientCredentials().getCredentials()); + } else { + this.clientConfig.setAuthenticationScheme(AuthenticationScheme.OAUTH2_PRIVATE_KEY); + + validateOAuth2ClientConfig(this.clientConfig); + + accessTokenRetrieverService = new AccessTokenRetrieverServiceImpl(clientConfig, apiClient); + + OAuth2ClientCredentials oAuth2ClientCredentials = + new OAuth2ClientCredentials(accessTokenRetrieverService); + + this.clientConfig.setClientCredentialsResolver(new DefaultClientCredentialsResolver(oAuth2ClientCredentials)); + } return apiClient; } + /** + * @since 1.6.0 + */ + private void validateOAuth2ClientConfig(ClientConfiguration clientConfiguration) { + Assert.notNull(clientConfiguration.getClientId(), "clientId cannot be null"); + Assert.isTrue(clientConfiguration.getScopes() != null && !clientConfiguration.getScopes().isEmpty(), + "At least one scope is required"); + String privateKey = clientConfiguration.getPrivateKey(); + Assert.hasText(privateKey, "privateKey cannot be null (either PEM file path (or) full PEM content must be supplied)"); + + if (!ConfigUtil.hasPrivateKeyContentWrapper(privateKey)) { + // privateKey is a file path, check if the file exists + Path privateKeyPemFilePath; + try { + privateKeyPemFilePath = Paths.get(privateKey); + } catch (InvalidPathException ipe) { + throw new IllegalArgumentException("Invalid privateKey file path", ipe); + } + boolean privateKeyPemFileExists = Files.exists(privateKeyPemFilePath, LinkOption.NOFOLLOW_LINKS); + Assert.isTrue(privateKeyPemFileExists, "privateKey file does not exist"); + } + } + private RestTemplate restTemplate(ClientConfiguration clientConfig) { ObjectMapper objectMapper = new ObjectMapper(); @@ -460,6 +505,10 @@ public ClientBuilder setKid(String kid) { return this; } + boolean isOAuth2Flow() { + return this.getClientConfiguration().getAuthorizationMode() == AuthorizationMode.PRIVATE_KEY; + } + public ClientConfiguration getClientConfiguration() { return clientConfig; } diff --git a/impl/src/main/java/com/okta/sdk/impl/oauth2/AccessTokenRetrieverService.java b/impl/src/main/java/com/okta/sdk/impl/oauth2/AccessTokenRetrieverService.java new file mode 100644 index 00000000000..995a7a3a118 --- /dev/null +++ b/impl/src/main/java/com/okta/sdk/impl/oauth2/AccessTokenRetrieverService.java @@ -0,0 +1,36 @@ +/* + * Copyright 2020-Present Okta, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.okta.sdk.impl.oauth2; + +import java.io.IOException; +import java.security.InvalidKeyException; + +/** + * Abstraction for OAuth2 access token retrieval service function. + * + * @since 1.6.0 + */ +public interface AccessTokenRetrieverService { + /** + * Obtain OAuth2 access token from Authorization Server endpoint. + * + * @return {@link OAuth2AccessToken} + * @throws IOException if problems are encountered extracting the input private key + * @throws InvalidKeyException if supplied private key is invalid + * @throws OAuth2TokenRetrieverException if token could not be retrieved + */ + OAuth2AccessToken getOAuth2AccessToken() throws IOException, InvalidKeyException, OAuth2TokenRetrieverException; +} diff --git a/impl/src/main/java/com/okta/sdk/impl/oauth2/AccessTokenRetrieverServiceImpl.java b/impl/src/main/java/com/okta/sdk/impl/oauth2/AccessTokenRetrieverServiceImpl.java new file mode 100644 index 00000000000..80d384d6b67 --- /dev/null +++ b/impl/src/main/java/com/okta/sdk/impl/oauth2/AccessTokenRetrieverServiceImpl.java @@ -0,0 +1,270 @@ +/* + * Copyright 2020-Present Okta, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.okta.sdk.impl.oauth2; + +import com.okta.commons.http.authc.DisabledAuthenticator; +import com.okta.commons.lang.Assert; +import com.okta.commons.lang.Strings; +import com.okta.sdk.client.AuthenticationScheme; +import com.okta.sdk.client.AuthorizationMode; +import com.okta.sdk.error.Error; +import com.okta.sdk.error.ResourceException; +import com.okta.sdk.impl.api.DefaultClientCredentialsResolver; +import com.okta.sdk.impl.config.ClientConfiguration; +import com.okta.sdk.impl.util.ConfigUtil; +import io.jsonwebtoken.JwtBuilder; +import io.jsonwebtoken.Jwts; +import org.bouncycastle.asn1.pkcs.PrivateKeyInfo; +import org.bouncycastle.openssl.PEMKeyPair; +import org.bouncycastle.openssl.PEMParser; +import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter; +import org.openapitools.client.ApiClient; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +import java.io.IOException; +import java.io.Reader; +import java.io.StringReader; +import java.nio.charset.Charset; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.security.InvalidKeyException; +import java.security.KeyPair; +import java.security.PrivateKey; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Collections; +import java.util.Date; +import java.util.Optional; +import java.util.UUID; + +/** + * Implementation of {@link AccessTokenRetrieverService} interface. + * + * This has logic to fetch OAuth2 access token from the Authorization server endpoint. + * + * @since 1.6.0 + */ +public class AccessTokenRetrieverServiceImpl implements AccessTokenRetrieverService { + private static final Logger log = LoggerFactory.getLogger(AccessTokenRetrieverServiceImpl.class); + + private static final String TOKEN_URI = "/oauth2/v1/token"; + + private final ClientConfiguration tokenClientConfiguration; + private final ApiClient apiClient; + + public AccessTokenRetrieverServiceImpl(ClientConfiguration apiClientConfiguration, ApiClient apiClient) { + Assert.notNull(apiClientConfiguration, "apiClientConfiguration must not be null."); + Assert.notNull(apiClient, "apiClient must not be null."); + this.apiClient = apiClient; + this.tokenClientConfiguration = constructTokenClientConfig(apiClientConfiguration); + } + + /** + * {@inheritDoc} + */ + @Override + public OAuth2AccessToken getOAuth2AccessToken() throws IOException, InvalidKeyException, OAuth2TokenRetrieverException { + log.debug("Attempting to get OAuth2 access token for client id {} from {}", + tokenClientConfiguration.getClientId(), tokenClientConfiguration.getBaseUrl() + TOKEN_URI); + + String signedJwt = createSignedJWT(); + String scope = String.join(" ", tokenClientConfiguration.getScopes()); + + try { + HttpHeaders httpHeaders = new HttpHeaders(); + httpHeaders.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON)); + httpHeaders.setContentType(MediaType.APPLICATION_FORM_URLENCODED); + + MultiValueMap queryParams = new LinkedMultiValueMap<>(); + queryParams.add("grant_type", "client_credentials"); + queryParams.add("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"); + queryParams.add("client_assertion", signedJwt); + queryParams.add("scope", scope); + + ResponseEntity responseEntity = apiClient.invokeAPI(TOKEN_URI, + HttpMethod.POST, + Collections.emptyMap(), + queryParams, + null, + httpHeaders, + new LinkedMultiValueMap<>(), + null, + Collections.singletonList(MediaType.APPLICATION_JSON), + MediaType.APPLICATION_JSON, + new String[] { "OAuth_2.0" }, + new ParameterizedTypeReference() {}); + + OAuth2AccessToken oAuth2AccessToken = responseEntity.getBody(); + + log.debug("Got OAuth2 access token for client id {} from {}", + tokenClientConfiguration.getClientId(), tokenClientConfiguration.getBaseUrl() + TOKEN_URI); + + return oAuth2AccessToken; + } catch (ResourceException e) { + Error defaultError = e.getError(); + throw new OAuth2HttpException(defaultError.getMessage(), e, e.getStatus() == 401); + } catch (Exception e) { + throw new OAuth2TokenRetrieverException("Exception while trying to get " + + "OAuth2 access token for client id " + tokenClientConfiguration.getClientId(), e); + } + } + + /** + * Create signed JWT string with the supplied token client configuration details. + * + * Expiration value should be not more than one hour in the future. + * We use 50 minutes in order to have a 10 minutes leeway in case of clock skew. + * + * @return signed JWT string + * @throws InvalidKeyException if the supplied key is invalid + * @throws IOException if the key could not be read + */ + String createSignedJWT() throws InvalidKeyException, IOException { + String clientId = tokenClientConfiguration.getClientId(); + PrivateKey privateKey = parsePrivateKey(getPemReader()); + Instant now = Instant.now(); + + JwtBuilder builder = Jwts.builder() + .setAudience(tokenClientConfiguration.getBaseUrl() + TOKEN_URI) + .setIssuedAt(Date.from(now)) + .setExpiration(Date.from(now.plus(50, ChronoUnit.MINUTES))) // see Javadoc + .setIssuer(clientId) + .setSubject(clientId) + .claim("jti", UUID.randomUUID().toString()) + .signWith(privateKey); + + if (Strings.hasText(tokenClientConfiguration.getKid())) { + builder.setHeaderParam("kid", tokenClientConfiguration.getKid()); + } + + return builder.compact(); + } + + /** + * Parse private key from the supplied configuration. + * + * @param pemReader a {@link Reader} that has access to a full PEM resource + * @return {@link PrivateKey} + * @throws IOException if the private key could not be read + * @throws InvalidKeyException if the supplied key is invalid + */ + PrivateKey parsePrivateKey(Reader pemReader) throws IOException, InvalidKeyException { + + PrivateKey privateKey = getPrivateKeyFromPEM(pemReader); + String algorithm = privateKey.getAlgorithm(); + + if (!algorithm.equals("RSA") && + !algorithm.equals("EC")) { + throw new InvalidKeyException("Supplied privateKey is not an RSA or EC key - " + algorithm); + } + + return privateKey; + } + + private Reader getPemReader() throws IOException { + String privateKey = tokenClientConfiguration.getPrivateKey(); + if (ConfigUtil.hasPrivateKeyContentWrapper(privateKey)) { + return new StringReader(privateKey); + } else { + return Files.newBufferedReader(Paths.get(privateKey), Charset.defaultCharset()); + } + } + + /** + * Get Private key from input PEM file. + * + * @param reader the reader instance + * @return {@link PrivateKey} private key instance + * @throws IOException if the parser could not read the reader object + */ + PrivateKey getPrivateKeyFromPEM(Reader reader) throws IOException { + PrivateKey privateKey; + + try (PEMParser pemParser = new PEMParser(reader)) { + JcaPEMKeyConverter jcaPEMKeyConverter = new JcaPEMKeyConverter(); + Object pemContent = pemParser.readObject(); + + if (pemContent == null) { + throw new IllegalArgumentException("Invalid Private Key PEM file"); + } + + if (pemContent instanceof PEMKeyPair) { + PEMKeyPair pemKeyPair = (PEMKeyPair) pemContent; + KeyPair keyPair = jcaPEMKeyConverter.getKeyPair(pemKeyPair); + privateKey = keyPair.getPrivate(); + } else if (pemContent instanceof PrivateKeyInfo) { + PrivateKeyInfo privateKeyInfo = (PrivateKeyInfo) pemContent; + privateKey = jcaPEMKeyConverter.getPrivateKey(privateKeyInfo); + } else { + throw new IllegalArgumentException("Unsupported Private Key format '" + + pemContent.getClass().getSimpleName() + '"'); + } + } + + return privateKey; + } + + /** + * Create token client config from the supplied API client config. + * + * Token client needs to retry http 401 errors only once which is not the case with the API client. + * We therefore effect this token client specific config by setting 'retryMaxElapsed' + * & 'retryMaxAttempts' fields. + * + * @param apiClientConfiguration supplied in the client configuration. + * @return ClientConfiguration to be used by token retrieval client. + */ + ClientConfiguration constructTokenClientConfig(ClientConfiguration apiClientConfiguration) { + ClientConfiguration tokenClientConfiguration = new ClientConfiguration(); + + tokenClientConfiguration.setClientCredentialsResolver( + new DefaultClientCredentialsResolver(Optional::empty)); + + tokenClientConfiguration.setRequestAuthenticator(new DisabledAuthenticator()); + + // TODO: set this to false explicitly when caching is implemented in OASv3 SDK + //tokenClientConfiguration.setCacheManagerEnabled(false); + + if (apiClientConfiguration.getBaseUrlResolver() != null) + tokenClientConfiguration.setBaseUrlResolver(apiClientConfiguration.getBaseUrlResolver()); + + if (apiClientConfiguration.getProxy() != null) + tokenClientConfiguration.setProxy(apiClientConfiguration.getProxy()); + + tokenClientConfiguration.setAuthenticationScheme(AuthenticationScheme.NONE); + tokenClientConfiguration.setAuthorizationMode(AuthorizationMode.get(tokenClientConfiguration.getAuthenticationScheme())); + tokenClientConfiguration.setClientId(apiClientConfiguration.getClientId()); + tokenClientConfiguration.setScopes(apiClientConfiguration.getScopes()); + tokenClientConfiguration.setPrivateKey(apiClientConfiguration.getPrivateKey()); + tokenClientConfiguration.setKid(apiClientConfiguration.getKid()); + + // setting this to '0' will disable this check and only 'retryMaxAttempts' will be effective + tokenClientConfiguration.setRetryMaxElapsed(0); + // retry only once for token requests (http 401 errors) + tokenClientConfiguration.setRetryMaxAttempts(1); + + return tokenClientConfiguration; + } + +} diff --git a/impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2AccessToken.java b/impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2AccessToken.java new file mode 100644 index 00000000000..8428022df9b --- /dev/null +++ b/impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2AccessToken.java @@ -0,0 +1,102 @@ +/* + * Copyright 2020-Present Okta, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.okta.sdk.impl.oauth2; + +import java.time.Duration; +import java.time.Instant; + +/** + * Represents the OAuth2 access token returned by Authorization server. + * + * @since 1.6.0 + */ +public class OAuth2AccessToken { + + /* Token body constants */ + public static final String TOKEN_TYPE_KEY = "token_type"; + public static final String EXPIRES_IN_KEY = "expires_in"; + public static final String ACCESS_TOKEN_KEY = "access_token"; + public static final String SCOPE_KEY = "scope"; + + /* Token error constants */ + public static final String ERROR_KEY = "error"; + public static final String ERROR_DESCRIPTION = "error_description"; + + private String tokenType; + + private Integer expiresIn; + + private String accessToken; + + private String scope; + + private Instant issuedAt = Instant.now(); + + public String getTokenType() { + return tokenType; + } + + public void setTokenType(String tokenType) { + this.tokenType = tokenType; + } + + public Integer getExpiresIn() { + return expiresIn; + } + + public void setExpiresIn(Integer expiresIn) { + this.expiresIn = expiresIn; + } + + public String getAccessToken() { + return accessToken; + } + + public void setAccessToken(String accessToken) { + this.accessToken = accessToken; + } + + public String getScope() { + return scope; + } + + public void setScope(String scope) { + this.scope = scope; + } + + public Instant getIssuedAt() { + return issuedAt; + } + + public boolean hasExpired() { + Duration duration = Duration.between(this.getIssuedAt(), Instant.now()); + return duration.getSeconds() >= this.getExpiresIn(); + } + + // for testing purposes + void expireNow() { + this.setExpiresIn(Integer.MIN_VALUE); + } + + @Override + public String toString() { + return "OAuth2AccessToken [tokenType=" + tokenType + + ", issuedAt=" + issuedAt + + ", expiresIn=" + expiresIn + + ", accessToken=xxxxx" + + ", scope=" + scope + "]"; + } +} diff --git a/impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2ClientCredentials.java b/impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2ClientCredentials.java new file mode 100644 index 00000000000..a514c2ce19e --- /dev/null +++ b/impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2ClientCredentials.java @@ -0,0 +1,73 @@ +/* + * Copyright 2020-Present Okta, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.okta.sdk.impl.oauth2; + +import com.okta.commons.lang.Assert; +import com.okta.sdk.authc.credentials.ClientCredentials; + +import java.io.IOException; +import java.security.InvalidKeyException; + +/** + * This implementation represents client credentials specific to OAuth2 Authentication scheme. + * + * @since 1.6.0 + */ +public class OAuth2ClientCredentials implements ClientCredentials { + + private OAuth2AccessToken oAuth2AccessToken; + private AccessTokenRetrieverService accessTokenRetrieverService; + + public OAuth2ClientCredentials(AccessTokenRetrieverService accessTokenRetrieverService) { + Assert.notNull(accessTokenRetrieverService, "accessTokenRetrieverService must not be null."); + this.accessTokenRetrieverService = accessTokenRetrieverService; + this.oAuth2AccessToken = eagerFetchOAuth2AccessToken(); + } + + private OAuth2AccessToken eagerFetchOAuth2AccessToken() { + OAuth2AccessToken accessToken; + + try { + accessToken = accessTokenRetrieverService.getOAuth2AccessToken(); + } catch (IOException | InvalidKeyException e) { + throw new OAuth2TokenRetrieverException("Failed to fetch OAuth2 access token eagerly", e); + } + + if (accessToken == null) { + throw new OAuth2TokenRetrieverException("Failed to fetch OAuth2 access token eagerly"); + } + + return accessToken; + } + + public OAuth2AccessToken getCredentials() { + return oAuth2AccessToken; + } + + public void setCredentials(OAuth2AccessToken oAuth2AccessToken) { + this.oAuth2AccessToken = oAuth2AccessToken; + } + + public AccessTokenRetrieverService getAccessTokenRetrieverService() { + return accessTokenRetrieverService; + } + + @Override + public String toString() { + // never ever print the secret + return ""; + } +} diff --git a/impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2HttpException.java b/impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2HttpException.java new file mode 100644 index 00000000000..f24358ec59e --- /dev/null +++ b/impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2HttpException.java @@ -0,0 +1,28 @@ +/* + * Copyright 2020-Present Okta, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.okta.sdk.impl.oauth2; + +import com.okta.commons.http.HttpException; + +/** + * @since 1.6.0 + */ +public class OAuth2HttpException extends HttpException { + + public OAuth2HttpException(String errMsg, Throwable cause, boolean retryable) { + super(errMsg, cause, retryable); + } +} diff --git a/impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2TokenRetrieverException.java b/impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2TokenRetrieverException.java new file mode 100644 index 00000000000..d602b582888 --- /dev/null +++ b/impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2TokenRetrieverException.java @@ -0,0 +1,30 @@ +/* + * Copyright 2020-Present Okta, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.okta.sdk.impl.oauth2; + +/** + * @since 1.6.0 + */ +public class OAuth2TokenRetrieverException extends RuntimeException { + + public OAuth2TokenRetrieverException(String s) { + super(s); + } + + public OAuth2TokenRetrieverException(String s, Throwable cause) { + super(s, cause); + } +} diff --git a/impl/src/test/groovy/com/okta/sdk/impl/oauth2/AccessTokenRetrieverServiceImplTest.groovy b/impl/src/test/groovy/com/okta/sdk/impl/oauth2/AccessTokenRetrieverServiceImplTest.groovy new file mode 100644 index 00000000000..8b682fd2627 --- /dev/null +++ b/impl/src/test/groovy/com/okta/sdk/impl/oauth2/AccessTokenRetrieverServiceImplTest.groovy @@ -0,0 +1,438 @@ +/* + * Copyright 2020-Present Okta, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.okta.sdk.impl.oauth2 + +import com.nimbusds.jose.jwk.RSAKey +import com.okta.commons.http.config.BaseUrlResolver +import com.okta.sdk.client.AuthorizationMode +import com.okta.sdk.client.Clients +import com.okta.sdk.error.Error +import com.okta.sdk.error.ErrorCause +import com.okta.sdk.error.ResourceException +import com.okta.sdk.impl.Util +import com.okta.sdk.impl.api.DefaultClientCredentialsResolver +import com.okta.sdk.impl.client.DefaultClientBuilder +import com.okta.sdk.impl.config.ClientConfiguration +import io.jsonwebtoken.Claims +import io.jsonwebtoken.Header +import io.jsonwebtoken.Jwts +import org.bouncycastle.openssl.PEMException +import org.hamcrest.MatcherAssert +import org.mockito.ArgumentMatchers +import org.openapitools.client.ApiClient +import org.springframework.http.HttpHeaders +import org.springframework.http.MediaType +import org.springframework.util.MultiValueMap +import org.testng.annotations.Test + +import java.security.KeyPair +import java.security.KeyPairGenerator +import java.security.PrivateKey + +import static org.hamcrest.MatcherAssert.assertThat +import static org.hamcrest.Matchers.is +import static org.hamcrest.Matchers.notNullValue +import static org.mockito.ArgumentMatchers.* +import static org.mockito.Mockito.* +import static org.testng.Assert.assertEquals + +/** + * Test for {@link AccessTokenRetrieverServiceImpl} class + * + * @since 1.6.0 + */ +class AccessTokenRetrieverServiceImplTest { + + private static final String PRIVATE_KEY = "-----BEGIN PRIVATE KEY-----\n" + + "MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDH0Y47a8w/Tgiv" + + "V4mytjBSR/5HIu+P58/v3g6gbYvxC/NWPzPZ3hjTeRskJpB1AfNwm55rAhjSD99d" + + "4ZiHWMEFjEeSKIaEtMxDPU1pg23R/e+sATVevEfx1G1+IoaSu6SKLnHN7iNtvWlK" + + "reR5pNUVHKcRotg/auiNUd8P9Wok8FQhFGxbZEdYhjvICLfHLrZKQKOR1AwqPscX" + + "+FnLF2F+9X0QXRAX3XW/D4S9sfS9JN+J6mhhOQpy78p5VDxGPZim2bk/WNhB3uQ+" + + "4UE7xdMYIMBPxP3kd2/vhBLG+AhqGvh1XhROvCQ2mtuRq+vTFCfd+tObx1b5a7eU" + + "UwVWM8AxAgMBAAECggEBAJ1vapU+1ep63TTpz8BS87egqaP6zq2fg6IGX5ffOAdv" + + "1wX5Pi1GZGEaZlwRVngaVWg/9I1zVYMMpn0dpkPdlhd881chPvuIR/gicL/Voc12" + + "OkRXn2lJB5ZuPObI5SbvWTDWbyxFmPx55F/GquF9EbZUoP2wRJmS7i+Kdinovvzi" + + "Rwgevpdo/IA3uosX5NzIazIhCeb4v1v2tpvNH63pfAzKsEHVF0bVQ7yRtrDE5bbp" + + "NbyNR5em30G2CXqNQIdKMQAL3b4LfCGkXrJABVszjTXQO117PhCnifNmLGzlVhBD" + + "qC65Luh5GJ21qvj2InWIYdLft0DvzUVH29Bb9uB4H5ECgYEA/+FIE5PDRx+PP9Zq" + + "kKtXaTtBstZAUbx0vwIs1IruB9Xio/lskZ8woAXDqWYR4VOkqR6JeOPImHFN8keK" + + "Vvd04J/2nDEHa4SPy23Ww0YcmtvlpRwsWZ1JktO9PFK1YadyUEYWVjWToSGmbZD3" + + "aPOSKf+uFDuOUClyGUjZiWBMWRMCgYEAx+mLOk76eraiLC3QIJ7g9XjBDBmBBBQA" + + "3yk7F86zrrRxb/FA7G9zc19GkMjnGT5QYG7Qdw0LRbh4AT3zkMPVbd4Cy3YRhKFi" + + "XwF5loOv4YHKlB+Ny9yKC7Jz9tzhccOAxjyjwDtY2tw/DQP4xdgtDuMccr7DLdrB" + + "8mrZNn1vTisCgYAxdd50yk8o5FTQRiX7KOOQl7+vTfLI2eDHOyhnPSOdqB5TC9eM" + + "nnTLudGEYRJ7t6tQdXKlR4Jy1RP4DRQUk2ioMsN8lY2Vnt4cuHKW9Gp7FJ5jN/rq" + + "p5idJQijLGmbIr7Z/XI738dVkieVbjwksVBDhgSkLI7pt9kyQf6qq06WuQKBgQC6" + + "E2b1ghfhauduacIk6t2HfrtpkL+m1RuunEkVst9KyUghIxUEPgTfKZqcH3QD6h2U" + + "dPDzLyAD6F1DArAYWj/pwNEnIqHRqwnOVqge8joek9nEn84zJ/cSRitsZ1IsuwW8" + + "/yqIPnVJWeISMlU3iiz+g2SyZV906f7Grq+56W1V+wKBgQDrgV2VHJyIHuHS6t+A" + + "BV89ditsFt4n2h2SPzX9xI5uwUclwPy01bCaMccKvzPwhmJWMRDzIeKgLy3aJWkl" + + "8zZeebOsKqNB3Nlm8wNrYbJJvpTEyt6kpZi+YJ/S4mQ/0CzIpmq6014aKm16NJxv" + + "AHYKHuDUvhiDFwadf8Q7kSK5KA==\n" + + "-----END PRIVATE KEY-----" + + private static final String RSA_PRIVATE_KEY = "-----BEGIN RSA PRIVATE KEY-----\n" + + "MIIEpQIBAAKCAQEAx9GOO2vMP04Ir1eJsrYwUkf+RyLvj+fP794OoG2L8QvzVj8z" + + "2d4Y03kbJCaQdQHzcJueawIY0g/fXeGYh1jBBYxHkiiGhLTMQz1NaYNt0f3vrAE1" + + "XrxH8dRtfiKGkrukii5xze4jbb1pSq3keaTVFRynEaLYP2rojVHfD/VqJPBUIRRs" + + "W2RHWIY7yAi3xy62SkCjkdQMKj7HF/hZyxdhfvV9EF0QF911vw+EvbH0vSTfiepo" + + "YTkKcu/KeVQ8Rj2Yptm5P1jYQd7kPuFBO8XTGCDAT8T95Hdv74QSxvgIahr4dV4U" + + "TrwkNprbkavr0xQn3frTm8dW+Wu3lFMFVjPAMQIDAQABAoIBAQCdb2qVPtXqet00" + + "6c/AUvO3oKmj+s6tn4OiBl+X3zgHb9cF+T4tRmRhGmZcEVZ4GlVoP/SNc1WDDKZ9" + + "HaZD3ZYXfPNXIT77iEf4InC/1aHNdjpEV59pSQeWbjzmyOUm71kw1m8sRZj8eeRf" + + "xqrhfRG2VKD9sESZku4vinYp6L784kcIHr6XaPyAN7qLF+TcyGsyIQnm+L9b9rab" + + "zR+t6XwMyrBB1RdG1UO8kbawxOW26TW8jUeXpt9Btgl6jUCHSjEAC92+C3whpF6y" + + "QAVbM4010Dtdez4Qp4nzZixs5VYQQ6guuS7oeRidtar49iJ1iGHS37dA781FR9vQ" + + "W/bgeB+RAoGBAP/hSBOTw0cfjz/WapCrV2k7QbLWQFG8dL8CLNSK7gfV4qP5bJGf" + + "MKAFw6lmEeFTpKkeiXjjyJhxTfJHilb3dOCf9pwxB2uEj8tt1sNGHJrb5aUcLFmd" + + "SZLTvTxStWGnclBGFlY1k6Ehpm2Q92jzkin/rhQ7jlApchlI2YlgTFkTAoGBAMfp" + + "izpO+nq2oiwt0CCe4PV4wQwZgQQUAN8pOxfOs660cW/xQOxvc3NfRpDI5xk+UGBu" + + "0HcNC0W4eAE985DD1W3eAst2EYShYl8BeZaDr+GBypQfjcvciguyc/bc4XHDgMY8" + + "o8A7WNrcPw0D+MXYLQ7jHHK+wy3awfJq2TZ9b04rAoGAMXXedMpPKORU0EYl+yjj" + + "kJe/r03yyNngxzsoZz0jnageUwvXjJ50y7nRhGESe7erUHVypUeCctUT+A0UFJNo" + + "qDLDfJWNlZ7eHLhylvRqexSeYzf66qeYnSUIoyxpmyK+2f1yO9/HVZInlW48JLFQ" + + "Q4YEpCyO6bfZMkH+qqtOlrkCgYEAuhNm9YIX4WrnbmnCJOrdh367aZC/ptUbrpxJ" + + "FbLfSslIISMVBD4E3ymanB90A+odlHTw8y8gA+hdQwKwGFo/6cDRJyKh0asJzlao" + + "HvI6HpPZxJ/OMyf3EkYrbGdSLLsFvP8qiD51SVniEjJVN4os/oNksmVfdOn+xq6v" + + "ueltVfsCgYEA64FdlRyciB7h0urfgAVfPXYrbBbeJ9odkj81/cSObsFHJcD8tNWw" + + "mjHHCr8z8IZiVjEQ8yHioC8t2iVpJfM2XnmzrCqjQdzZZvMDa2GySb6UxMrepKWY" + + "vmCf0uJkP9AsyKZqutNeGiptejScbwB2Ch7g1L4YgxcGnX/EO5EiuSg=\n" + + "-----END RSA PRIVATE KEY-----" + + @Test + void testInstantiationWithNullClientConfig() { + Util.expect(IllegalArgumentException) { + new AccessTokenRetrieverServiceImpl(null, null) + } + } + + @Test + void testGetPrivateKeyFromPem() { + + def clientConfiguration = mock(ClientConfiguration) + def apiClient = mock(ApiClient) + + PrivateKey generatedPrivateKey = generatePrivateKey("RSA", 2048) + File privateKeyPemFile = writePrivateKeyToPemFile(generatedPrivateKey, "privateKey") + + // Now test the pem -> private key conversion function of getPrivateKeyFromPem method + Reader pemFileReader = new FileReader(privateKeyPemFile) + + PrivateKey resultPrivateKey = getAccessTokenRetrieverServiceInstance(clientConfiguration, apiClient).getPrivateKeyFromPEM(pemFileReader) + + privateKeyPemFile.deleteOnExit() + + assertThat(resultPrivateKey, notNullValue()) + MatcherAssert.assertThat(resultPrivateKey.getAlgorithm(), is("RSA")) + MatcherAssert.assertThat(resultPrivateKey.getFormat(), is("PKCS#8")) + } + + @Test + void testParsePrivateKey() { + + def apiClient = mock(ApiClient) + def clientConfiguration = mock(ClientConfiguration) + + PrivateKey generatedPrivateKey = generatePrivateKey("RSA", 2048) + File privateKeyPemFile = writePrivateKeyToPemFile(generatedPrivateKey, "privateKey") + Reader reader = new BufferedReader(new FileReader(privateKeyPemFile)) + + PrivateKey parsedPrivateKey = getAccessTokenRetrieverServiceInstance(clientConfiguration, apiClient).parsePrivateKey(reader) + + privateKeyPemFile.deleteOnExit() + + assertThat(parsedPrivateKey, notNullValue()) + MatcherAssert.assertThat(parsedPrivateKey.getAlgorithm(), is("RSA")) + MatcherAssert.assertThat(parsedPrivateKey.getFormat(), is("PKCS#8")) + } + + @Test + void testCreateSignedJWT() { + + def clientConfig = mock(ClientConfiguration) + def apiClient = mock(ApiClient) + + PrivateKey generatedPrivateKey = generatePrivateKey("RSA", 2048) + File privateKeyPemFile = writePrivateKeyToPemFile(generatedPrivateKey, "privateKey") + + String baseUrl = "https://sample.okta.com" + BaseUrlResolver baseUrlResolver = new BaseUrlResolver() { + @Override + String getBaseUrl() { + return baseUrl + } + } + + when(clientConfig.getBaseUrl()).thenReturn(baseUrl) + when(clientConfig.getClientId()).thenReturn("client12345") + when(clientConfig.getPrivateKey()).thenReturn(privateKeyPemFile.path) + when(clientConfig.getKid()).thenReturn("kid-value") + when(clientConfig.getBaseUrlResolver()).thenReturn(baseUrlResolver) + when(clientConfig.getClientCredentialsResolver()).thenReturn( + new DefaultClientCredentialsResolver({ -> Optional.empty() })) + + String signedJwt = getAccessTokenRetrieverServiceInstance(clientConfig, apiClient).createSignedJWT() + + privateKeyPemFile.deleteOnExit() + + assertThat(signedJwt, notNullValue()) + + // decode the signed jwt and verify + Claims claims = Jwts.parserBuilder() + .setSigningKey(generatedPrivateKey) + .build() + .parseClaimsJws(signedJwt).getBody() + + assertThat(claims, notNullValue()) + + assertEquals(claims.get("aud"), clientConfig.getBaseUrl() + "/oauth2/v1/token") + assertThat(claims.get("iat"), notNullValue()) + assertThat(claims.get("exp"), notNullValue()) + assertEquals(Integer.valueOf(claims.get("exp") as String) - Integer.valueOf(claims.get("iat") as String), 3000, + "token expiry time is not 50 minutes") + assertThat(claims.get("iss"), notNullValue()) + assertEquals(claims.get("iss"), clientConfig.getClientId(), "iss must be equal to client id") + assertThat(claims.get("sub"), notNullValue()) + assertEquals(claims.get("sub"), clientConfig.getClientId(), "sub must be equal to client id") + assertThat(claims.get("jti"), notNullValue()) + + Header header = Jwts.parser() + .setSigningKey(generatedPrivateKey) + .parseClaimsJws(signedJwt) + .getHeader() + + assertThat(header.get("kid"), notNullValue()) + assertThat(header.get("kid"), is("kid-value")) + } + + @Test + void testCreateSignedJWTUsingPrivateKeyFromString() { + + def apiClient = mock(ApiClient) + def clientConfig = mock(ClientConfiguration) + + PrivateKey generatedPrivateKey = generatePrivateKey("RSA", 2048) + + String baseUrl = "https://sample.okta.com" + BaseUrlResolver baseUrlResolver = new BaseUrlResolver() { + @Override + String getBaseUrl() { + return baseUrl + } + } + + when(clientConfig.getBaseUrl()).thenReturn(baseUrl) + when(clientConfig.getClientId()).thenReturn("client12345") + when(clientConfig.getPrivateKey()).thenReturn(createPemFileContent(generatedPrivateKey)) + when(clientConfig.getBaseUrlResolver()).thenReturn(baseUrlResolver) + when(clientConfig.getClientCredentialsResolver()).thenReturn( + new DefaultClientCredentialsResolver({ -> Optional.empty() })) + + String signedJwt = getAccessTokenRetrieverServiceInstance(clientConfig, apiClient).createSignedJWT() + + assertThat(signedJwt, notNullValue()) + } + + @Test(expectedExceptions = OAuth2TokenRetrieverException.class) + void testGetOAuth2TokenRetrieverRuntimeException() { + + def tokenClient = mock(ApiClient) + def clientConfig = mock(ClientConfiguration) + + PrivateKey generatedPrivateKey = generatePrivateKey("RSA", 2048) + File privateKeyPemFile = writePrivateKeyToPemFile(generatedPrivateKey, "privateKey") + + String baseUrl = "https://sample.okta.com" + BaseUrlResolver baseUrlResolver = new BaseUrlResolver() { + @Override + String getBaseUrl() { + return baseUrl + } + } + + when(clientConfig.getBaseUrl()).thenReturn(baseUrl) + when(clientConfig.getClientId()).thenReturn("client12345") + when(clientConfig.getPrivateKey()).thenReturn(privateKeyPemFile.path) + when(clientConfig.getBaseUrlResolver()).thenReturn(baseUrlResolver) + when(clientConfig.getClientCredentialsResolver()).thenReturn( + new DefaultClientCredentialsResolver({ -> Optional.empty() })) + + def accessTokenRetrieverService = new AccessTokenRetrieverServiceImpl(clientConfig, tokenClient) + accessTokenRetrieverService.getOAuth2AccessToken() + + verify(tokenClient, times(1)) + } + + @Test(expectedExceptions = OAuth2HttpException.class) + void testGetOAuth2ResourceException() { + + def apiClient = mock(ApiClient) + def clientConfig = mock(ClientConfiguration) + + PrivateKey generatedPrivateKey = generatePrivateKey("RSA", 2048) + File privateKeyPemFile = writePrivateKeyToPemFile(generatedPrivateKey, "privateKey") + + String baseUrl = "https://sample.okta.com" + BaseUrlResolver baseUrlResolver = new BaseUrlResolver() { + @Override + String getBaseUrl() { + return baseUrl + } + } + + when(clientConfig.getBaseUrl()).thenReturn(baseUrl) + when(clientConfig.getClientId()).thenReturn("client12345") + when(clientConfig.getPrivateKey()).thenReturn(privateKeyPemFile.path) + when(clientConfig.getBaseUrlResolver()).thenReturn(baseUrlResolver) + when(clientConfig.getClientCredentialsResolver()).thenReturn( + new DefaultClientCredentialsResolver({ -> Optional.empty() })) + + Error defaultError = new Error() { + @Override + int getStatus() { + return 401 + } + + @Override + String getCode() { + return null + } + + @Override + String getMessage() { + return "Unauthorized" + } + + @Override + String getId() { + return null + } + + @Override + List getCauses() { + return null + } + + @Override + Map> getHeaders() { + return null + } + } + + ResourceException resourceException = new ResourceException(defaultError) + + when(apiClient.invokeAPI(anyString(), + anyObject(), + ArgumentMatchers.any() as Map, + ArgumentMatchers.any() as MultiValueMap, + any(), + ArgumentMatchers.any() as HttpHeaders, + ArgumentMatchers.any() as MultiValueMap, + ArgumentMatchers.any() as MultiValueMap, + ArgumentMatchers.any() as List, + ArgumentMatchers.any() as MediaType, + ArgumentMatchers.any() as String[], + any())).thenThrow(resourceException) + + def accessTokenRetrieverService = new AccessTokenRetrieverServiceImpl(clientConfig, apiClient) + accessTokenRetrieverService.getOAuth2AccessToken() + + verify(apiClient, times(1)).invokeAPI() + } + + // helper methods + + PrivateKey generatePrivateKey(String algorithm, int keySize) { + KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance(algorithm) + keyPairGenerator.initialize(keySize) + KeyPair keyPair = keyPairGenerator.generateKeyPair() + PrivateKey privateKey = keyPair.getPrivate() + return privateKey + } + + String createPemFileContent(PrivateKey privateKey) { + String encodedString = "-----BEGIN PRIVATE KEY-----\n" + encodedString = encodedString + Base64.getEncoder().encodeToString(privateKey.getEncoded()) + "\n" + encodedString = encodedString + "-----END PRIVATE KEY-----\n" + return encodedString + + } + File writePrivateKeyToPemFile(PrivateKey privateKey, String fileNamePrefix) { + File privateKeyPemFile = File.createTempFile(fileNamePrefix,".pem") + privateKeyPemFile.write(createPemFileContent(privateKey)) + return privateKeyPemFile + } + + AccessTokenRetrieverServiceImpl getAccessTokenRetrieverServiceInstance(ClientConfiguration clientConfiguration, ApiClient apiClient) { + if (clientConfiguration == null) { + ClientConfiguration cc = new ClientConfiguration() + cc.setBaseUrlResolver(new BaseUrlResolver() { + @Override + String getBaseUrl() { + return "https://sample.okta.com" + } + }) + cc.setClientCredentialsResolver(new DefaultClientCredentialsResolver({ -> Optional.empty() })) + return new AccessTokenRetrieverServiceImpl(cc, apiClient) + } + + return new AccessTokenRetrieverServiceImpl(clientConfiguration, apiClient) + } + + @Test + void testConvertPemKeyToRsaPrivateKey() { + ApiClient apiClient = mock(ApiClient) + + DefaultClientBuilder oktaClient = (DefaultClientBuilder) Clients.builder() + .setOrgUrl("https://sample.okta.com") + .setAuthorizationMode(AuthorizationMode.PRIVATE_KEY) + .setPrivateKey(RSAKey.parseFromPEMEncodedObjects(PRIVATE_KEY).toRSAKey().toPrivateKey()) + + ClientConfiguration clientConfiguration = oktaClient.getClientConfiguration() + + assertEquals(clientConfiguration.getPrivateKey(), RSA_PRIVATE_KEY) + + String signedJwt = getAccessTokenRetrieverServiceInstance(clientConfiguration, apiClient).createSignedJWT() + assertThat(signedJwt, notNullValue()) + } + + @Test + void testParseRsaPrivateKey() { + ApiClient apiClient = mock(ApiClient) + + DefaultClientBuilder oktaClient = (DefaultClientBuilder) Clients.builder() + .setOrgUrl("https://sample.okta.com") + .setAuthorizationMode(AuthorizationMode.PRIVATE_KEY) + .setPrivateKey(RSAKey.parseFromPEMEncodedObjects(RSA_PRIVATE_KEY).toRSAKey().toPrivateKey()) + + String signedJwt = getAccessTokenRetrieverServiceInstance(oktaClient.getClientConfiguration(), apiClient).createSignedJWT() + assertThat(signedJwt, notNullValue()) + } + + @Test(expectedExceptions = PEMException.class) + void testParsePemKeyAsRsaPrivateKey() { + ClientConfiguration clientConfigMock = mock(ClientConfiguration) + ApiClient apiClient = mock(ApiClient) + + BaseUrlResolver baseUrlResolver = { -> "https://sample.okta.com" } + when(clientConfigMock.getBaseUrlResolver()).thenReturn(baseUrlResolver) + when(clientConfigMock.getPrivateKey()).thenReturn(PRIVATE_KEY.replaceAll(" PRIVATE ", " RSA PRIVATE ")) + + String signedJwt = getAccessTokenRetrieverServiceInstance(clientConfigMock, apiClient).createSignedJWT() + assertThat(signedJwt, notNullValue()) + } +} \ No newline at end of file