From 6aba34cac5eed62e491b1f274075d1ecf96eb973 Mon Sep 17 00:00:00 2001 From: Arvind Krishnakumar <61501885+arvindkrishnakumar-okta@users.noreply.github.com> Date: Wed, 29 Apr 2020 20:56:07 -0700 Subject: [PATCH] Add OAuth2 Support (#354) * Added OAuth 2.0 Support --- THIRD-PARTY-NOTICES | 5 +- .../okta/sdk/client/AuthenticationScheme.java | 5 +- .../okta/sdk/client/AuthorizationMode.java | 66 +++++ .../com/okta/sdk/client/ClientBuilder.java | 53 ++++ .../main/java/quickstart/ReadmeSnippets.java | 5 +- impl/pom.xml | 22 ++ .../sdk/impl/client/DefaultClientBuilder.java | 112 +++++++- .../sdk/impl/config/ClientConfiguration.java | 64 ++++- .../config/DefaultEnvVarNameConverter.java | 6 +- .../authc/OAuth2RequestAuthenticator.java | 75 ++++++ .../oauth2/AccessTokenRetrieverService.java | 36 +++ .../AccessTokenRetrieverServiceImpl.java | 248 +++++++++++++++++ .../sdk/impl/oauth2/OAuth2AccessToken.java | 102 +++++++ .../impl/oauth2/OAuth2ClientCredentials.java | 73 +++++ .../sdk/impl/oauth2/OAuth2HttpException.java | 29 ++ .../sdk/impl/oauth2/OAuth2TokenClient.java | 30 +++ .../oauth2/OAuth2TokenRetrieverException.java | 30 +++ .../client/DefaultClientBuilderTest.groovy | 196 +++++++++++++- ...entBuilderTestCustomCredentialsTest.groovy | 10 + ...RequestAuthenticatorConcurrencyTest.groovy | 87 ++++++ .../OAuth2RequestAuthenticatorTest.groovy | 223 +++++++++++++++ ...AccessTokenRetrieverServiceImplTest.groovy | 255 ++++++++++++++++++ .../okta/sdk/tests/it/ApplicationsIT.groovy | 3 +- .../sdk/tests/it/util/ClientProvider.groovy | 1 - pom.xml | 43 ++- src/findbugs/findbugs-exclude.xml | 2 +- src/license/NOTICE.template | 2 +- src/license/mapping.xml | 7 +- 28 files changed, 1744 insertions(+), 46 deletions(-) create mode 100644 api/src/main/java/com/okta/sdk/client/AuthorizationMode.java create mode 100644 impl/src/main/java/com/okta/sdk/impl/http/authc/OAuth2RequestAuthenticator.java create mode 100644 impl/src/main/java/com/okta/sdk/impl/oauth2/AccessTokenRetrieverService.java create mode 100644 impl/src/main/java/com/okta/sdk/impl/oauth2/AccessTokenRetrieverServiceImpl.java create mode 100644 impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2AccessToken.java create mode 100644 impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2ClientCredentials.java create mode 100644 impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2HttpException.java create mode 100644 impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2TokenClient.java create mode 100644 impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2TokenRetrieverException.java create mode 100644 impl/src/test/groovy/com/okta/sdk/impl/http/OAuth2RequestAuthenticatorConcurrencyTest.groovy create mode 100644 impl/src/test/groovy/com/okta/sdk/impl/http/OAuth2RequestAuthenticatorTest.groovy create mode 100644 impl/src/test/groovy/com/okta/sdk/impl/oauth2/AccessTokenRetrieverServiceImplTest.groovy diff --git a/THIRD-PARTY-NOTICES b/THIRD-PARTY-NOTICES index 95ee51faba6..9729fe2228c 100644 --- a/THIRD-PARTY-NOTICES +++ b/THIRD-PARTY-NOTICES @@ -1,4 +1,4 @@ -Copyright 2017 Okta +Copyright 2017-Present Okta, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,6 +21,8 @@ This project includes: AutoService under Apache 2.0 AutoService Processor under Apache 2.0 Bean Validation API under The Apache Software License, Version 2.0 + Bouncy Castle PKIX, CMS, EAC, TSP, PKCS, OCSP, CMP, and CRMF APIs under Bouncy Castle Licence + Bouncy Castle Provider under Bouncy Castle Licence commonmark-java core under BSD 2-Clause License Commons CLI under The Apache Software License, Version 2.0 Commons IO under The Apache Software License, Version 2.0 @@ -33,6 +35,7 @@ This project includes: JavaMail API jar under CDDL or GPLv2+CE javax.annotation API under CDDL + GPLv2 with classpath exception JCL 1.2 implemented over SLF4J under Apache License, Version 2.0 + JJWT :: API under Apache License, Version 2.0 jmustache under The (New) BSD License Joda-Time under Apache 2 JOpt Simple under The MIT License diff --git a/api/src/main/java/com/okta/sdk/client/AuthenticationScheme.java b/api/src/main/java/com/okta/sdk/client/AuthenticationScheme.java index 35d6c14f630..77fe8b6bba3 100644 --- a/api/src/main/java/com/okta/sdk/client/AuthenticationScheme.java +++ b/api/src/main/java/com/okta/sdk/client/AuthenticationScheme.java @@ -26,14 +26,15 @@ * The Authentication Scheme setting is helpful in cases where the code is run in a platform where the header information for * outgoing HTTP requests is modified and thus causing communication issues. *

- * The SSWS (Okta session bearer token) should be used for the management SDK, {code NONE} should be used for non - * authenticated requests. + * One of SSWS (Okta session bearer token) (or) OAUTH2 authentication schemes should be used for the management SDK, {@code NONE} + * should be used for unauthenticated requests. * * @since 0.5.0 */ public enum AuthenticationScheme { SSWS("com.okta.sdk.impl.http.authc.SswsAuthenticator"), //SSWS Authentication + OAUTH2_PRIVATE_KEY("com.okta.sdk.impl.http.authc.OAuth2RequestAuthenticator"), //OAuth2 NONE(DisabledAuthenticator.class); private final String requestAuthenticatorClassName; diff --git a/api/src/main/java/com/okta/sdk/client/AuthorizationMode.java b/api/src/main/java/com/okta/sdk/client/AuthorizationMode.java new file mode 100644 index 00000000000..f4914778e53 --- /dev/null +++ b/api/src/main/java/com/okta/sdk/client/AuthorizationMode.java @@ -0,0 +1,66 @@ +/* + * Copyright 2020-Present Okta, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.okta.sdk.client; + +import java.util.HashMap; +import java.util.Map; + +/** + * Enumeration that defines the mapping between available Authentication schemes and Authorization modes. + */ +public enum AuthorizationMode { + + SSWS("SSWS", AuthenticationScheme.SSWS), // SSWS + PRIVATE_KEY("PrivateKey", AuthenticationScheme.OAUTH2_PRIVATE_KEY), // OAuth2 + NONE("NONE", AuthenticationScheme.NONE); // None + + private final String label; + private final AuthenticationScheme authenticationScheme; + + private static final Map lookup = new HashMap<>(); + + static { + for (AuthorizationMode authorizationMode : AuthorizationMode.values()) { + lookup.put(authorizationMode.getAuthenticationScheme(), authorizationMode); + } + } + + AuthorizationMode(String label, AuthenticationScheme authenticationScheme) { + this.label = label; + this.authenticationScheme = authenticationScheme; + } + + public String getLabel() { + return this.label; + } + + public AuthenticationScheme getAuthenticationScheme() { + return this.authenticationScheme; + } + + public static AuthorizationMode get(AuthenticationScheme authenticationScheme) { + return lookup.get(authenticationScheme); + } + + public static AuthorizationMode getAuthorizationMode(String label) { + for (AuthorizationMode authorizationMode : values()) { + if (authorizationMode.getLabel().equals(label)) { + return authorizationMode; + } + } + throw new IllegalArgumentException(); + } +} diff --git a/api/src/main/java/com/okta/sdk/client/ClientBuilder.java b/api/src/main/java/com/okta/sdk/client/ClientBuilder.java index 58d254c50cd..c1f2e288789 100644 --- a/api/src/main/java/com/okta/sdk/client/ClientBuilder.java +++ b/api/src/main/java/com/okta/sdk/client/ClientBuilder.java @@ -19,6 +19,8 @@ import com.okta.sdk.authc.credentials.ClientCredentials; import com.okta.sdk.cache.CacheManager; +import java.util.Set; + /** * A Builder design pattern used to * construct {@link com.okta.sdk.client.Client} instances. @@ -206,6 +208,10 @@ public interface ClientBuilder { String DEFAULT_CLIENT_PROXY_HOST_PROPERTY_NAME = "okta.client.proxy.host"; String DEFAULT_CLIENT_PROXY_USERNAME_PROPERTY_NAME = "okta.client.proxy.username"; String DEFAULT_CLIENT_PROXY_PASSWORD_PROPERTY_NAME = "okta.client.proxy.password"; + String DEFAULT_CLIENT_AUTHORIZATION_MODE_PROPERTY_NAME = "okta.client.authorizationMode"; + String DEFAULT_CLIENT_ID_PROPERTY_NAME = "okta.client.clientId"; + String DEFAULT_CLIENT_SCOPES_PROPERTY_NAME = "okta.client.scopes"; + String DEFAULT_CLIENT_PRIVATE_KEY_PROPERTY_NAME = "okta.client.privateKey"; String DEFAULT_CLIENT_REQUEST_TIMEOUT_PROPERTY_NAME = "okta.client.requestTimeout"; String DEFAULT_CLIENT_RETRY_MAX_ATTEMPTS_PROPERTY_NAME = "okta.client.rateLimit.maxRetries"; String DEFAULT_CLIENT_TESTING_DISABLE_HTTPS_CHECK_PROPERTY_NAME = "okta.testing.disableHttpsCheck"; @@ -306,9 +312,56 @@ public interface ClientBuilder { * * @param authenticationScheme the type of authentication to be used for communication with the Okta API server. * @return the ClientBuilder instance for method chaining + * + * @deprecated since 1.6.0 use {@link #setAuthorizationMode(AuthorizationMode)} to indicate the authentication scheme. */ + @Deprecated ClientBuilder setAuthenticationScheme(AuthenticationScheme authenticationScheme); + /** + * Allows specifying an authorization mode. + * + * @param authorizationMode mode of authorization for requests to the Okta API server. + * @return the ClientBuilder instance for method chaining. + * + * @since 1.6.0 + */ + ClientBuilder setAuthorizationMode(AuthorizationMode authorizationMode); + + /** + * Allows specifying a list of scopes directly instead of relying on the + * default location + override/fallback behavior defined in the {@link ClientBuilder documentation above}. + * + * @param scopes set of scopes for which the client requests access. + * @return the ClientBuilder instance for method chaining. + * + * @since 1.6.0 + */ + ClientBuilder setScopes(Set scopes); + + /** + * Allows specifying the private key (PEM file) path (for private key jwt authentication) directly instead + * of relying on the default location + override/fallback behavior defined + * in the {@link ClientBuilder documentation above}. + * + * @param privateKey the fully qualified string path to the private key (PEM file). + * @return the ClientBuilder instance for method chaining. + * + * @since 1.6.0 + */ + ClientBuilder setPrivateKey(String privateKey); + + /** + * Allows specifying the client ID instead of relying on the default location + override/fallback behavior defined + * in the {@link ClientBuilder documentation above}. + * + * @param clientId string representing the client ID. + * @return the ClientBuilder instance for method chaining. + * + * @since 1.6.0 + */ + ClientBuilder setClientId(String clientId); + /** * Sets both the timeout until a connection is established and the socket timeout (i.e. a maximum period of inactivity * between two consecutive data packets). A timeout value of zero is interpreted as an infinite timeout. diff --git a/examples/quickstart/src/main/java/quickstart/ReadmeSnippets.java b/examples/quickstart/src/main/java/quickstart/ReadmeSnippets.java index 2cd87f3952f..89304ba7a87 100644 --- a/examples/quickstart/src/main/java/quickstart/ReadmeSnippets.java +++ b/examples/quickstart/src/main/java/quickstart/ReadmeSnippets.java @@ -46,7 +46,7 @@ /** * Example snippets used for this projects README.md. *

- * Manually run {@code mvn okta-code-snippet:snip} after chaging this file to update the README.md. + * Manually run {@code mvn okta-code-snippet:snip} after changing this file to update the README.md. */ @SuppressWarnings({"unused"}) public class ReadmeSnippets { @@ -219,4 +219,5 @@ private void disableCaching() { .setCacheManager(Caches.newDisabledCacheManager()) .build(); } -} \ No newline at end of file +} + diff --git a/impl/pom.xml b/impl/pom.xml index 847840aaced..c36b4906bfd 100644 --- a/impl/pom.xml +++ b/impl/pom.xml @@ -56,6 +56,28 @@ org.yaml snakeyaml + + org.bouncycastle + bcprov-jdk15on + + + org.bouncycastle + bcpkix-jdk15on + + + io.jsonwebtoken + jjwt-api + + + io.jsonwebtoken + jjwt-impl + runtime + + + io.jsonwebtoken + jjwt-jackson + runtime + javax.annotation javax.annotation-api diff --git a/impl/src/main/java/com/okta/sdk/impl/client/DefaultClientBuilder.java b/impl/src/main/java/com/okta/sdk/impl/client/DefaultClientBuilder.java index d6f6e2b0c12..71a48041ea6 100644 --- a/impl/src/main/java/com/okta/sdk/impl/client/DefaultClientBuilder.java +++ b/impl/src/main/java/com/okta/sdk/impl/client/DefaultClientBuilder.java @@ -21,17 +21,18 @@ import com.okta.commons.lang.Assert; import com.okta.commons.lang.Classes; import com.okta.commons.lang.Strings; +import com.okta.sdk.authc.credentials.ClientCredentials; import com.okta.sdk.cache.CacheConfigurationBuilder; import com.okta.sdk.cache.CacheManager; import com.okta.sdk.cache.CacheManagerBuilder; import com.okta.sdk.cache.Caches; import com.okta.sdk.client.AuthenticationScheme; +import com.okta.sdk.client.AuthorizationMode; import com.okta.sdk.client.Client; import com.okta.sdk.client.ClientBuilder; import com.okta.sdk.client.Proxy; import com.okta.sdk.impl.api.ClientCredentialsResolver; import com.okta.sdk.impl.api.DefaultClientCredentialsResolver; -import com.okta.sdk.authc.credentials.ClientCredentials; import com.okta.sdk.impl.config.ClientConfiguration; import com.okta.sdk.impl.config.EnvironmentVariablesPropertiesSource; import com.okta.sdk.impl.config.OptionalPropertiesSource; @@ -44,15 +45,25 @@ import com.okta.sdk.impl.io.DefaultResourceFactory; import com.okta.sdk.impl.io.Resource; import com.okta.sdk.impl.io.ResourceFactory; +import com.okta.sdk.impl.oauth2.AccessTokenRetrieverService; +import com.okta.sdk.impl.oauth2.AccessTokenRetrieverServiceImpl; +import com.okta.sdk.impl.oauth2.OAuth2ClientCredentials; import com.okta.sdk.impl.util.DefaultBaseUrlResolver; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; +import java.nio.file.Files; +import java.nio.file.LinkOption; +import java.nio.file.Path; +import java.nio.file.Paths; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; +import java.util.HashSet; import java.util.LinkedHashMap; import java.util.Map; +import java.util.Set; import java.util.concurrent.TimeUnit; /** @@ -92,13 +103,14 @@ public class DefaultClientBuilder implements ClientBuilder { SYSPROPS_TOKEN }; - private CacheManager cacheManager; private ClientCredentials clientCredentials; private boolean allowNonHttpsForTesting = false; private ClientConfiguration clientConfig = new ClientConfiguration(); + private AccessTokenRetrieverService accessTokenRetrieverService; + public DefaultClientBuilder() { this(new DefaultResourceFactory()); } @@ -210,6 +222,23 @@ else if (SYSPROPS_TOKEN.equalsIgnoreCase(location)) { clientConfig.setProxyPassword(props.get(DEFAULT_CLIENT_PROXY_PASSWORD_PROPERTY_NAME)); } + if (Strings.hasText(props.get(DEFAULT_CLIENT_AUTHORIZATION_MODE_PROPERTY_NAME))) { + clientConfig.setAuthorizationMode(AuthorizationMode.getAuthorizationMode(props.get(DEFAULT_CLIENT_AUTHORIZATION_MODE_PROPERTY_NAME))); + } + + if (Strings.hasText(props.get(DEFAULT_CLIENT_ID_PROPERTY_NAME))) { + clientConfig.setClientId(props.get(DEFAULT_CLIENT_ID_PROPERTY_NAME)); + } + + if (Strings.hasText(props.get(DEFAULT_CLIENT_SCOPES_PROPERTY_NAME))) { + Set scopes = new HashSet<>(Arrays.asList(props.get(DEFAULT_CLIENT_SCOPES_PROPERTY_NAME).split(" "))); + clientConfig.setScopes(scopes); + } + + if (Strings.hasText(props.get(DEFAULT_CLIENT_PRIVATE_KEY_PROPERTY_NAME))) { + clientConfig.setPrivateKey(props.get(DEFAULT_CLIENT_PRIVATE_KEY_PROPERTY_NAME)); + } + if (Strings.hasText(props.get(DEFAULT_CLIENT_REQUEST_TIMEOUT_PROPERTY_NAME))) { clientConfig.setRetryMaxElapsed(Integer.parseInt(props.get(DEFAULT_CLIENT_REQUEST_TIMEOUT_PROPERTY_NAME))); } @@ -239,8 +268,7 @@ public ClientBuilder setCacheManager(CacheManager cacheManager) { @Override public ClientBuilder setAuthenticationScheme(AuthenticationScheme authenticationScheme) { - this.clientConfig.setAuthenticationScheme(authenticationScheme); - return this; + return setAuthorizationMode(AuthorizationMode.get(authenticationScheme)); } @Override @@ -307,21 +335,46 @@ public Client build() { this.cacheManager = cacheManagerBuilder.build(); } - if (this.clientConfig.getClientCredentialsResolver() == null && this.clientCredentials != null) { - this.clientConfig.setClientCredentialsResolver(new DefaultClientCredentialsResolver(this.clientCredentials)); - } - else if (this.clientConfig.getClientCredentialsResolver() == null) { - this.clientConfig.setClientCredentialsResolver(new DefaultClientCredentialsResolver(clientConfig)); - } - if (this.clientConfig.getBaseUrlResolver() == null) { ConfigurationValidator.assertOrgUrl(this.clientConfig.getBaseUrl(), allowNonHttpsForTesting); this.clientConfig.setBaseUrlResolver(new DefaultBaseUrlResolver(this.clientConfig.getBaseUrl())); } + if (!isOAuth2Flow()) { + if (this.clientConfig.getClientCredentialsResolver() == null && this.clientCredentials != null) { + this.clientConfig.setClientCredentialsResolver(new DefaultClientCredentialsResolver(this.clientCredentials)); + } else if (this.clientConfig.getClientCredentialsResolver() == null) { + this.clientConfig.setClientCredentialsResolver(new DefaultClientCredentialsResolver(this.clientConfig)); + } + } else { + this.clientConfig.setAuthenticationScheme(AuthenticationScheme.OAUTH2_PRIVATE_KEY); + + validateOAuth2ClientConfig(this.clientConfig); + + accessTokenRetrieverService = new AccessTokenRetrieverServiceImpl(clientConfig); + + OAuth2ClientCredentials oAuth2ClientCredentials = + new OAuth2ClientCredentials(accessTokenRetrieverService); + + this.clientConfig.setClientCredentialsResolver(new DefaultClientCredentialsResolver(oAuth2ClientCredentials)); + } + return new DefaultClient(clientConfig, cacheManager); } + /** + * @since 1.6.0 + */ + private void validateOAuth2ClientConfig(ClientConfiguration clientConfiguration) { + Assert.notNull(clientConfiguration.getClientId(), "clientId cannot be null"); + Assert.isTrue(clientConfiguration.getScopes() != null && !clientConfiguration.getScopes().isEmpty(), + "At least one scope is required"); + Assert.notNull(clientConfiguration.getPrivateKey(), "privateKey cannot be null"); + Path privateKeyPemFilePath = Paths.get(clientConfiguration.getPrivateKey()); + boolean privateKeyPemFileExists = Files.exists(privateKeyPemFilePath, new LinkOption[]{ LinkOption.NOFOLLOW_LINKS }); + Assert.isTrue(privateKeyPemFileExists, "privateKey file does not exist"); + } + @Override public ClientBuilder setOrgUrl(String baseUrl) { ConfigurationValidator.assertOrgUrl(baseUrl, allowNonHttpsForTesting); @@ -329,8 +382,45 @@ public ClientBuilder setOrgUrl(String baseUrl) { return this; } + @Override + public ClientBuilder setAuthorizationMode(AuthorizationMode authorizationMode) { + this.clientConfig.setAuthorizationMode(authorizationMode); + this.clientConfig.setAuthenticationScheme(authorizationMode.getAuthenticationScheme()); + return this; + } + + @Override + public ClientBuilder setScopes(Set scopes) { + if (isOAuth2Flow()) { + Assert.isTrue(scopes != null && !scopes.isEmpty(), "At least one scope is required"); + this.clientConfig.setScopes(scopes); + } + return this; + } + + @Override + public ClientBuilder setPrivateKey(String privateKey) { + if (isOAuth2Flow()) { + Assert.notNull(privateKey, "Missing privateKey"); + this.clientConfig.setPrivateKey(privateKey); + } + return this; + } + + @Override + public ClientBuilder setClientId(String clientId) { + ConfigurationValidator.assertClientId(clientId); + this.clientConfig.setClientId(clientId); + return this; + } + + boolean isOAuth2Flow() { + return this.getClientConfiguration().getAuthorizationMode() == AuthorizationMode.PRIVATE_KEY; + } + // Used for testing, package private ClientConfiguration getClientConfiguration() { return clientConfig; } + } diff --git a/impl/src/main/java/com/okta/sdk/impl/config/ClientConfiguration.java b/impl/src/main/java/com/okta/sdk/impl/config/ClientConfiguration.java index 1b2a4d5aa51..c6ccc42f8e9 100644 --- a/impl/src/main/java/com/okta/sdk/impl/config/ClientConfiguration.java +++ b/impl/src/main/java/com/okta/sdk/impl/config/ClientConfiguration.java @@ -22,12 +22,15 @@ import com.okta.commons.lang.Strings; import com.okta.sdk.cache.CacheConfigurationBuilder; import com.okta.sdk.client.AuthenticationScheme; +import com.okta.sdk.client.AuthorizationMode; import com.okta.sdk.impl.api.ClientCredentialsResolver; import com.okta.sdk.impl.http.authc.DefaultRequestAuthenticatorFactory; import com.okta.sdk.impl.http.authc.RequestAuthenticatorFactory; +import java.util.HashSet; import java.util.LinkedHashMap; import java.util.Map; +import java.util.Set; /** * This class holds the default configuration properties. @@ -49,6 +52,10 @@ public class ClientConfiguration extends HttpClientConfiguration { private RequestAuthenticatorFactory requestAuthenticatorFactory = new DefaultRequestAuthenticatorFactory(); private AuthenticationScheme authenticationScheme; private BaseUrlResolver baseUrlResolver; + private AuthorizationMode authorizationMode; + private String clientId; + private Set scopes = new HashSet<>(); + private String privateKey; public String getApiToken() { return apiToken; @@ -90,6 +97,38 @@ public void setBaseUrlResolver(BaseUrlResolver baseUrlResolver) { this.baseUrlResolver = baseUrlResolver; } + public AuthorizationMode getAuthorizationMode() { + return authorizationMode; + } + + public void setAuthorizationMode(AuthorizationMode authorizationMode) { + this.authorizationMode = authorizationMode; + } + + public String getClientId() { + return clientId; + } + + public void setClientId(String clientId) { + this.clientId = clientId; + } + + public Set getScopes() { + return scopes; + } + + public void setScopes(Set scopes) { + this.scopes = scopes; + } + + public String getPrivateKey() { + return privateKey; + } + + public void setPrivateKey(String privateKey) { + this.privateKey = privateKey; + } + public boolean isCacheManagerEnabled() { return cacheManagerEnabled; } @@ -162,17 +201,20 @@ public String getBaseUrl() { @Override public String toString() { - return "ClientConfiguration{" + - ", cacheManagerTtl=" + cacheManagerTtl + - ", cacheManagerTti=" + cacheManagerTti + - ", cacheManagerCaches=" + cacheManagerCaches + - ", baseUrl='" + getBaseUrl() + '\'' + - ", connectionTimeout=" + getConnectionTimeout() + - ", requestAuthenticator=" + getRequestAuthenticator() + - ", retryMaxElapsed=" + getRetryMaxElapsed() + - ", retryMaxAttempts=" + getRetryMaxAttempts() + - ", proxy=" + getProxy() + - '}'; + return "ClientConfiguration {cacheManagerTtl=" + cacheManagerTtl + + ", cacheManagerTti=" + cacheManagerTti + + ", cacheManagerCaches=" + cacheManagerCaches + + ", baseUrl='" + getBaseUrl() + '\'' + + ", authorizationMode=" + getAuthorizationMode() + + ", clientId=" + getClientId() + + ", scopes=" + getScopes() + + ", privateKey=" + ((getPrivateKey() != null) ? "xxxxx" : null) + + ", connectionTimeout=" + getConnectionTimeout() + + ", requestAuthenticator=" + getRequestAuthenticator() + + ", retryMaxElapsed=" + getRetryMaxElapsed() + + ", retryMaxAttempts=" + getRetryMaxAttempts() + + ", proxy=" + getProxy() + + " }"; } } diff --git a/impl/src/main/java/com/okta/sdk/impl/config/DefaultEnvVarNameConverter.java b/impl/src/main/java/com/okta/sdk/impl/config/DefaultEnvVarNameConverter.java index 699ebf57828..a999c424d3e 100644 --- a/impl/src/main/java/com/okta/sdk/impl/config/DefaultEnvVarNameConverter.java +++ b/impl/src/main/java/com/okta/sdk/impl/config/DefaultEnvVarNameConverter.java @@ -39,7 +39,11 @@ public DefaultEnvVarNameConverter() { ClientBuilder.DEFAULT_CLIENT_AUTHENTICATION_SCHEME_PROPERTY_NAME, ClientBuilder.DEFAULT_CLIENT_REQUEST_TIMEOUT_PROPERTY_NAME, ClientBuilder.DEFAULT_CLIENT_RETRY_MAX_ATTEMPTS_PROPERTY_NAME, - ClientBuilder.DEFAULT_CLIENT_TESTING_DISABLE_HTTPS_CHECK_PROPERTY_NAME); + ClientBuilder.DEFAULT_CLIENT_TESTING_DISABLE_HTTPS_CHECK_PROPERTY_NAME, + ClientBuilder.DEFAULT_CLIENT_AUTHORIZATION_MODE_PROPERTY_NAME, + ClientBuilder.DEFAULT_CLIENT_ID_PROPERTY_NAME, + ClientBuilder.DEFAULT_CLIENT_SCOPES_PROPERTY_NAME, + ClientBuilder.DEFAULT_CLIENT_PRIVATE_KEY_PROPERTY_NAME); } private Map buildReverseLookupToMap(String... dottedPropertyNames) { diff --git a/impl/src/main/java/com/okta/sdk/impl/http/authc/OAuth2RequestAuthenticator.java b/impl/src/main/java/com/okta/sdk/impl/http/authc/OAuth2RequestAuthenticator.java new file mode 100644 index 00000000000..d802b4234a1 --- /dev/null +++ b/impl/src/main/java/com/okta/sdk/impl/http/authc/OAuth2RequestAuthenticator.java @@ -0,0 +1,75 @@ +/* + * Copyright 2020-Present Okta, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.okta.sdk.impl.http.authc; + +import com.okta.commons.http.Request; +import com.okta.commons.http.authc.RequestAuthenticationException; +import com.okta.commons.http.authc.RequestAuthenticator; +import com.okta.commons.lang.Assert; +import com.okta.sdk.authc.credentials.ClientCredentials; +import com.okta.sdk.impl.oauth2.OAuth2AccessToken; +import com.okta.sdk.impl.oauth2.OAuth2ClientCredentials; +import com.okta.sdk.impl.oauth2.OAuth2TokenRetrieverException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.security.InvalidKeyException; + +/** + * This implementation used by OAuth2 flow adds Bearer header with access token as the + * value in all outgoing requests. This has logic to fetch a new token and store it in + * {@link OAuth2ClientCredentials} object if the existing one has expired. + * + * @since 1.6.0 + */ +public class OAuth2RequestAuthenticator implements RequestAuthenticator { + private static final Logger log = LoggerFactory.getLogger(OAuth2RequestAuthenticator.class); + + private final ClientCredentials clientCredentials; + + public OAuth2RequestAuthenticator(ClientCredentials clientCredentials) { + Assert.notNull(clientCredentials, "clientCredentials may not be null"); + this.clientCredentials = clientCredentials; + } + + @Override + public void authenticate(Request request) throws RequestAuthenticationException { + OAuth2AccessToken oAuth2AccessToken = clientCredentials.getCredentials(); + + if (oAuth2AccessToken.hasExpired()) { + log.debug("OAuth2 access token expiry detected. Will fetch a new token from Authorization server"); + + synchronized (this) { + if (oAuth2AccessToken.hasExpired()) { + try { + OAuth2ClientCredentials oAuth2ClientCredentials = (OAuth2ClientCredentials) clientCredentials; + // fetch new token + oAuth2AccessToken = oAuth2ClientCredentials.getAccessTokenRetrieverService().getOAuth2AccessToken(); + // store the new token + oAuth2ClientCredentials.setCredentials(oAuth2AccessToken); + } catch (IOException | InvalidKeyException e) { + throw new OAuth2TokenRetrieverException("Failed to renew expired OAuth2 access token", e); + } + } + } + } + + // add Bearer header with token value + request.getHeaders().set(AUTHORIZATION_HEADER, "Bearer " + oAuth2AccessToken.getAccessToken()); + } + +} diff --git a/impl/src/main/java/com/okta/sdk/impl/oauth2/AccessTokenRetrieverService.java b/impl/src/main/java/com/okta/sdk/impl/oauth2/AccessTokenRetrieverService.java new file mode 100644 index 00000000000..995a7a3a118 --- /dev/null +++ b/impl/src/main/java/com/okta/sdk/impl/oauth2/AccessTokenRetrieverService.java @@ -0,0 +1,36 @@ +/* + * Copyright 2020-Present Okta, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.okta.sdk.impl.oauth2; + +import java.io.IOException; +import java.security.InvalidKeyException; + +/** + * Abstraction for OAuth2 access token retrieval service function. + * + * @since 1.6.0 + */ +public interface AccessTokenRetrieverService { + /** + * Obtain OAuth2 access token from Authorization Server endpoint. + * + * @return {@link OAuth2AccessToken} + * @throws IOException if problems are encountered extracting the input private key + * @throws InvalidKeyException if supplied private key is invalid + * @throws OAuth2TokenRetrieverException if token could not be retrieved + */ + OAuth2AccessToken getOAuth2AccessToken() throws IOException, InvalidKeyException, OAuth2TokenRetrieverException; +} diff --git a/impl/src/main/java/com/okta/sdk/impl/oauth2/AccessTokenRetrieverServiceImpl.java b/impl/src/main/java/com/okta/sdk/impl/oauth2/AccessTokenRetrieverServiceImpl.java new file mode 100644 index 00000000000..56c2c971df0 --- /dev/null +++ b/impl/src/main/java/com/okta/sdk/impl/oauth2/AccessTokenRetrieverServiceImpl.java @@ -0,0 +1,248 @@ +/* + * Copyright 2020-Present Okta, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.okta.sdk.impl.oauth2; + +import com.okta.commons.http.MediaType; +import com.okta.commons.http.authc.DisabledAuthenticator; +import com.okta.commons.lang.Assert; +import com.okta.sdk.client.AuthenticationScheme; +import com.okta.sdk.client.AuthorizationMode; +import com.okta.sdk.impl.api.DefaultClientCredentialsResolver; +import com.okta.sdk.impl.config.ClientConfiguration; +import com.okta.sdk.impl.error.DefaultError; +import com.okta.sdk.resource.ExtensibleResource; +import com.okta.sdk.resource.ResourceException; +import io.jsonwebtoken.Jwts; +import org.bouncycastle.asn1.pkcs.PrivateKeyInfo; +import org.bouncycastle.openssl.PEMKeyPair; +import org.bouncycastle.openssl.PEMParser; +import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.Reader; +import java.nio.charset.Charset; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.InvalidKeyException; +import java.security.KeyPair; +import java.security.PrivateKey; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Date; +import java.util.Optional; +import java.util.UUID; + +/** + * Implementation of {@link AccessTokenRetrieverService} interface. + * + * This has logic to fetch OAuth2 access token from the Authorization server endpoint. + * + * @since 1.6.0 + */ +public class AccessTokenRetrieverServiceImpl implements AccessTokenRetrieverService { + private static final Logger log = LoggerFactory.getLogger(AccessTokenRetrieverServiceImpl.class); + + private static final String TOKEN_URI = "/oauth2/v1/token"; + + private final ClientConfiguration tokenClientConfiguration; + private final OAuth2TokenClient tokenClient; + + public AccessTokenRetrieverServiceImpl(ClientConfiguration apiClientConfiguration) { + Assert.notNull(apiClientConfiguration, "apiClientConfiguration must not be null."); + ClientConfiguration tokenClientConfig = constructTokenClientConfig(apiClientConfiguration); + this.tokenClient = new OAuth2TokenClient(tokenClientConfig); + this.tokenClientConfiguration = tokenClientConfig; + } + + public AccessTokenRetrieverServiceImpl(ClientConfiguration apiClientConfiguration, OAuth2TokenClient tokenClient) { + Assert.notNull(apiClientConfiguration, "apiClientConfiguration must not be null."); + Assert.notNull(tokenClient, "tokenClient must not be null."); + this.tokenClient = tokenClient; + this.tokenClientConfiguration = constructTokenClientConfig(apiClientConfiguration); + } + + /** + * {@inheritDoc} + */ + @Override + public OAuth2AccessToken getOAuth2AccessToken() throws IOException, InvalidKeyException, OAuth2TokenRetrieverException { + log.debug("Attempting to get OAuth2 access token for client id {} from {}", + tokenClientConfiguration.getClientId(), tokenClientConfiguration.getBaseUrl() + TOKEN_URI); + + String signedJwt = createSignedJWT(); + String scope = String.join(" ", tokenClientConfiguration.getScopes()); + + try { + ExtensibleResource accessTokenResponse = tokenClient.http() + .addHeaderParameter("Accept", MediaType.APPLICATION_JSON_VALUE) + .addHeaderParameter("Content-Type", MediaType.APPLICATION_FORM_URLENCODED_VALUE) + .addQueryParameter("grant_type", "client_credentials") + .addQueryParameter("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") + .addQueryParameter("client_assertion", signedJwt) + .addQueryParameter("scope", scope) + .post(TOKEN_URI, ExtensibleResource.class); + + OAuth2AccessToken oAuth2AccessToken = new OAuth2AccessToken(); + oAuth2AccessToken.setTokenType(accessTokenResponse.getString(OAuth2AccessToken.TOKEN_TYPE_KEY)); + oAuth2AccessToken.setExpiresIn(accessTokenResponse.getInteger(OAuth2AccessToken.EXPIRES_IN_KEY)); + oAuth2AccessToken.setAccessToken(accessTokenResponse.getString(OAuth2AccessToken.ACCESS_TOKEN_KEY)); + oAuth2AccessToken.setScope(accessTokenResponse.getString(OAuth2AccessToken.SCOPE_KEY)); + + log.debug("Got OAuth2 access token for client id {} from {}", + tokenClientConfiguration.getClientId(), tokenClientConfiguration.getBaseUrl() + TOKEN_URI); + + return oAuth2AccessToken; + } catch (ResourceException e) { + //TODO: clean up the ugly casting and refactor code around it. + DefaultError defaultError = (DefaultError) e.getError(); + String errorMessage = defaultError.getString(OAuth2AccessToken.ERROR_KEY); + String errorDescription = defaultError.getString(OAuth2AccessToken.ERROR_DESCRIPTION); + defaultError.setMessage(errorMessage + " - " + errorDescription); + throw new OAuth2HttpException(defaultError, e, e.getStatus() == 401); + } catch (Exception e) { + throw new OAuth2TokenRetrieverException("Exception while trying to get " + + "OAuth2 access token for client id " + tokenClientConfiguration.getClientId(), e); + } + } + + /** + * Create signed JWT string with the supplied token client configuration details. + * + * @return signed JWT string + * @throws InvalidKeyException + * @throws IOException + */ + String createSignedJWT() throws InvalidKeyException, IOException { + String clientId = tokenClientConfiguration.getClientId(); + PrivateKey privateKey = parsePrivateKey(tokenClientConfiguration.getPrivateKey()); + Instant now = Instant.now(); + + String jwt = Jwts.builder() + .setAudience(tokenClientConfiguration.getBaseUrl() + TOKEN_URI) + .setIssuedAt(Date.from(now)) + .setExpiration(Date.from(now.plus(1L, ChronoUnit.HOURS))) + .setIssuer(clientId) + .setSubject(clientId) + .claim("jti", UUID.randomUUID().toString()) + .signWith(privateKey) + .compact(); + + return jwt; + } + + /** + * Parse private key from the supplied path. + * + * @param privateKeyFilePath + * @return {@link PrivateKey} + * @throws IOException + * @throws InvalidKeyException + */ + PrivateKey parsePrivateKey(String privateKeyFilePath) throws IOException, InvalidKeyException { + Assert.notNull(privateKeyFilePath, "privateKeyFilePath may not be null"); + + Path privateKeyPemFilePath = Paths.get(privateKeyFilePath); + Reader reader = Files.newBufferedReader(privateKeyPemFilePath, Charset.defaultCharset()); + + PrivateKey privateKey = getPrivateKeyFromPEM(reader); + String algorithm = privateKey.getAlgorithm(); + + if (!algorithm.equals("RSA") && + !algorithm.equals("EC")) { + throw new InvalidKeyException("Supplied privateKey is not an RSA or EC key - " + algorithm); + } + + return privateKey; + } + + /** + * Get Private key from input PEM file. + * + * @param reader + * @return {@link PrivateKey} + * @throws IOException + */ + PrivateKey getPrivateKeyFromPEM(Reader reader) throws IOException { + PrivateKey privateKey; + + try (PEMParser pemParser = new PEMParser(reader)) { + JcaPEMKeyConverter jcaPEMKeyConverter = new JcaPEMKeyConverter(); + Object pemContent = pemParser.readObject(); + + if (pemContent == null) { + throw new IllegalArgumentException("Invalid Private Key PEM file"); + } + + if (pemContent instanceof PEMKeyPair) { + PEMKeyPair pemKeyPair = (PEMKeyPair) pemContent; + KeyPair keyPair = jcaPEMKeyConverter.getKeyPair(pemKeyPair); + privateKey = keyPair.getPrivate(); + } else if (pemContent instanceof PrivateKeyInfo) { + PrivateKeyInfo privateKeyInfo = (PrivateKeyInfo) pemContent; + privateKey = jcaPEMKeyConverter.getPrivateKey(privateKeyInfo); + } else { + throw new IllegalArgumentException("Unsupported Private Key format '" + + pemContent.getClass().getSimpleName() + '"'); + } + } + + return privateKey; + } + + /** + * Create token client config from the supplied API client config. + * + * Token client needs to retry http 401 errors only once which is not the case with the API client. + * We therefore effect this token client specific config by setting 'retryMaxElapsed' + * & 'retryMaxAttempts' fields. + * + * @param apiClientConfiguration supplied in the client configuration. + * @return ClientConfiguration to be used by token retrieval client. + */ + ClientConfiguration constructTokenClientConfig(ClientConfiguration apiClientConfiguration) { + ClientConfiguration tokenClientConfiguration = new ClientConfiguration(); + + tokenClientConfiguration.setClientCredentialsResolver( + new DefaultClientCredentialsResolver(() -> Optional.empty())); + + tokenClientConfiguration.setRequestAuthenticator(new DisabledAuthenticator()); + + tokenClientConfiguration.setCacheManagerEnabled(false); + + if (apiClientConfiguration.getBaseUrlResolver() != null) + tokenClientConfiguration.setBaseUrlResolver(apiClientConfiguration.getBaseUrlResolver()); + + if (apiClientConfiguration.getProxy() != null) + tokenClientConfiguration.setProxy(apiClientConfiguration.getProxy()); + + tokenClientConfiguration.setAuthenticationScheme(AuthenticationScheme.NONE); + tokenClientConfiguration.setAuthorizationMode(AuthorizationMode.get(tokenClientConfiguration.getAuthenticationScheme())); + tokenClientConfiguration.setClientId(apiClientConfiguration.getClientId()); + tokenClientConfiguration.setScopes(apiClientConfiguration.getScopes()); + tokenClientConfiguration.setPrivateKey(apiClientConfiguration.getPrivateKey()); + + // setting this to '0' will disable this check and only 'retryMaxAttempts' will be effective + tokenClientConfiguration.setRetryMaxElapsed(0); + // retry only once for token requests (http 401 errors) + tokenClientConfiguration.setRetryMaxAttempts(1); + + return tokenClientConfiguration; + } + +} diff --git a/impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2AccessToken.java b/impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2AccessToken.java new file mode 100644 index 00000000000..8428022df9b --- /dev/null +++ b/impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2AccessToken.java @@ -0,0 +1,102 @@ +/* + * Copyright 2020-Present Okta, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.okta.sdk.impl.oauth2; + +import java.time.Duration; +import java.time.Instant; + +/** + * Represents the OAuth2 access token returned by Authorization server. + * + * @since 1.6.0 + */ +public class OAuth2AccessToken { + + /* Token body constants */ + public static final String TOKEN_TYPE_KEY = "token_type"; + public static final String EXPIRES_IN_KEY = "expires_in"; + public static final String ACCESS_TOKEN_KEY = "access_token"; + public static final String SCOPE_KEY = "scope"; + + /* Token error constants */ + public static final String ERROR_KEY = "error"; + public static final String ERROR_DESCRIPTION = "error_description"; + + private String tokenType; + + private Integer expiresIn; + + private String accessToken; + + private String scope; + + private Instant issuedAt = Instant.now(); + + public String getTokenType() { + return tokenType; + } + + public void setTokenType(String tokenType) { + this.tokenType = tokenType; + } + + public Integer getExpiresIn() { + return expiresIn; + } + + public void setExpiresIn(Integer expiresIn) { + this.expiresIn = expiresIn; + } + + public String getAccessToken() { + return accessToken; + } + + public void setAccessToken(String accessToken) { + this.accessToken = accessToken; + } + + public String getScope() { + return scope; + } + + public void setScope(String scope) { + this.scope = scope; + } + + public Instant getIssuedAt() { + return issuedAt; + } + + public boolean hasExpired() { + Duration duration = Duration.between(this.getIssuedAt(), Instant.now()); + return duration.getSeconds() >= this.getExpiresIn(); + } + + // for testing purposes + void expireNow() { + this.setExpiresIn(Integer.MIN_VALUE); + } + + @Override + public String toString() { + return "OAuth2AccessToken [tokenType=" + tokenType + + ", issuedAt=" + issuedAt + + ", expiresIn=" + expiresIn + + ", accessToken=xxxxx" + + ", scope=" + scope + "]"; + } +} diff --git a/impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2ClientCredentials.java b/impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2ClientCredentials.java new file mode 100644 index 00000000000..a514c2ce19e --- /dev/null +++ b/impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2ClientCredentials.java @@ -0,0 +1,73 @@ +/* + * Copyright 2020-Present Okta, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.okta.sdk.impl.oauth2; + +import com.okta.commons.lang.Assert; +import com.okta.sdk.authc.credentials.ClientCredentials; + +import java.io.IOException; +import java.security.InvalidKeyException; + +/** + * This implementation represents client credentials specific to OAuth2 Authentication scheme. + * + * @since 1.6.0 + */ +public class OAuth2ClientCredentials implements ClientCredentials { + + private OAuth2AccessToken oAuth2AccessToken; + private AccessTokenRetrieverService accessTokenRetrieverService; + + public OAuth2ClientCredentials(AccessTokenRetrieverService accessTokenRetrieverService) { + Assert.notNull(accessTokenRetrieverService, "accessTokenRetrieverService must not be null."); + this.accessTokenRetrieverService = accessTokenRetrieverService; + this.oAuth2AccessToken = eagerFetchOAuth2AccessToken(); + } + + private OAuth2AccessToken eagerFetchOAuth2AccessToken() { + OAuth2AccessToken accessToken; + + try { + accessToken = accessTokenRetrieverService.getOAuth2AccessToken(); + } catch (IOException | InvalidKeyException e) { + throw new OAuth2TokenRetrieverException("Failed to fetch OAuth2 access token eagerly", e); + } + + if (accessToken == null) { + throw new OAuth2TokenRetrieverException("Failed to fetch OAuth2 access token eagerly"); + } + + return accessToken; + } + + public OAuth2AccessToken getCredentials() { + return oAuth2AccessToken; + } + + public void setCredentials(OAuth2AccessToken oAuth2AccessToken) { + this.oAuth2AccessToken = oAuth2AccessToken; + } + + public AccessTokenRetrieverService getAccessTokenRetrieverService() { + return accessTokenRetrieverService; + } + + @Override + public String toString() { + // never ever print the secret + return ""; + } +} diff --git a/impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2HttpException.java b/impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2HttpException.java new file mode 100644 index 00000000000..abfcc1166d6 --- /dev/null +++ b/impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2HttpException.java @@ -0,0 +1,29 @@ +/* + * Copyright 2020-Present Okta, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.okta.sdk.impl.oauth2; + +import com.okta.commons.http.HttpException; +import com.okta.sdk.impl.error.DefaultError; + +/** + * @since 1.6.0 + */ +public class OAuth2HttpException extends HttpException { + + public OAuth2HttpException(DefaultError e, Throwable cause, boolean retryable) { + super(e.getMessage(), cause, retryable); + } +} diff --git a/impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2TokenClient.java b/impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2TokenClient.java new file mode 100644 index 00000000000..eb78e44894c --- /dev/null +++ b/impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2TokenClient.java @@ -0,0 +1,30 @@ +/* + * Copyright 2020-Present Okta, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.okta.sdk.impl.oauth2; + +import com.okta.sdk.cache.Caches; +import com.okta.sdk.impl.client.BaseClient; +import com.okta.sdk.impl.config.ClientConfiguration; + +/** + * @since 1.6.0 + */ +public class OAuth2TokenClient extends BaseClient { + + public OAuth2TokenClient(ClientConfiguration tokenClientConfig) { + super(tokenClientConfig, Caches.newDisabledCacheManager()); + } +} diff --git a/impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2TokenRetrieverException.java b/impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2TokenRetrieverException.java new file mode 100644 index 00000000000..d602b582888 --- /dev/null +++ b/impl/src/main/java/com/okta/sdk/impl/oauth2/OAuth2TokenRetrieverException.java @@ -0,0 +1,30 @@ +/* + * Copyright 2020-Present Okta, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.okta.sdk.impl.oauth2; + +/** + * @since 1.6.0 + */ +public class OAuth2TokenRetrieverException extends RuntimeException { + + public OAuth2TokenRetrieverException(String s) { + super(s); + } + + public OAuth2TokenRetrieverException(String s, Throwable cause) { + super(s, cause); + } +} diff --git a/impl/src/test/groovy/com/okta/sdk/impl/client/DefaultClientBuilderTest.groovy b/impl/src/test/groovy/com/okta/sdk/impl/client/DefaultClientBuilderTest.groovy index a00cde8d51e..64b78a72e09 100644 --- a/impl/src/test/groovy/com/okta/sdk/impl/client/DefaultClientBuilderTest.groovy +++ b/impl/src/test/groovy/com/okta/sdk/impl/client/DefaultClientBuilderTest.groovy @@ -16,26 +16,35 @@ */ package com.okta.sdk.impl.client +import com.okta.commons.http.config.BaseUrlResolver import com.okta.sdk.authc.credentials.TokenClientCredentials import com.okta.sdk.client.AuthenticationScheme +import com.okta.sdk.client.AuthorizationMode import com.okta.sdk.client.ClientBuilder import com.okta.sdk.client.Clients +import com.okta.sdk.impl.Util import com.okta.sdk.impl.io.DefaultResourceFactory import com.okta.sdk.impl.io.Resource import com.okta.sdk.impl.io.ResourceFactory +import com.okta.sdk.impl.oauth2.OAuth2TokenRetrieverException import com.okta.sdk.impl.test.RestoreEnvironmentVariables import com.okta.sdk.impl.test.RestoreSystemProperties -import com.okta.commons.http.config.BaseUrlResolver -import com.okta.sdk.impl.Util import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.testng.annotations.Listeners import org.testng.annotations.Test -import static org.testng.Assert.* -import static org.mockito.Mockito.* -import static org.hamcrest.Matchers.* +import java.security.KeyPair +import java.security.KeyPairGenerator +import java.security.PrivateKey + import static org.hamcrest.MatcherAssert.assertThat +import static org.hamcrest.Matchers.is +import static org.hamcrest.Matchers.nullValue +import static org.mockito.ArgumentMatchers.anyString +import static org.mockito.Mockito.* +import static org.testng.Assert.assertEquals +import static org.testng.Assert.assertTrue @Listeners([RestoreSystemProperties, RestoreEnvironmentVariables]) class DefaultClientBuilderTest { @@ -48,9 +57,17 @@ class DefaultClientBuilderTest { void clearOktaEnvAndSysProps() { System.clearProperty("okta.client.token") System.clearProperty("okta.client.orgUrl") + System.clearProperty("okta.client.authorizationMode") + System.clearProperty("okta.client.clientId") + System.clearProperty("okta.client.scopes") + System.clearProperty("okta.client.privateKey") RestoreEnvironmentVariables.setEnvironmentVariable("OKTA_CLIENT_TOKEN", null) RestoreEnvironmentVariables.setEnvironmentVariable("OKTA_CLIENT_ORGURL", null) + RestoreEnvironmentVariables.setEnvironmentVariable("OKTA_CLIENT_AUTHORIZATIONMODE", null) + RestoreEnvironmentVariables.setEnvironmentVariable("OKTA_CLIENT_CLIENTID", null) + RestoreEnvironmentVariables.setEnvironmentVariable("OKTA_CLIENT_SCOPES", null) + RestoreEnvironmentVariables.setEnvironmentVariable("OKTA_CLIENT_PRIVATEKEY", null) } @Test @@ -112,6 +129,7 @@ class DefaultClientBuilderTest { @Test void testConfigureBaseUrlResolver(){ + clearOktaEnvAndSysProps() BaseUrlResolver baseUrlResolver = new BaseUrlResolver() { @Override String getBaseUrl() { @@ -174,6 +192,174 @@ class DefaultClientBuilderTest { } } + @Test + void testOAuth2NullClientId() { + clearOktaEnvAndSysProps() + Util.expect(IllegalArgumentException) { + new DefaultClientBuilder(noDefaultYamlNoAppYamlResourceFactory()) + .setOrgUrl("https://okta.example.com") + .setAuthorizationMode(AuthorizationMode.PRIVATE_KEY) + .build() + } + } + + @Test + void testOAuth2NullScopes() { + clearOktaEnvAndSysProps() + Util.expect(IllegalArgumentException) { + new DefaultClientBuilder(noDefaultYamlNoAppYamlResourceFactory()) + .setOrgUrl("https://okta.example.com") + .setAuthorizationMode(AuthorizationMode.PRIVATE_KEY) + .setClientId("client12345") + .build() + } + } + + @Test + void testOAuth2EmptyScopes() { + clearOktaEnvAndSysProps() + Util.expect(IllegalArgumentException) { + new DefaultClientBuilder(noDefaultYamlNoAppYamlResourceFactory()) + .setOrgUrl("https://okta.example.com") + .setAuthorizationMode(AuthorizationMode.PRIVATE_KEY) + .setClientId("client12345") + .setScopes(new HashSet()) + .build() + } + } + + @Test + void testOAuth2NullPrivateKey() { + clearOktaEnvAndSysProps() + Util.expect(IllegalArgumentException) { + new DefaultClientBuilder(noDefaultYamlNoAppYamlResourceFactory()) + .setOrgUrl("https://okta.example.com") + .setAuthorizationMode(AuthorizationMode.PRIVATE_KEY) + .setClientId("client12345") + .setScopes(new HashSet<>(Arrays.asList({"okta.apps.read"}))) + .setPrivateKey(null) + .build() + } + } + + @Test + void testOAuth2InvalidPrivateKeyPemFilePath() { + clearOktaEnvAndSysProps() + Util.expect(IllegalArgumentException) { + new DefaultClientBuilder(noDefaultYamlNoAppYamlResourceFactory()) + .setOrgUrl("https://okta.example.com") + .setAuthorizationMode(AuthorizationMode.PRIVATE_KEY) + .setClientId("client12345") + .setScopes(new HashSet<>(Arrays.asList({"okta.apps.read"}))) + .setPrivateKey("blahblah.pem") + .build() + } + } + + @Test + void testOAuth2InvalidPrivateKeyPemFileContent() { + clearOktaEnvAndSysProps() + File privateKeyFile = File.createTempFile("tmp",".pem") + privateKeyFile.write("-----INVALID PEM CONTENT-----") + Util.expect(IllegalArgumentException) { + new DefaultClientBuilder(noDefaultYamlNoAppYamlResourceFactory()) + .setOrgUrl("https://okta.example.com") + .setAuthorizationMode(AuthorizationMode.PRIVATE_KEY) + .setClientId("client12345") + .setScopes(new HashSet<>(Arrays.asList({"okta.apps.read"}))) + .setPrivateKey(privateKeyFile.path) + .build() + } + + privateKeyFile.delete() + } + + @Test + void testOAuth2UnsupportedPrivateKeyAlgorithm() { + clearOktaEnvAndSysProps() + + // DSA algorithm is unsupported (we support only RSA & EC) + File privateKeyFile = generatePrivateKey("DSA", 2048, "privateKey", ".pem") + + Set scopes = new HashSet<>(); + scopes.add("okta.apps.read") + scopes.add("okta.apps.manage") + + Util.expect(OAuth2TokenRetrieverException) { + new DefaultClientBuilder(noDefaultYamlNoAppYamlResourceFactory()) + .setOrgUrl("https://okta.example.com") + .setAuthorizationMode(AuthorizationMode.PRIVATE_KEY) + .setClientId("client12345") + .setScopes(scopes) + .setPrivateKey(privateKeyFile.path) + .build() + } + + privateKeyFile.delete() + } + + @Test + void testOAuth2SemanticallyValidInputParams() { + clearOktaEnvAndSysProps() + + File privateKeyFile = generatePrivateKey("RSA", 2048, "privateKey", ".pem") + + Set scopes = new HashSet<>(); + scopes.add("okta.apps.read") + scopes.add("okta.apps.manage") + + // expected because the URL is not an actual endpoint + Util.expect(OAuth2TokenRetrieverException) { + new DefaultClientBuilder(noDefaultYamlNoAppYamlResourceFactory()) + .setOrgUrl("https://okta.example.com") + .setAuthorizationMode(AuthorizationMode.PRIVATE_KEY) + .setClientId("client12345") + .setScopes(scopes) + .setPrivateKey(privateKeyFile.path) + .build() + } + + privateKeyFile.delete() + } + + @Test + void testOAuth2WithEnvVariables() { + RestoreEnvironmentVariables.setEnvironmentVariable("OKTA_CLIENT_ORGURL", + "https://okta.example.com") + RestoreEnvironmentVariables.setEnvironmentVariable("OKTA_CLIENT_AUTHORIZATIONMODE", + AuthorizationMode.PRIVATE_KEY.getLabel()) // "PrivateKey" + RestoreEnvironmentVariables.setEnvironmentVariable("OKTA_CLIENT_CLIENTID", + "client12345") + RestoreEnvironmentVariables.setEnvironmentVariable("OKTA_CLIENT_SCOPES", + "okta.users.read okta.users.manage okta.apps.read okta.apps.manage") + + File privateKeyFile = generatePrivateKey("RSA", 2048, "privateKey", ".pem") + + RestoreEnvironmentVariables.setEnvironmentVariable("OKTA_CLIENT_PRIVATEKEY", privateKeyFile.path) + + // expected because the URL is not an actual endpoint + Util.expect(OAuth2TokenRetrieverException) { + new DefaultClientBuilder().build() + } + + privateKeyFile.delete() + } + + // helper methods + + static generatePrivateKey(String algorithm, int keySize, String fileNamePrefix, String fileNameSuffix) { + KeyPairGenerator keyGen = KeyPairGenerator.getInstance(algorithm) + keyGen.initialize(keySize) + KeyPair key = keyGen.generateKeyPair() + PrivateKey privateKey = key.getPrivate() + String encodedString = "-----BEGIN PRIVATE KEY-----\n" + encodedString = encodedString + Base64.getEncoder().encodeToString(privateKey.getEncoded()) + "\n" + encodedString = encodedString + "-----END PRIVATE KEY-----\n" + File file = File.createTempFile(fileNamePrefix,fileNameSuffix) + file.write(encodedString) + return file + } + static ResourceFactory noDefaultYamlNoAppYamlResourceFactory() { def resourceFactory = spy(new DefaultResourceFactory()) doAnswer(new Answer() { diff --git a/impl/src/test/groovy/com/okta/sdk/impl/client/DefaultClientBuilderTestCustomCredentialsTest.groovy b/impl/src/test/groovy/com/okta/sdk/impl/client/DefaultClientBuilderTestCustomCredentialsTest.groovy index 4dcfbbf34a7..c32344f624d 100644 --- a/impl/src/test/groovy/com/okta/sdk/impl/client/DefaultClientBuilderTestCustomCredentialsTest.groovy +++ b/impl/src/test/groovy/com/okta/sdk/impl/client/DefaultClientBuilderTestCustomCredentialsTest.groovy @@ -18,6 +18,7 @@ package com.okta.sdk.impl.client import com.okta.sdk.authc.credentials.ClientCredentials import com.okta.sdk.impl.api.DefaultClientCredentialsResolver import com.okta.sdk.authc.credentials.TokenClientCredentials +import com.okta.sdk.impl.test.RestoreEnvironmentVariables import org.testng.annotations.BeforeMethod import org.testng.annotations.Test @@ -29,6 +30,7 @@ class DefaultClientBuilderTestCustomCredentialsTest { @BeforeMethod void before() { + clearOktaEnvAndSysProps() id = UUID.randomUUID().toString() secret = UUID.randomUUID().toString() @@ -47,6 +49,8 @@ class DefaultClientBuilderTestCustomCredentialsTest { @Test void testCustomClientCredentialsAllowedWithApiKeyResolver(){ + clearOktaEnvAndSysProps() + def credentialsSecret = UUID.randomUUID().toString() ClientCredentials customCredentials = new ClientCredentials() { @@ -64,8 +68,14 @@ class DefaultClientBuilderTestCustomCredentialsTest { builder = new DefaultClientBuilder() builder.setClientCredentials(customCredentials) builder.setClientCredentialsResolver(apiKeyResolver) + def testClient = builder.build() assertEquals testClient.dataStore.clientCredentials.credentials, keySecret } + + void clearOktaEnvAndSysProps() { + System.clearProperty("okta.client.authorizationMode") + RestoreEnvironmentVariables.setEnvironmentVariable("OKTA_CLIENT_AUTHORIZATIONMODE", null) + } } diff --git a/impl/src/test/groovy/com/okta/sdk/impl/http/OAuth2RequestAuthenticatorConcurrencyTest.groovy b/impl/src/test/groovy/com/okta/sdk/impl/http/OAuth2RequestAuthenticatorConcurrencyTest.groovy new file mode 100644 index 00000000000..50e1aa2f5ee --- /dev/null +++ b/impl/src/test/groovy/com/okta/sdk/impl/http/OAuth2RequestAuthenticatorConcurrencyTest.groovy @@ -0,0 +1,87 @@ +/* + * Copyright 2020-Present Okta, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.okta.sdk.impl.http + +import com.okta.commons.http.HttpHeaders +import com.okta.commons.http.Request +import com.okta.commons.http.authc.RequestAuthenticator + +import com.okta.sdk.impl.http.authc.OAuth2RequestAuthenticator +import com.okta.sdk.impl.oauth2.AccessTokenRetrieverService +import com.okta.sdk.impl.oauth2.OAuth2AccessToken +import com.okta.sdk.impl.oauth2.OAuth2ClientCredentials + +import org.testng.annotations.AfterTest +import org.testng.annotations.BeforeTest +import org.testng.annotations.Test + +import static org.mockito.Mockito.when +import static org.mockito.Mockito.verify +import static org.mockito.Mockito.mock +import static org.mockito.Mockito.times + +/** + * Concurreny test for {@link com.okta.sdk.impl.http.authc.OAuth2RequestAuthenticator} class + * + * @since 1.6.0 + */ +class OAuth2RequestAuthenticatorConcurrencyTest { + + OAuth2RequestAuthenticator oAuth2RequestAuthenticator + + def initialAccessTokenStr = "initial-token-12345" + def refreshedAccessTokenStr = "refreshed-token-12345" + + def request = mock(Request) + def clientCredentials = mock(OAuth2ClientCredentials) + def accessTokenRetrievalService = mock(AccessTokenRetrieverService) + def initialAccessTokenObj = mock(OAuth2AccessToken) + def refreshedAccessTokenObj = mock(OAuth2AccessToken) + def httpHeaders = mock(HttpHeaders) + + @BeforeTest + void initialize() { + oAuth2RequestAuthenticator = new OAuth2RequestAuthenticator(clientCredentials) + + when(initialAccessTokenObj.getAccessToken()).thenReturn(initialAccessTokenStr) + when(initialAccessTokenObj.hasExpired()).thenReturn(true) + + when(refreshedAccessTokenObj.getAccessToken()).thenReturn(refreshedAccessTokenStr) + + when(clientCredentials.getCredentials()).thenReturn(initialAccessTokenObj) + when(clientCredentials.getAccessTokenRetrieverService()).thenReturn(accessTokenRetrievalService) + when(accessTokenRetrievalService.getOAuth2AccessToken()).thenReturn(refreshedAccessTokenObj) + + when(request.getHeaders()).thenReturn(httpHeaders) + } + + @Test(threadPoolSize = 5, invocationCount = 10) + void testAuthenticateRequestWithExpiredInitialToken() { + oAuth2RequestAuthenticator.authenticate(request) + Thread.sleep((long)(Math.random() * 1000)) /* sleep random time (max 1000 ms) */ + } + + @AfterTest + void verifyMocks() { + verify(clientCredentials, times(10)).getCredentials() + verify(initialAccessTokenObj, times(20)).hasExpired() // double locking + verify(accessTokenRetrievalService, times(10)).getOAuth2AccessToken() + verify(clientCredentials, times(10)).setCredentials(refreshedAccessTokenObj) + verify(request.getHeaders(), times(10)) + .set(RequestAuthenticator.AUTHORIZATION_HEADER, "Bearer " + refreshedAccessTokenStr) + } + +} \ No newline at end of file diff --git a/impl/src/test/groovy/com/okta/sdk/impl/http/OAuth2RequestAuthenticatorTest.groovy b/impl/src/test/groovy/com/okta/sdk/impl/http/OAuth2RequestAuthenticatorTest.groovy new file mode 100644 index 00000000000..d835831e280 --- /dev/null +++ b/impl/src/test/groovy/com/okta/sdk/impl/http/OAuth2RequestAuthenticatorTest.groovy @@ -0,0 +1,223 @@ +/* + * Copyright 2020-Present Okta, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.okta.sdk.impl.http + +import com.okta.commons.http.HttpHeaders +import com.okta.commons.http.Request +import com.okta.commons.http.authc.RequestAuthenticator +import com.okta.sdk.impl.Util +import com.okta.sdk.impl.http.authc.OAuth2RequestAuthenticator +import com.okta.sdk.impl.oauth2.AccessTokenRetrieverService +import com.okta.sdk.impl.oauth2.OAuth2AccessToken +import com.okta.sdk.impl.oauth2.OAuth2ClientCredentials +import com.okta.sdk.impl.oauth2.OAuth2TokenRetrieverException + +import org.testng.annotations.Test + +import java.security.InvalidKeyException + +import static org.mockito.Mockito.mock +import static org.mockito.Mockito.never +import static org.mockito.Mockito.reset +import static org.mockito.Mockito.times +import static org.mockito.Mockito.verify +import static org.mockito.Mockito.when + +/** + * Test for {@link com.okta.sdk.impl.http.authc.OAuth2RequestAuthenticator} class + * + * @since 1.6.0 + */ +class OAuth2RequestAuthenticatorTest { + + @Test + void testInstantiationWithNullClientCredentials() { + Util.expect(IllegalArgumentException) { + new OAuth2RequestAuthenticator(null) + } + } + + @Test + void testAuthenticateRequestWithUnexpiredInitialToken() { + def initialAccessTokenStr = "initial-token-12345" + + def request = mock(Request) + + def clientCredentials = mock(OAuth2ClientCredentials) + def accessTokenRetrievalService = mock(AccessTokenRetrieverService) + def initialAccessTokenObj = mock(OAuth2AccessToken) + + def httpHeaders = mock(HttpHeaders) + + when(initialAccessTokenObj.getAccessToken()).thenReturn(initialAccessTokenStr) + when(initialAccessTokenObj.hasExpired()).thenReturn(false) + + when(clientCredentials.getCredentials()).thenReturn(initialAccessTokenObj) + when(clientCredentials.getAccessTokenRetrieverService()).thenReturn(accessTokenRetrievalService) + + when(request.getHeaders()).thenReturn(httpHeaders) + + new OAuth2RequestAuthenticator(clientCredentials).authenticate(request) + + verify(clientCredentials, times(1)).getCredentials() + verify(initialAccessTokenObj, times(1)).hasExpired() + verify(request.getHeaders(), times(1)) + .set(RequestAuthenticator.AUTHORIZATION_HEADER, "Bearer " + initialAccessTokenStr) + } + + @Test + void testAuthenticateRequestWithExpiredInitialToken() { + def initialAccessTokenStr = "initial-token-12345" + def refreshedAccessTokenStr = "refreshed-token-12345" + + def request = mock(Request) + + def clientCredentials = mock(OAuth2ClientCredentials) + def accessTokenRetrievalService = mock(AccessTokenRetrieverService) + def initialAccessTokenObj = mock(OAuth2AccessToken) + def refreshedAccessTokenObj = mock(OAuth2AccessToken) + + def httpHeaders = mock(HttpHeaders) + + when(initialAccessTokenObj.getAccessToken()).thenReturn(initialAccessTokenStr) + when(initialAccessTokenObj.hasExpired()).thenReturn(true) + + when(refreshedAccessTokenObj.getAccessToken()).thenReturn(refreshedAccessTokenStr) + + when(clientCredentials.getCredentials()).thenReturn(initialAccessTokenObj) + when(clientCredentials.getAccessTokenRetrieverService()).thenReturn(accessTokenRetrievalService) + when(accessTokenRetrievalService.getOAuth2AccessToken()).thenReturn(refreshedAccessTokenObj) + + when(request.getHeaders()).thenReturn(httpHeaders) + + new OAuth2RequestAuthenticator(clientCredentials).authenticate(request) + + verify(clientCredentials, times(1)).getCredentials() + verify(initialAccessTokenObj, times(2)).hasExpired() // double locking + verify(accessTokenRetrievalService, times(1)).getOAuth2AccessToken() + verify(clientCredentials, times(1)).setCredentials(refreshedAccessTokenObj) + verify(request.getHeaders(), times(1)) + .set(RequestAuthenticator.AUTHORIZATION_HEADER, "Bearer " + refreshedAccessTokenStr) + } + + @Test(expectedExceptions = OAuth2TokenRetrieverException) + void testRefreshTokenFetchException() { + def initialAccessTokenStr = "initial-token-12345" + + def request = mock(Request) + + def clientCredentials = mock(OAuth2ClientCredentials) + def accessTokenRetrievalService = mock(AccessTokenRetrieverService) + def initialAccessTokenObj = mock(OAuth2AccessToken) + + when(initialAccessTokenObj.getAccessToken()).thenReturn(initialAccessTokenStr) + when(initialAccessTokenObj.hasExpired()).thenReturn(true) + + when(clientCredentials.getCredentials()).thenReturn(initialAccessTokenObj) + when(clientCredentials.getAccessTokenRetrieverService()).thenReturn(accessTokenRetrievalService) + when(accessTokenRetrievalService.getOAuth2AccessToken()) + .thenThrow(new OAuth2TokenRetrieverException("Failed to renew expired OAuth2 access token")) + + new OAuth2RequestAuthenticator(clientCredentials).authenticate(request) + + verify(clientCredentials, times(1)).getCredentials() + verify(initialAccessTokenObj, times(2)).hasExpired() // double locking + verify(clientCredentials, never()).setCredentials(mock(OAuth2AccessToken)) + verify(request.getHeaders(), never()) + .set(RequestAuthenticator.AUTHORIZATION_HEADER, "Bearer " + any(String.class)) + } + + @Test(expectedExceptions = OAuth2TokenRetrieverException) + void testRefreshTokenFetchInvalidKeyException() { + def initialAccessTokenStr = "initial-token-12345" + + def request = mock(Request) + + def clientCredentials = mock(OAuth2ClientCredentials) + def accessTokenRetrievalService = mock(AccessTokenRetrieverService) + def initialAccessTokenObj = mock(OAuth2AccessToken) + + when(initialAccessTokenObj.getAccessToken()).thenReturn(initialAccessTokenStr) + when(initialAccessTokenObj.hasExpired()).thenReturn(true) + + when(clientCredentials.getCredentials()).thenReturn(initialAccessTokenObj) + when(clientCredentials.getAccessTokenRetrieverService()).thenReturn(accessTokenRetrievalService) + when(accessTokenRetrievalService.getOAuth2AccessToken()) + .thenThrow(new InvalidKeyException("Failed to renew expired OAuth2 access token")) + + new OAuth2RequestAuthenticator(clientCredentials).authenticate(request) + + verify(clientCredentials, times(1)).getCredentials() + verify(initialAccessTokenObj, times(2)).hasExpired() // double locking + verify(clientCredentials, never()).setCredentials(mock(OAuth2AccessToken)) + verify(request.getHeaders(), never()) + .set(RequestAuthenticator.AUTHORIZATION_HEADER, "Bearer " + any(String.class)) + } + + @Test + void testRefreshTokenFetchWithTokenReuse() { + def initialAccessTokenStr = "initial-token-12345" + def refreshedAccessTokenStr = "refreshed-token-12345" + + def request = mock(Request) + + def clientCredentials = mock(OAuth2ClientCredentials) + def accessTokenRetrievalService = mock(AccessTokenRetrieverService) + def initialAccessTokenObj = mock(OAuth2AccessToken) + def refreshedAccessTokenObj = mock(OAuth2AccessToken) + + def httpHeaders = mock(HttpHeaders) + + when(initialAccessTokenObj.getAccessToken()).thenReturn(initialAccessTokenStr) + when(initialAccessTokenObj.hasExpired()).thenReturn(true) + + when(refreshedAccessTokenObj.getAccessToken()).thenReturn(refreshedAccessTokenStr) + + when(clientCredentials.getCredentials()).thenReturn(initialAccessTokenObj) + when(clientCredentials.getAccessTokenRetrieverService()).thenReturn(accessTokenRetrievalService) + when(accessTokenRetrievalService.getOAuth2AccessToken()).thenReturn(refreshedAccessTokenObj) + + when(request.getHeaders()).thenReturn(httpHeaders) + + new OAuth2RequestAuthenticator(clientCredentials).authenticate(request) + + verify(clientCredentials, times(1)).getCredentials() + verify(initialAccessTokenObj, times(2)).hasExpired() // double locking + verify(accessTokenRetrievalService, times(1)).getOAuth2AccessToken() + verify(clientCredentials, times(1)).setCredentials(refreshedAccessTokenObj) + verify(request.getHeaders(), times(1)) + .set(RequestAuthenticator.AUTHORIZATION_HEADER, "Bearer " + refreshedAccessTokenStr) + + // reset mocks + reset(request, clientCredentials, accessTokenRetrievalService, initialAccessTokenObj, refreshedAccessTokenObj, httpHeaders) + + // reuse the refreshed token which we got above (do not expire it) + when(request.getHeaders()).thenReturn(httpHeaders) + when(clientCredentials.getCredentials()).thenReturn(refreshedAccessTokenObj) + when(refreshedAccessTokenObj.getAccessToken()).thenReturn(refreshedAccessTokenStr) + when(refreshedAccessTokenObj.hasExpired()).thenReturn(false) + + new OAuth2RequestAuthenticator(clientCredentials).authenticate(request) + + verify(clientCredentials, times(1)).getCredentials() + verify(refreshedAccessTokenObj, times(1)).hasExpired() + verify(accessTokenRetrievalService, never()).getOAuth2AccessToken() + verify(clientCredentials, never()).setCredentials(refreshedAccessTokenObj) + verify(request.getHeaders(), times(1)) + .set(RequestAuthenticator.AUTHORIZATION_HEADER, "Bearer " + refreshedAccessTokenStr) + } + +} \ No newline at end of file diff --git a/impl/src/test/groovy/com/okta/sdk/impl/oauth2/AccessTokenRetrieverServiceImplTest.groovy b/impl/src/test/groovy/com/okta/sdk/impl/oauth2/AccessTokenRetrieverServiceImplTest.groovy new file mode 100644 index 00000000000..07b9bc2df11 --- /dev/null +++ b/impl/src/test/groovy/com/okta/sdk/impl/oauth2/AccessTokenRetrieverServiceImplTest.groovy @@ -0,0 +1,255 @@ +/* + * Copyright 2020-Present Okta, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.okta.sdk.impl.oauth2 + +import com.okta.commons.http.config.BaseUrlResolver +import com.okta.sdk.ds.RequestBuilder +import com.okta.sdk.impl.Util +import com.okta.sdk.impl.api.DefaultClientCredentialsResolver +import com.okta.sdk.impl.config.ClientConfiguration +import com.okta.sdk.impl.error.DefaultError +import com.okta.sdk.resource.ResourceException +import io.jsonwebtoken.Claims +import io.jsonwebtoken.Jwts +import org.testng.annotations.Test + +import java.security.KeyPair +import java.security.KeyPairGenerator +import java.security.PrivateKey + +import static org.hamcrest.MatcherAssert.assertThat +import static org.hamcrest.Matchers.is +import static org.hamcrest.Matchers.notNullValue +import static org.mockito.ArgumentMatchers.any +import static org.mockito.ArgumentMatchers.anyString +import static org.mockito.Mockito.when +import static org.mockito.Mockito.verify +import static org.mockito.Mockito.mock +import static org.mockito.Mockito.times +import static org.testng.Assert.assertEquals + +/** + * Test for {@link AccessTokenRetrieverServiceImpl} class + * + * @since 1.6.0 + */ +class AccessTokenRetrieverServiceImplTest { + + @Test + void testInstantiationWithNullClientConfig() { + Util.expect(IllegalArgumentException) { + new AccessTokenRetrieverServiceImpl(null) + } + } + + @Test + void testGetPrivateKeyFromPem() { + + PrivateKey generatedPrivateKey = generatePrivateKey("RSA", 2048) + File privateKeyPemFile = writePrivateKeyToPemFile(generatedPrivateKey, "privateKey") + + // Now test the pem -> private key conversion function of getPrivateKeyFromPem method + Reader pemFileReader = new FileReader(privateKeyPemFile) + + PrivateKey resultPrivateKey = getAccessTokenRetrieverServiceInstance().getPrivateKeyFromPEM(pemFileReader) + + privateKeyPemFile.deleteOnExit() + + assertThat(resultPrivateKey, notNullValue()) + assertThat(resultPrivateKey.getAlgorithm(), is("RSA")) + assertThat(resultPrivateKey.getFormat(), is("PKCS#8")) + } + + @Test + void testParsePrivateKey() { + PrivateKey generatedPrivateKey = generatePrivateKey("RSA", 2048) + File privateKeyPemFile = writePrivateKeyToPemFile(generatedPrivateKey, "privateKey") + + PrivateKey parsedPrivateKey = getAccessTokenRetrieverServiceInstance().parsePrivateKey(privateKeyPemFile.path) + + privateKeyPemFile.deleteOnExit() + + assertThat(parsedPrivateKey, notNullValue()) + assertThat(parsedPrivateKey.getAlgorithm(), is("RSA")) + assertThat(parsedPrivateKey.getFormat(), is("PKCS#8")) + } + + @Test + void testCreateSignedJWT() { + def clientConfig = mock(ClientConfiguration) + + PrivateKey generatedPrivateKey = generatePrivateKey("RSA", 2048) + File privateKeyPemFile = writePrivateKeyToPemFile(generatedPrivateKey, "privateKey") + + String baseUrl = "https://sample.okta.com" + BaseUrlResolver baseUrlResolver = new BaseUrlResolver() { + @Override + String getBaseUrl() { + return baseUrl + } + } + + when(clientConfig.getBaseUrl()).thenReturn(baseUrl) + when(clientConfig.getClientId()).thenReturn("client12345") + when(clientConfig.getPrivateKey()).thenReturn(privateKeyPemFile.path) + when(clientConfig.getBaseUrlResolver()).thenReturn(baseUrlResolver) + when(clientConfig.getClientCredentialsResolver()).thenReturn( + new DefaultClientCredentialsResolver({ -> Optional.empty() })) + + String signedJwt = getAccessTokenRetrieverServiceInstance(clientConfig).createSignedJWT() + + privateKeyPemFile.deleteOnExit() + + assertThat(signedJwt, notNullValue()) + + // decode the signed jwt and verify + Claims claims = Jwts.parser() + .setSigningKey(generatedPrivateKey) + .parseClaimsJws(signedJwt).getBody() + + assertThat(claims, notNullValue()) + + assertEquals(claims.get("aud"), clientConfig.getBaseUrl() + "/oauth2/v1/token") + assertThat(claims.get("iat"), notNullValue()) + assertThat(claims.get("exp"), notNullValue()) + assertEquals(Integer.valueOf(claims.get("exp")) - Integer.valueOf(claims.get("iat")), 3600, + "token expiry time is not 3600s") + assertThat(claims.get("iss"), notNullValue()) + assertEquals(claims.get("iss"), clientConfig.getClientId(), "iss must be equal to client id") + assertThat(claims.get("sub"), notNullValue()) + assertEquals(claims.get("sub"), clientConfig.getClientId(), "sub must be equal to client id") + assertThat(claims.get("jti"), notNullValue()) + } + + @Test(expectedExceptions = OAuth2TokenRetrieverException.class) + void testGetOAuth2TokenRetrieverRuntimeException() { + def tokenClient = mock(OAuth2TokenClient) + def requestBuilder = mock(RequestBuilder) + def clientConfig = mock(ClientConfiguration) + + PrivateKey generatedPrivateKey = generatePrivateKey("RSA", 2048) + File privateKeyPemFile = writePrivateKeyToPemFile(generatedPrivateKey, "privateKey") + + String baseUrl = "https://sample.okta.com" + BaseUrlResolver baseUrlResolver = new BaseUrlResolver() { + @Override + String getBaseUrl() { + return baseUrl + } + } + + when(clientConfig.getBaseUrl()).thenReturn(baseUrl) + when(clientConfig.getClientId()).thenReturn("client12345") + when(clientConfig.getPrivateKey()).thenReturn(privateKeyPemFile.path) + when(clientConfig.getBaseUrlResolver()).thenReturn(baseUrlResolver) + when(clientConfig.getClientCredentialsResolver()).thenReturn( + new DefaultClientCredentialsResolver({ -> Optional.empty() })) + + when(tokenClient.http()).thenReturn(requestBuilder) + + when(requestBuilder.addHeaderParameter(anyString(), anyString())).thenReturn(requestBuilder) + when(requestBuilder.addQueryParameter(anyString(), anyString())).thenReturn(requestBuilder) + + when(requestBuilder.post(anyString(), any())).thenThrow(new RuntimeException("Unexpected runtime error")) + + def accessTokenRetrieverService = new AccessTokenRetrieverServiceImpl(clientConfig, tokenClient) + accessTokenRetrieverService.getOAuth2AccessToken() + + verify(tokenClient, times(1)).http() + verify(requestBuilder, times(1)).post() + } + + @Test(expectedExceptions = OAuth2HttpException.class) + void testGetOAuth2ResourceException() { + def tokenClient = mock(OAuth2TokenClient) + def requestBuilder = mock(RequestBuilder) + def clientConfig = mock(ClientConfiguration) + + PrivateKey generatedPrivateKey = generatePrivateKey("RSA", 2048) + File privateKeyPemFile = writePrivateKeyToPemFile(generatedPrivateKey, "privateKey") + + String baseUrl = "https://sample.okta.com" + BaseUrlResolver baseUrlResolver = new BaseUrlResolver() { + @Override + String getBaseUrl() { + return baseUrl + } + } + + when(clientConfig.getBaseUrl()).thenReturn(baseUrl) + when(clientConfig.getClientId()).thenReturn("client12345") + when(clientConfig.getPrivateKey()).thenReturn(privateKeyPemFile.path) + when(clientConfig.getBaseUrlResolver()).thenReturn(baseUrlResolver) + when(clientConfig.getClientCredentialsResolver()).thenReturn( + new DefaultClientCredentialsResolver({ -> Optional.empty() })) + + when(tokenClient.http()).thenReturn(requestBuilder) + + when(requestBuilder.addHeaderParameter(anyString(), anyString())).thenReturn(requestBuilder) + when(requestBuilder.addQueryParameter(anyString(), anyString())).thenReturn(requestBuilder) + + DefaultError defaultError = new DefaultError() + defaultError.setStatus(401) + defaultError.setProperty(OAuth2AccessToken.ERROR_KEY, "error key") + defaultError.setProperty(OAuth2AccessToken.ERROR_DESCRIPTION, "error desc") + + ResourceException resourceException = new ResourceException(defaultError) + + when(requestBuilder.post(anyString(), any())).thenThrow(resourceException) + + def accessTokenRetrieverService = new AccessTokenRetrieverServiceImpl(clientConfig, tokenClient) + accessTokenRetrieverService.getOAuth2AccessToken() + + verify(tokenClient, times(1)).http() + verify(requestBuilder, times(1)).post() + } + + // helper methods + + PrivateKey generatePrivateKey(String algorithm, int keySize) { + KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance(algorithm) + keyPairGenerator.initialize(keySize) + KeyPair keyPair = keyPairGenerator.generateKeyPair() + PrivateKey privateKey = keyPair.getPrivate() + return privateKey + } + + File writePrivateKeyToPemFile(PrivateKey privateKey, String fileNamePrefix) { + String encodedString = "-----BEGIN PRIVATE KEY-----\n" + encodedString = encodedString + Base64.getEncoder().encodeToString(privateKey.getEncoded()) + "\n" + encodedString = encodedString + "-----END PRIVATE KEY-----\n" + File privateKeyPemFile = File.createTempFile(fileNamePrefix,".pem") + privateKeyPemFile.write(encodedString) + return privateKeyPemFile + } + + AccessTokenRetrieverServiceImpl getAccessTokenRetrieverServiceInstance(ClientConfiguration clientConfiguration) { + if (clientConfiguration == null) { + ClientConfiguration cc = new ClientConfiguration() + cc.setBaseUrlResolver(new BaseUrlResolver() { + @Override + String getBaseUrl() { + return "https://sample.okta.com" + } + }) + cc.setClientCredentialsResolver(new DefaultClientCredentialsResolver({ -> Optional.empty() })) + return new AccessTokenRetrieverServiceImpl(cc) + } + + return new AccessTokenRetrieverServiceImpl(clientConfiguration) + } + +} \ No newline at end of file diff --git a/integration-tests/src/test/groovy/com/okta/sdk/tests/it/ApplicationsIT.groovy b/integration-tests/src/test/groovy/com/okta/sdk/tests/it/ApplicationsIT.groovy index a549a82c3d7..1fabcf81279 100644 --- a/integration-tests/src/test/groovy/com/okta/sdk/tests/it/ApplicationsIT.groovy +++ b/integration-tests/src/test/groovy/com/okta/sdk/tests/it/ApplicationsIT.groovy @@ -1,5 +1,5 @@ /* - * Copyright 2017 Okta + * Copyright 2017-Present Okta, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ */ package com.okta.sdk.tests.it + import com.okta.sdk.client.Client import com.okta.sdk.resource.ResourceException import com.okta.sdk.resource.application.AppUser diff --git a/integration-tests/src/test/groovy/com/okta/sdk/tests/it/util/ClientProvider.groovy b/integration-tests/src/test/groovy/com/okta/sdk/tests/it/util/ClientProvider.groovy index a5d23cfd2d0..07c15d8ba54 100644 --- a/integration-tests/src/test/groovy/com/okta/sdk/tests/it/util/ClientProvider.groovy +++ b/integration-tests/src/test/groovy/com/okta/sdk/tests/it/util/ClientProvider.groovy @@ -35,7 +35,6 @@ import org.slf4j.Logger import org.slf4j.LoggerFactory import org.testng.IHookCallBack import org.testng.IHookable -import org.testng.ITestNGMethod import org.testng.ITestResult import org.testng.annotations.AfterMethod import org.testng.annotations.Listeners diff --git a/pom.xml b/pom.xml index bdf50214863..43cf8cf4e1c 100644 --- a/pom.xml +++ b/pom.xml @@ -37,7 +37,8 @@ 1.5.8 2.10.0 1.25 - 0.10.7 + 1.64 + 0.11.1 1.5.4 1.2.2 @@ -115,6 +116,11 @@ okta-http-httpclient ${otka.commons.version} + + com.okta.commons + okta-http-okhttp + ${otka.commons.version} + javax.annotation javax.annotation-api @@ -152,17 +158,38 @@ 1.2.2 + + + org.bouncycastle + bcprov-jdk15on + ${bouncycastle.version} + + + + org.bouncycastle + bcpkix-jdk15on + ${bouncycastle.version} + + + io.jsonwebtoken - jjwt + jjwt-api ${jjwt.version} - - - com.fasterxml.jackson.core - jackson-databind - - + + io.jsonwebtoken + jjwt-impl + ${jjwt.version} + runtime + + + io.jsonwebtoken + jjwt-jackson + ${jjwt.version} + runtime + + org.yaml snakeyaml diff --git a/src/findbugs/findbugs-exclude.xml b/src/findbugs/findbugs-exclude.xml index 2d2deaef2d8..e3b5c10cb0c 100644 --- a/src/findbugs/findbugs-exclude.xml +++ b/src/findbugs/findbugs-exclude.xml @@ -1,5 +1,5 @@