Skip to content

Refactored logic for determining SSL inputs #1607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
import com.marklogic.client.DatabaseClientBuilder;
import com.marklogic.client.DatabaseClientFactory;
import com.marklogic.client.extra.okhttpclient.RemoveAcceptEncodingConfigurator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.net.ssl.SSLContext;
import javax.net.ssl.X509TrustManager;
Expand All @@ -40,7 +38,6 @@
*/
public class DatabaseClientPropertySource {

private static final Logger logger = LoggerFactory.getLogger(DatabaseClientPropertySource.class);
private static final String PREFIX = DatabaseClientBuilder.PREFIX;

private final Function<String, Object> propertySource;
Expand Down Expand Up @@ -97,7 +94,7 @@ public class DatabaseClientPropertySource {
if (value instanceof Boolean && Boolean.TRUE.equals(value)) {
disableGzippedResponses = true;
} else if (value instanceof String) {
disableGzippedResponses = Boolean.parseBoolean((String)value);
disableGzippedResponses = Boolean.parseBoolean((String) value);
}
if (disableGzippedResponses) {
DatabaseClientFactory.addConfigurator(new RemoveAcceptEncodingConfigurator());
Expand Down Expand Up @@ -152,20 +149,13 @@ private DatabaseClientFactory.SecurityContext newSecurityContext() {
if (typeValue == null || !(typeValue instanceof String)) {
throw new IllegalArgumentException("Security context should be set, or auth type must be of type String");
}
final String authType = (String)typeValue;
final SSLInputs sslInputs = buildSSLInputs(authType);
final String authType = (String) typeValue;

final SSLInputs sslInputs = buildSSLInputs(authType);
DatabaseClientFactory.SecurityContext securityContext = newSecurityContext(authType, sslInputs);

X509TrustManager trustManager = determineTrustManager(sslInputs);
SSLContext sslContext = sslInputs.getSslContext() != null ?
sslInputs.getSslContext() :
determineSSLContext(sslInputs, trustManager);

if (sslContext != null) {
securityContext.withSSLContext(sslContext, trustManager);
if (sslInputs.getSslContext() != null) {
securityContext.withSSLContext(sslInputs.getSslContext(), sslInputs.getTrustManager());
}

securityContext.withSSLHostnameVerifier(determineHostnameVerifier());
return securityContext;
}
Expand Down Expand Up @@ -202,7 +192,7 @@ private String getNullableStringValue(String propertyName) {
if (value != null && !(value instanceof String)) {
throw new IllegalArgumentException(propertyName + " must be of type String");
}
return (String)value;
return (String) value;
}

private DatabaseClientFactory.SecurityContext newBasicAuthContext() {
Expand Down Expand Up @@ -255,57 +245,6 @@ private DatabaseClientFactory.SecurityContext newSAMLAuthContext() {
return new DatabaseClientFactory.SAMLAuthContext(getRequiredStringValue("saml.token"));
}

private SSLContext determineSSLContext(SSLInputs sslInputs, X509TrustManager trustManager) {
String protocol = sslInputs.getSslProtocol();
if (protocol != null) {
if ("default".equalsIgnoreCase(protocol)) {
try {
return SSLContext.getDefault();
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException("Unable to obtain default SSLContext; cause: " + e.getMessage(), e);
}
}

SSLContext sslContext;
try {
sslContext = SSLContext.getInstance(protocol);
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException("Unable to get SSLContext instance with protocol: " + protocol
+ "; cause: " + e.getMessage(), e);
}
// Note that if only a protocol is specified, and not a TrustManager, an attempt will later be made
// to use the JVM's default TrustManager
if (trustManager != null) {
try {
sslContext.init(null, new X509TrustManager[]{trustManager}, null);
} catch (KeyManagementException e) {
throw new RuntimeException("Unable to initialize SSLContext; protocol: " + protocol + "; cause: " + e.getMessage(), e);
}
}
return sslContext;
}
return null;
}

private X509TrustManager determineTrustManager(SSLInputs sslInputs) {
if (sslInputs.getTrustManager() != null) {
return sslInputs.getTrustManager();
}
// If the user chooses the "default" SSLContext, then it's already been initialized - but OkHttp still
// needs a separate X509TrustManager, so use the JVM's default trust manager. The assumption is that the
// default SSLContext was initialized with the JVM's default trust manager. A user can of course always override
// this by simply providing their own trust manager.
if ("default".equalsIgnoreCase(sslInputs.getSslProtocol())) {
X509TrustManager defaultTrustManager = SSLUtil.getDefaultTrustManager();
if (logger.isDebugEnabled() && defaultTrustManager != null && defaultTrustManager.getAcceptedIssuers() != null) {
logger.debug("Count of accepted issuers in default trust manager: {}",
defaultTrustManager.getAcceptedIssuers().length);
}
return defaultTrustManager;
}
return null;
}

private DatabaseClientFactory.SSLHostnameVerifier determineHostnameVerifier() {
Object verifierObject = propertySource.apply(PREFIX + "sslHostnameVerifier");
if (verifierObject instanceof DatabaseClientFactory.SSLHostnameVerifier) {
Expand All @@ -329,61 +268,124 @@ private DatabaseClientFactory.SSLHostnameVerifier determineHostnameVerifier() {
* X509TrustManager.
*
* @param authType used for applying "default" as the SSL protocol for MarkLogic cloud authentication in
* case the user does not define their own SSLContext or SSL protocol
* case the user does not define their own SSLContext or SSL protocol
* @return
*/
private SSLInputs buildSSLInputs(String authType) {
SSLContext sslContext = null;
X509TrustManager userTrustManager = getTrustManager();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where all the changes start. This method should be a lot easier to understand now.


// Approach 1 - user provides an SSLContext object, in which case there's nothing further to check.
SSLContext sslContext = getSSLContext();
if (sslContext != null) {
return new SSLInputs(sslContext, userTrustManager);
}

// Approaches 2 and 3 - user defines an SSL protocol.
// Approach 2 - "default" is a convenience for using the JVM's default SSLContext.
// Approach 3 - create a new SSLContext, and initialize it if the user-provided TrustManager is not null.
final String sslProtocol = getSSLProtocol(authType);
if (sslProtocol != null) {
return "default".equalsIgnoreCase(sslProtocol) ?
useDefaultSSLContext(userTrustManager) :
useNewSSLContext(sslProtocol, userTrustManager);
}

// Approach 4 - no SSL connection is needed.
return new SSLInputs(null, null);
}

private X509TrustManager getTrustManager() {
Object val = propertySource.apply(PREFIX + "trustManager");
if (val != null) {
if (val instanceof X509TrustManager) {
return (X509TrustManager) val;
} else {
throw new IllegalArgumentException("Trust manager must be an instanceof " + X509TrustManager.class.getName());
}
}
return null;
}

private SSLContext getSSLContext() {
Object val = propertySource.apply(PREFIX + "sslContext");
if (val != null) {
if (val instanceof SSLContext) {
sslContext = (SSLContext) val;
return (SSLContext) val;
} else {
throw new IllegalArgumentException("SSL context must be an instanceof " + SSLContext.class.getName());
}
}
return null;
}

private String getSSLProtocol(String authType) {
String sslProtocol = getNullableStringValue("sslProtocol");
if (sslContext == null &&
(sslProtocol == null || sslProtocol.trim().length() == 0) &&
DatabaseClientBuilder.AUTH_TYPE_MARKLOGIC_CLOUD.equalsIgnoreCase(authType)) {
if (sslProtocol != null) {
sslProtocol = sslProtocol.trim();
}
// For convenience for MarkLogic Cloud users, assume the JVM's default SSLContext should trust the certificate
// used by MarkLogic Cloud. A user can always override this default behavior by providing their own SSLContext.
if ((sslProtocol == null || sslProtocol.length() == 0) && DatabaseClientBuilder.AUTH_TYPE_MARKLOGIC_CLOUD.equalsIgnoreCase(authType)) {
sslProtocol = "default";
}
return sslProtocol;
}

/**
* Uses the JVM's default SSLContext. Because OkHttp requires a separate TrustManager, this approach will either
* user the user-provided TrustManager or it will assume that the JVM's default TrustManager should be used.
*/
private SSLInputs useDefaultSSLContext(X509TrustManager userTrustManager) {
SSLContext sslContext;
try {
sslContext = SSLContext.getDefault();
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException("Unable to obtain default SSLContext; cause: " + e.getMessage(), e);
}
X509TrustManager trustManager = userTrustManager != null ? userTrustManager : SSLUtil.getDefaultTrustManager();
return new SSLInputs(sslContext, trustManager);
}

val = propertySource.apply(PREFIX + "trustManager");
X509TrustManager trustManager = null;
if (val != null) {
if (val instanceof X509TrustManager) {
trustManager = (X509TrustManager) val;
} else {
throw new IllegalArgumentException("Trust manager must be an instanceof " + X509TrustManager.class.getName());
/**
* Constructs a new SSLContext based on the given protocol (e.g. TLSv1.2). The SSLContext will be initialized if
* the user's TrustManager is not null. Otherwise, OkHttpUtil will eventually initialize the SSLContext using the
* JVM's default TrustManager.
*/
private SSLInputs useNewSSLContext(String sslProtocol, X509TrustManager userTrustManager) {
SSLContext sslContext;
try {
sslContext = SSLContext.getInstance(sslProtocol);
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException(String.format("Unable to get SSLContext instance with protocol: %s; cause: %s",
sslProtocol, e.getMessage()), e);
}
if (userTrustManager != null) {
try {
sslContext.init(null, new X509TrustManager[]{userTrustManager}, null);
} catch (KeyManagementException e) {
throw new RuntimeException(String.format("Unable to initialize SSLContext; protocol: %s; cause: %s",
sslProtocol, e.getMessage()), e);
}
}
return new SSLInputs(sslContext, sslProtocol, trustManager);
return new SSLInputs(sslContext, userTrustManager);
}

/**
* Captures the inputs provided by the caller that pertain to constructing an SSLContext.
*/
private static class SSLInputs {
private final SSLContext sslContext;
private final String sslProtocol;
private final X509TrustManager trustManager;

public SSLInputs(SSLContext sslContext, String sslProtocol, X509TrustManager trustManager) {
public SSLInputs(SSLContext sslContext, X509TrustManager trustManager) {
this.sslContext = sslContext;
this.sslProtocol = sslProtocol;
this.trustManager = trustManager;
}

public SSLContext getSslContext() {
return sslContext;
}

public String getSslProtocol() {
return sslProtocol;
}

public X509TrustManager getTrustManager() {
return trustManager;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
*/
package com.marklogic.client.impl;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.X509TrustManager;
Expand All @@ -25,7 +28,12 @@
public interface SSLUtil {

static X509TrustManager getDefaultTrustManager() {
return (X509TrustManager) getDefaultTrustManagers()[0];
X509TrustManager trustManager = (X509TrustManager) getDefaultTrustManagers()[0];
Logger logger = LoggerFactory.getLogger(SSLUtil.class);
if (logger.isDebugEnabled() && trustManager.getAcceptedIssuers() != null) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved this logging here as it's useful to see for debugging purposes regardless of when this method is called.

logger.debug("Count of accepted issuers in default trust manager: {}", trustManager.getAcceptedIssuers().length);
}
return trustManager;
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package com.marklogic.client.test;
package com.marklogic.client.test.ssl;

import com.marklogic.client.DatabaseClient;
import com.marklogic.client.DatabaseClientFactory;
import com.marklogic.client.ForbiddenUserException;
import com.marklogic.client.MarkLogicIOException;
import com.marklogic.client.test.Common;
import com.marklogic.client.test.junit5.RequireSSLExtension;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.marklogic.client.test;
package com.marklogic.client.test.ssl;

import com.marklogic.client.DatabaseClient;
import com.marklogic.client.DatabaseClientFactory.SSLHostnameVerifier;
import com.marklogic.client.MarkLogicIOException;
import com.marklogic.client.document.TextDocumentManager;
import com.marklogic.client.io.StringHandle;
import com.marklogic.client.test.Common;
import org.junit.jupiter.api.Test;

import javax.net.ssl.*;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.marklogic.client.test;
package com.marklogic.client.test.ssl;

import com.fasterxml.jackson.databind.node.ObjectNode;
import com.marklogic.client.DatabaseClient;
Expand All @@ -8,6 +8,7 @@
import com.marklogic.client.document.DocumentDescriptor;
import com.marklogic.client.eval.EvalResultIterator;
import com.marklogic.client.io.StringHandle;
import com.marklogic.client.test.Common;
import com.marklogic.client.test.junit5.RequireSSLExtension;
import com.marklogic.mgmt.ManageClient;
import com.marklogic.mgmt.resource.appservers.ServerManager;
Expand Down