From f4bc1cbbb112be1e5ba2df04c179f1a45a9928e9 Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Mon, 29 Apr 2024 18:28:47 -0600 Subject: [PATCH] Implement OIDC SASL mechanism (#1134) * Implement OIDC SASL mechanism in sync (#1107) JAVA-4980 * Implement OIDC auth for async (#1131) JAVA-4981 * Remove non-machine workflow (#1259) JAVA-5077 * Add Human OIDC Workflow (#1316) JAVA-5328 * OIDC Add remaining environments (azure, gcp), evergreen testing, API naming updates (#1371) JAVA-5353 JAVA-5395 JAVA-4834 JAVA-4932 Co-authored-by: Valentin Kovalenko --- .evergreen/.evg.yml | 152 ++- .evergreen/prepare-oidc-get-tokens-docker.sh | 50 + .evergreen/prepare-oidc-server-docker.sh | 50 + .evergreen/run-mongodb-oidc-test.sh | 40 + .../src/test/unit/util/ThreadTestHelpers.java | 12 +- .../com/mongodb/AuthenticationMechanism.java | 7 + .../main/com/mongodb/ConnectionString.java | 64 +- .../src/main/com/mongodb/MongoCredential.java | 270 +++- .../com/mongodb/assertions/Assertions.java | 36 - .../src/main/com/mongodb/internal/Locks.java | 16 + .../authentication/AzureCredentialHelper.java | 71 +- .../authentication/CredentialInfo.java | 44 + .../authentication/GcpCredentialHelper.java | 13 + .../internal/connection/Authenticator.java | 18 + .../internal/connection/AwsAuthenticator.java | 54 +- .../connection/InternalConnection.java | 2 +- .../connection/InternalStreamConnection.java | 118 +- .../InternalStreamConnectionFactory.java | 21 +- .../InternalStreamConnectionInitializer.java | 20 +- .../connection/MongoCredentialWithCache.java | 27 +- .../connection/OidcAuthenticator.java | 745 +++++++++++ .../connection/SaslAuthenticator.java | 66 +- .../connection/ScramShaAuthenticator.java | 48 +- .../com/mongodb/client/TestHelper.java | 47 + .../connection/TestCommandListener.java | 45 +- .../auth/{ => legacy}/connection-string.json | 187 +++ .../auth/mongodb-oidc-no-retry.json | 421 +++++++ .../com/mongodb/AuthConnectionStringTest.java | 45 +- .../ConnectionStringSpecification.groovy | 2 +- .../com/mongodb/ConnectionStringUnitTest.java | 47 + .../OidcAuthenticationAsyncProseTests.java | 68 + .../com/mongodb/client/unified/Entities.java | 59 +- .../mongodb/client/unified/ErrorMatcher.java | 18 +- .../unified/RunOnRequirementsMatcher.java | 14 +- .../client/unified/UnifiedAuthTest.java | 39 + .../mongodb/client/unified/UnifiedTest.java | 4 +- .../OidcAuthenticationProseTests.java | 1120 +++++++++++++++++ 37 files changed, 3825 insertions(+), 235 deletions(-) create mode 100755 .evergreen/prepare-oidc-get-tokens-docker.sh create mode 100755 .evergreen/prepare-oidc-server-docker.sh create mode 100755 .evergreen/run-mongodb-oidc-test.sh create mode 100644 driver-core/src/main/com/mongodb/internal/authentication/CredentialInfo.java create mode 100644 driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java create mode 100644 driver-core/src/test/functional/com/mongodb/client/TestHelper.java rename driver-core/src/test/resources/auth/{ => legacy}/connection-string.json (67%) create mode 100644 driver-core/src/test/resources/unified-test-format/auth/mongodb-oidc-no-retry.json create mode 100644 driver-reactive-streams/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationAsyncProseTests.java create mode 100644 driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedAuthTest.java create mode 100644 driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java diff --git a/.evergreen/.evg.yml b/.evergreen/.evg.yml index d35c01fd89f..886282b77c4 100644 --- a/.evergreen/.evg.yml +++ b/.evergreen/.evg.yml @@ -12,9 +12,8 @@ stepback: true # Actual testing tasks are marked with `type: test` command_type: system -# Protect ourself against rogue test case, or curl gone wild, that runs forever -# 12 minutes is the longest we'll ever run -exec_timeout_secs: 3600 # 12 minutes is the longest we'll ever run +# Protect ourselves against rogue test case, or curl gone wild, that runs forever +exec_timeout_secs: 3600 # What to do when evergreen hits the timeout (`post:` tasks are run automatically) timeout: @@ -968,6 +967,60 @@ tasks: - func: "run load-balancer" - func: "run load-balancer tests" + - name: "oidc-auth-test" + commands: + - command: subprocess.exec + type: test + params: + working_dir: "src" + binary: bash + include_expansions_in_env: ["DRIVERS_TOOLS", "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] + env: + OIDC_ENV: "test" + args: + - .evergreen/run-mongodb-oidc-test.sh + + - name: "oidc-auth-test-azure" + commands: + - command: shell.exec + params: + shell: bash + env: + JAVA_HOME: ${JAVA_HOME} + script: |- + set -o errexit + ${PREPARE_SHELL} + cd src + git add . + git commit -m "add files" + # uncompressed tar used to allow appending .git folder + export AZUREOIDC_DRIVERS_TAR_FILE=/tmp/mongo-java-driver.tar + git archive -o $AZUREOIDC_DRIVERS_TAR_FILE HEAD + tar -rf $AZUREOIDC_DRIVERS_TAR_FILE .git + export AZUREOIDC_TEST_CMD="OIDC_ENV=azure ./.evergreen/run-mongodb-oidc-test.sh" + bash $DRIVERS_TOOLS/.evergreen/auth_oidc/azure/run-driver-test.sh + + - name: "oidc-auth-test-gcp" + commands: + - command: shell.exec + params: + shell: bash + script: |- + set -o errexit + ${PREPARE_SHELL} + cd src + git add . + git commit -m "add files" + # uncompressed tar used to allow appending .git folder + export GCPOIDC_DRIVERS_TAR_FILE=/tmp/mongo-java-driver.tar + git archive -o $GCPOIDC_DRIVERS_TAR_FILE HEAD + tar -rf $GCPOIDC_DRIVERS_TAR_FILE .git + # Define the command to run on the VM. + # Ensure that we source the environment file created for us, set up any other variables we need, + # and then run our test suite on the vm. + export GCPOIDC_TEST_CMD="OIDC_ENV=gcp ./.evergreen/run-mongodb-oidc-test.sh" + bash $DRIVERS_TOOLS/.evergreen/auth_oidc/gcp/run-driver-test.sh + - name: serverless-test commands: - func: "run serverless" @@ -2065,6 +2118,78 @@ task_groups: tasks: - test-aws-lambda-deployed + - name: testoidc_task_group + setup_group: + - func: fetch source + - func: prepare resources + - func: fix absolute paths + - command: ec2.assume_role + params: + role_arn: ${aws_test_secrets_role} + - command: subprocess.exec + params: + binary: bash + include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] + args: + - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/setup.sh + teardown_task: + - command: subprocess.exec + params: + binary: bash + args: + - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/teardown.sh + setup_group_can_fail_task: true + setup_group_timeout_secs: 1800 + tasks: + - oidc-auth-test + + - name: testazureoidc_task_group + setup_group: + - func: fetch source + - func: prepare resources + - func: fix absolute paths + - command: subprocess.exec + params: + binary: bash + env: + AZUREOIDC_VMNAME_PREFIX: "JAVA_DRIVER" + args: + - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/azure/create-and-setup-vm.sh + teardown_task: + - command: subprocess.exec + params: + binary: bash + args: + - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/azure/delete-vm.sh + setup_group_can_fail_task: true + setup_group_timeout_secs: 1800 + tasks: + - oidc-auth-test-azure + + - name: testgcpoidc_task_group + setup_group: + - func: fetch source + - func: prepare resources + - func: fix absolute paths + - command: subprocess.exec + params: + binary: bash + env: + GCPOIDC_VMNAME_PREFIX: "JAVA_DRIVER" + GCPKMS_MACHINETYPE: "e2-medium" # comparable elapsed time to Azure; default was starved, caused timeouts + args: + - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/gcp/setup.sh + teardown_task: + - command: subprocess.exec + params: + binary: bash + args: + - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/gcp/teardown.sh + setup_group_can_fail_task: true + setup_group_timeout_secs: 1800 + tasks: + - oidc-auth-test-gcp + buildvariants: # Test packaging and other release related routines @@ -2216,6 +2341,27 @@ buildvariants: tasks: - name: "test_atlas_task_group_search_indexes" +- name: "oidc-auth-test" + display_name: "OIDC Auth" + run_on: ubuntu2204-small + tasks: + - name: testoidc_task_group + batchtime: 20160 # 14 days + +- name: testazureoidc-variant + display_name: "OIDC Auth Azure" + run_on: ubuntu2204-small + tasks: + - name: testazureoidc_task_group + batchtime: 20160 # 14 days + +- name: testgcpoidc-variant + display_name: "OIDC Auth GCP" + run_on: ubuntu2204-small + tasks: + - name: testgcpoidc_task_group + batchtime: 20160 # 14 days + - matrix_name: "aws-auth-test" matrix_spec: { ssl: "nossl", jdk: ["jdk8", "jdk17", "jdk21"], version: ["4.4", "5.0", "6.0", "7.0", "latest"], os: "ubuntu", aws-credential-provider: "*" } diff --git a/.evergreen/prepare-oidc-get-tokens-docker.sh b/.evergreen/prepare-oidc-get-tokens-docker.sh new file mode 100755 index 00000000000..e904d5d2b89 --- /dev/null +++ b/.evergreen/prepare-oidc-get-tokens-docker.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +set -o xtrace +set -o errexit # Exit the script with error if any of the commands fail + +############################################ +# Main Program # +############################################ + +# Supported/used environment variables: +# DRIVERS_TOOLS The path to evergreeen tools +# OIDC_AWS_* Required OIDC_AWS_* env variables must be configured +# +# Environment variables used as output: +# OIDC_TESTS_ENABLED Allows running OIDC tests +# OIDC_TOKEN_DIR The path to generated OIDC AWS tokens +# AWS_WEB_IDENTITY_TOKEN_FILE The path to AWS token for device workflow + +if [ -z ${DRIVERS_TOOLS+x} ]; then + echo "DRIVERS_TOOLS. is not set"; + exit 1 +fi + +if [ -z ${OIDC_AWS_ROLE_ARN+x} ]; then + echo "OIDC_AWS_ROLE_ARN. is not set"; + exit 1 +fi + +if [ -z ${OIDC_AWS_SECRET_ACCESS_KEY+x} ]; then + echo "OIDC_AWS_SECRET_ACCESS_KEY. is not set"; + exit 1 +fi + +if [ -z ${OIDC_AWS_ACCESS_KEY_ID+x} ]; then + echo "OIDC_AWS_ACCESS_KEY_ID. is not set"; + exit 1 +fi + +export AWS_ROLE_ARN=${OIDC_AWS_ROLE_ARN} +export AWS_SECRET_ACCESS_KEY=${OIDC_AWS_SECRET_ACCESS_KEY} +export AWS_ACCESS_KEY_ID=${OIDC_AWS_ACCESS_KEY_ID} +export OIDC_FOLDER=${DRIVERS_TOOLS}/.evergreen/auth_oidc +export OIDC_TOKEN_DIR=${OIDC_FOLDER}/test_tokens +export AWS_WEB_IDENTITY_TOKEN_FILE=${OIDC_TOKEN_DIR}/test1 +export OIDC_TESTS_ENABLED=true + +echo "Configuring OIDC server for local authentication tests" + +cd ${OIDC_FOLDER} +DRIVERS_TOOLS=${DRIVERS_TOOLS} ./oidc_get_tokens.sh \ No newline at end of file diff --git a/.evergreen/prepare-oidc-server-docker.sh b/.evergreen/prepare-oidc-server-docker.sh new file mode 100755 index 00000000000..0fcd1ed4194 --- /dev/null +++ b/.evergreen/prepare-oidc-server-docker.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +set -o xtrace +set -o errexit # Exit the script with error if any of the commands fail + +############################################ +# Main Program # +############################################ + +# Supported/used environment variables: +# DRIVERS_TOOLS The path to evergreeen tools +# OIDC_AWS_* OIDC_AWS_* env variables must be configured +# +# Environment variables used as output: +# OIDC_TESTS_ENABLED Allows running OIDC tests +# OIDC_TOKEN_DIR The path to generated tokens +# AWS_WEB_IDENTITY_TOKEN_FILE The path to AWS token for device workflow + +if [ -z ${DRIVERS_TOOLS+x} ]; then + echo "DRIVERS_TOOLS. is not set"; + exit 1 +fi + +if [ -z ${OIDC_AWS_ROLE_ARN+x} ]; then + echo "OIDC_AWS_ROLE_ARN. is not set"; + exit 1 +fi + +if [ -z ${OIDC_AWS_SECRET_ACCESS_KEY+x} ]; then + echo "OIDC_AWS_SECRET_ACCESS_KEY. is not set"; + exit 1 +fi + +if [ -z ${OIDC_AWS_ACCESS_KEY_ID+x} ]; then + echo "OIDC_AWS_ACCESS_KEY_ID. is not set"; + exit 1 +fi + +export AWS_ROLE_ARN=${OIDC_AWS_ROLE_ARN} +export AWS_SECRET_ACCESS_KEY=${OIDC_AWS_SECRET_ACCESS_KEY} +export AWS_ACCESS_KEY_ID=${OIDC_AWS_ACCESS_KEY_ID} +export OIDC_FOLDER=${DRIVERS_TOOLS}/.evergreen/auth_oidc +export OIDC_TOKEN_DIR=${OIDC_FOLDER}/test_tokens +export AWS_WEB_IDENTITY_TOKEN_FILE=${OIDC_TOKEN_DIR}/test1 +export OIDC_TESTS_ENABLED=true + +echo "Configuring OIDC server for local authentication tests" + +cd ${OIDC_FOLDER} +DRIVERS_TOOLS=${DRIVERS_TOOLS} ./start_local_server.sh \ No newline at end of file diff --git a/.evergreen/run-mongodb-oidc-test.sh b/.evergreen/run-mongodb-oidc-test.sh new file mode 100755 index 00000000000..1f5c1b310cc --- /dev/null +++ b/.evergreen/run-mongodb-oidc-test.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +set +x # Disable debug trace +set -eu + +echo "Running MONGODB-OIDC authentication tests" +echo "OIDC_ENV $OIDC_ENV" + +if [ $OIDC_ENV == "test" ]; then + if [ -z "$DRIVERS_TOOLS" ]; then + echo "Must specify DRIVERS_TOOLS" + exit 1 + fi + source ${DRIVERS_TOOLS}/.evergreen/auth_oidc/secrets-export.sh + # java will not need to be installed, but we need to config + RELATIVE_DIR_PATH="$(dirname "${BASH_SOURCE:-$0}")" + source "${RELATIVE_DIR_PATH}/javaConfig.bash" +elif [ $OIDC_ENV == "azure" ]; then + source ./env.sh +elif [ $OIDC_ENV == "gcp" ]; then + source ./secrets-export.sh +else + echo "Unrecognized OIDC_ENV $OIDC_ENV" + exit 1 +fi + + +if ! which java ; then + echo "Installing java..." + sudo apt install openjdk-17-jdk -y + echo "Installed java." +fi + +which java +export OIDC_TESTS_ENABLED=true + +./gradlew -Dorg.mongodb.test.uri="$MONGODB_URI" \ + --stacktrace --debug --info --no-build-cache driver-core:cleanTest \ + driver-sync:test --tests OidcAuthenticationProseTests --tests UnifiedAuthTest \ + driver-reactive-streams:test --tests OidcAuthenticationAsyncProseTests \ diff --git a/bson/src/test/unit/util/ThreadTestHelpers.java b/bson/src/test/unit/util/ThreadTestHelpers.java index a4767c503f9..e2115da079f 100644 --- a/bson/src/test/unit/util/ThreadTestHelpers.java +++ b/bson/src/test/unit/util/ThreadTestHelpers.java @@ -31,15 +31,19 @@ private ThreadTestHelpers() { } public static void executeAll(final int nThreads, final Runnable c) { + executeAll(Collections.nCopies(nThreads, c).toArray(new Runnable[0])); + } + + public static void executeAll(final Runnable... runnables) { ExecutorService service = null; try { - service = Executors.newFixedThreadPool(nThreads); - CountDownLatch latch = new CountDownLatch(nThreads); + service = Executors.newFixedThreadPool(runnables.length); + CountDownLatch latch = new CountDownLatch(runnables.length); List failures = Collections.synchronizedList(new ArrayList<>()); - for (int i = 0; i < nThreads; i++) { + for (final Runnable runnable : runnables) { service.submit(() -> { try { - c.run(); + runnable.run(); } catch (Throwable e) { failures.add(e); } finally { diff --git a/driver-core/src/main/com/mongodb/AuthenticationMechanism.java b/driver-core/src/main/com/mongodb/AuthenticationMechanism.java index db8a909b79d..7a7b7415ef6 100644 --- a/driver-core/src/main/com/mongodb/AuthenticationMechanism.java +++ b/driver-core/src/main/com/mongodb/AuthenticationMechanism.java @@ -37,6 +37,13 @@ public enum AuthenticationMechanism { */ MONGODB_AWS("MONGODB-AWS"), + /** + * The MONGODB-OIDC mechanism. + * @since 4.10 + * @mongodb.server.release 7.0 + */ + MONGODB_OIDC("MONGODB-OIDC"), + /** * The MongoDB X.509 mechanism. This mechanism is available only with client certificates over SSL. */ diff --git a/driver-core/src/main/com/mongodb/ConnectionString.java b/driver-core/src/main/com/mongodb/ConnectionString.java index e715b8983f6..34378d4069f 100644 --- a/driver-core/src/main/com/mongodb/ConnectionString.java +++ b/driver-core/src/main/com/mongodb/ConnectionString.java @@ -38,6 +38,7 @@ import java.net.URLDecoder; import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -47,7 +48,11 @@ import java.util.Objects; import java.util.Set; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import static com.mongodb.MongoCredential.ALLOWED_HOSTS_KEY; +import static com.mongodb.internal.connection.OidcAuthenticator.OidcValidator.validateCreateOidcCredential; import static java.lang.String.format; import static java.util.Arrays.asList; import static java.util.Collections.singletonList; @@ -225,9 +230,9 @@ * *

Authentication configuration:

*
    - *
  • {@code authMechanism=MONGO-CR|GSSAPI|PLAIN|MONGODB-X509}: The authentication mechanism to use if a credential was supplied. + *
  • {@code authMechanism=MONGO-CR|GSSAPI|PLAIN|MONGODB-X509|MONGODB-OIDC}: The authentication mechanism to use if a credential was supplied. * The default is unspecified, in which case the client will pick the most secure mechanism available based on the sever version. For the - * GSSAPI and MONGODB-X509 mechanisms, no password is accepted, only the username. + * GSSAPI, MONGODB-X509, and MONGODB-OIDC mechanisms, no password is accepted, only the username. *
  • *
  • {@code authSource=string}: The source of the authentication credentials. This is typically the database that * the credentials have been created. The value defaults to the database specified in the path portion of the connection string. @@ -235,7 +240,9 @@ * mechanism (the default). *
  • *
  • {@code authMechanismProperties=PROPERTY_NAME:PROPERTY_VALUE,PROPERTY_NAME2:PROPERTY_VALUE2}: This option allows authentication - * mechanism properties to be set on the connection string. + * mechanism properties to be set on the connection string. Property values must be percent-encoded individually, when + * special characters are used, including {@code ,} (comma), {@code =}, {@code +}, {@code &}, and {@code %}. The + * entire substring following the {@code =} should not itself be encoded. *
  • *
  • {@code gssapiServiceName=string}: This option only applies to the GSSAPI mechanism and is used to alter the service name. * Deprecated, please use {@code authMechanismProperties=SERVICE_NAME:string} instead. @@ -281,6 +288,9 @@ public class ConnectionString { private static final Set ALLOWED_OPTIONS_IN_TXT_RECORD = new HashSet<>(asList("authsource", "replicaset", "loadbalanced")); private static final Logger LOGGER = Loggers.getLogger("uri"); + private static final List MECHANISM_KEYS_DISALLOWED_IN_CONNECTION_STRING = Stream.of(ALLOWED_HOSTS_KEY) + .map(k -> k.toLowerCase()) + .collect(Collectors.toList()); private final MongoCredential credential; private final boolean isSrvProtocol; @@ -909,13 +919,21 @@ private MongoCredential createCredentials(final Map> option if (credential != null && authMechanismProperties != null) { for (String part : authMechanismProperties.split(",")) { - String[] mechanismPropertyKeyValue = part.split(":"); + String[] mechanismPropertyKeyValue = part.split(":", 2); if (mechanismPropertyKeyValue.length != 2) { throw new IllegalArgumentException(format("The connection string contains invalid authentication properties. " + "'%s' is not a key value pair", part)); } String key = mechanismPropertyKeyValue[0].trim().toLowerCase(); String value = mechanismPropertyKeyValue[1].trim(); + if (decodeValueOfKeyValuePair(credential.getMechanism())) { + value = urldecode(value); + } + if (MECHANISM_KEYS_DISALLOWED_IN_CONNECTION_STRING.contains(key)) { + throw new IllegalArgumentException(format("The connection string contains disallowed mechanism properties. " + + "'%s' must be set on the credential programmatically.", key)); + } + if (key.equals("canonicalize_host_name")) { credential = credential.withMechanismProperty(key, Boolean.valueOf(value)); } else { @@ -926,6 +944,27 @@ private MongoCredential createCredentials(final Map> option return credential; } + private static boolean decodeWholeOptionValue(final boolean isOidc, final String key) { + // The "whole option value" is the entire string following = in an option, + // including separators when the value is a list or list of key-values. + // This is the original parsing behaviour, but implies that users can + // encode separators (much like they might with URL parameters). This + // behaviour implies that users cannot encode "key-value" values that + // contain a comma, because this will (after this "whole value decoding) + // be parsed as a key-value separator, rather than part of a value. + return !(isOidc && key.equals("authmechanismproperties")); + } + + private static boolean decodeValueOfKeyValuePair(@Nullable final String mechanismName) { + // Only authMechanismProperties should be individually decoded, and only + // when the mechanism is OIDC. These will not have been decoded. + return AuthenticationMechanism.MONGODB_OIDC.getMechanismName().equals(mechanismName); + } + + private static boolean isOidc(final List options) { + return options.contains("authMechanism=" + AuthenticationMechanism.MONGODB_OIDC.getMechanismName()); + } + private MongoCredential createMongoCredentialWithMechanism(final AuthenticationMechanism mechanism, final String userName, @Nullable final char[] password, @Nullable final String authSource, @@ -975,6 +1014,10 @@ private MongoCredential createMongoCredentialWithMechanism(final AuthenticationM case MONGODB_AWS: credential = MongoCredential.createAwsCredential(userName, password); break; + case MONGODB_OIDC: + validateCreateOidcCredential(password); + credential = MongoCredential.createOidcCredential(userName); + break; default: throw new UnsupportedOperationException(format("The connection string contains an invalid authentication mechanism'. " + "'%s' is not a supported authentication mechanism", @@ -1002,12 +1045,14 @@ private String getLastValue(final Map> optionsMap, final St private Map> parseOptions(final String optionsPart) { Map> optionsMap = new HashMap<>(); - if (optionsPart.length() == 0) { + if (optionsPart.isEmpty()) { return optionsMap; } - for (final String part : optionsPart.split("&|;")) { - if (part.length() == 0) { + List options = Arrays.asList(optionsPart.split("&|;")); + boolean isOidc = isOidc(options); + for (final String part : options) { + if (part.isEmpty()) { continue; } int idx = part.indexOf("="); @@ -1018,7 +1063,10 @@ private Map> parseOptions(final String optionsPart) { if (valueList == null) { valueList = new ArrayList<>(1); } - valueList.add(urldecode(value)); + if (decodeWholeOptionValue(isOidc, key)) { + value = urldecode(value); + } + valueList.add(value); optionsMap.put(key, valueList); } else { throw new IllegalArgumentException(format("The connection string contains an invalid option '%s'. " diff --git a/driver-core/src/main/com/mongodb/MongoCredential.java b/driver-core/src/main/com/mongodb/MongoCredential.java index ffa2a3c4e02..e085ac074f0 100644 --- a/driver-core/src/main/com/mongodb/MongoCredential.java +++ b/driver-core/src/main/com/mongodb/MongoCredential.java @@ -17,22 +17,28 @@ package com.mongodb; import com.mongodb.annotations.Beta; +import com.mongodb.annotations.Evolving; import com.mongodb.annotations.Immutable; import com.mongodb.lang.Nullable; +import java.time.Duration; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Objects; import static com.mongodb.AuthenticationMechanism.GSSAPI; import static com.mongodb.AuthenticationMechanism.MONGODB_AWS; +import static com.mongodb.AuthenticationMechanism.MONGODB_OIDC; import static com.mongodb.AuthenticationMechanism.MONGODB_X509; import static com.mongodb.AuthenticationMechanism.PLAIN; import static com.mongodb.AuthenticationMechanism.SCRAM_SHA_1; import static com.mongodb.AuthenticationMechanism.SCRAM_SHA_256; import static com.mongodb.assertions.Assertions.notNull; +import static com.mongodb.internal.connection.OidcAuthenticator.OidcValidator.validateCreateOidcCredential; +import static com.mongodb.internal.connection.OidcAuthenticator.OidcValidator.validateOidcCredentialConstruction; /** * Represents credentials to authenticate to a mongo server,as well as the source of the credentials and the authentication mechanism to @@ -179,6 +185,94 @@ public final class MongoCredential { @Beta(Beta.Reason.CLIENT) public static final String AWS_CREDENTIAL_PROVIDER_KEY = "AWS_CREDENTIAL_PROVIDER"; + /** + * Mechanism property key for specifying the environment for OIDC, which is + * the name of a built-in OIDC application environment integration to use + * to obtain credentials. The value must be either "gcp" or "azure". + * This is an alternative to supplying a callback. + *

    + * The "gcp" and "azure" environments require + * {@link MongoCredential#TOKEN_RESOURCE_KEY} to be specified. + *

    + * If this is provided, + * {@link MongoCredential#OIDC_CALLBACK_KEY} and + * {@link MongoCredential#OIDC_HUMAN_CALLBACK_KEY} + * must not be provided. + * + * @see #createOidcCredential(String) + * @see MongoCredential#TOKEN_RESOURCE_KEY + * @since 5.1 + */ + public static final String ENVIRONMENT_KEY = "ENVIRONMENT"; + + /** + * Mechanism property key for the OIDC callback. + * This callback is invoked when the OIDC-based authenticator requests + * a token. The type of the value must be {@link OidcCallback}. + * {@link IdpInfo} will not be supplied to the callback, + * and a {@linkplain com.mongodb.MongoCredential.OidcCallbackResult#getRefreshToken() refresh token} + * must not be returned by the callback. + *

    + * If this is provided, {@link MongoCredential#ENVIRONMENT_KEY} + * and {@link MongoCredential#OIDC_HUMAN_CALLBACK_KEY} + * must not be provided. + * + * @see #createOidcCredential(String) + * @since 5.1 + */ + public static final String OIDC_CALLBACK_KEY = "OIDC_CALLBACK"; + + /** + * Mechanism property key for the OIDC human callback. + * This callback is invoked when the OIDC-based authenticator requests + * a token from the identity provider (IDP) using the IDP information + * from the MongoDB server. The type of the value must be + * {@link OidcCallback}. + *

    + * If this is provided, {@link MongoCredential#ENVIRONMENT_KEY} + * and {@link MongoCredential#OIDC_CALLBACK_KEY} + * must not be provided. + * + * @see #createOidcCredential(String) + * @since 5.1 + */ + public static final String OIDC_HUMAN_CALLBACK_KEY = "OIDC_HUMAN_CALLBACK"; + + + /** + * Mechanism property key for a list of allowed hostnames or ip-addresses for MongoDB connections. Ports must be excluded. + * The hostnames may include a leading "*." wildcard, which allows for matching (potentially nested) subdomains. + * When MONGODB-OIDC authentication is attempted against a hostname that does not match any of list of allowed hosts + * the driver will raise an error. The type of the value must be {@code List}. + * + * @see MongoCredential#DEFAULT_ALLOWED_HOSTS + * @see #createOidcCredential(String) + * @since 5.1 + */ + public static final String ALLOWED_HOSTS_KEY = "ALLOWED_HOSTS"; + + /** + * The list of allowed hosts that will be used if no + * {@link MongoCredential#ALLOWED_HOSTS_KEY} value is supplied. + * The default allowed hosts are: + * {@code "*.mongodb.net", "*.mongodb-qa.net", "*.mongodb-dev.net", "*.mongodbgov.net", "localhost", "127.0.0.1", "::1"} + * + * @see #createOidcCredential(String) + * @since 5.1 + */ + public static final List DEFAULT_ALLOWED_HOSTS = Collections.unmodifiableList(Arrays.asList( + "*.mongodb.net", "*.mongodb-qa.net", "*.mongodb-dev.net", "*.mongodbgov.net", "localhost", "127.0.0.1", "::1")); + + /** + * Mechanism property key for specifying he URI of the target resource (sometimes called the audience), + * used in some OIDC environments. + * + * @see MongoCredential#ENVIRONMENT_KEY + * @see #createOidcCredential(String) + * @since 5.1 + */ + public static final String TOKEN_RESOURCE_KEY = "TOKEN_RESOURCE"; + /** * Creates a MongoCredential instance with an unspecified mechanism. The client will negotiate the best mechanism based on the * version of the server that the client is authenticating to. @@ -327,6 +421,24 @@ public static MongoCredential createAwsCredential(@Nullable final String userNam return new MongoCredential(MONGODB_AWS, userName, "$external", password); } + /** + * Creates a MongoCredential instance for the MONGODB-OIDC mechanism. + * + * @param userName the user name, which may be null. This is the OIDC principal name. + * @return the credential + * @since 5.1 + * @see #withMechanismProperty(String, Object) + * @see #ENVIRONMENT_KEY + * @see #TOKEN_RESOURCE_KEY + * @see #OIDC_CALLBACK_KEY + * @see #OIDC_HUMAN_CALLBACK_KEY + * @see #ALLOWED_HOSTS_KEY + * @mongodb.server.release 7.0 + */ + public static MongoCredential createOidcCredential(@Nullable final String userName) { + return new MongoCredential(MONGODB_OIDC, userName, "$external", null); + } + /** * Creates a new MongoCredential as a copy of this instance, with the specified mechanism property added. * @@ -370,7 +482,12 @@ public MongoCredential withMechanism(final AuthenticationMechanism mechanism) { MongoCredential(@Nullable final AuthenticationMechanism mechanism, @Nullable final String userName, final String source, @Nullable final char[] password, final Map mechanismProperties) { - if (userName == null && !Arrays.asList(MONGODB_X509, MONGODB_AWS).contains(mechanism)) { + if (mechanism == MONGODB_OIDC) { + validateOidcCredentialConstruction(source, mechanismProperties); + validateCreateOidcCredential(password); + } + + if (userName == null && !Arrays.asList(MONGODB_X509, MONGODB_AWS, MONGODB_OIDC).contains(mechanism)) { throw new IllegalArgumentException("username can not be null"); } @@ -543,4 +660,155 @@ public String toString() { + ", mechanismProperties=" + '}'; } + + /** + * The context for the {@link OidcCallback#onRequest(OidcCallbackContext) OIDC request callback}. + * + * @since 5.1 + */ + @Evolving + public interface OidcCallbackContext { + /** + * @return Convenience method to obtain the {@linkplain MongoCredential#getUserName() username}. + */ + @Nullable + String getUserName(); + + /** + * @return The timeout that this callback must complete within. + */ + Duration getTimeout(); + + /** + * @return The OIDC callback API version. Currently, version 1. + */ + int getVersion(); + + /** + * @return The OIDC Identity Provider's configuration that can be used + * to acquire an Access Token, or null if not using a + * {@linkplain MongoCredential#OIDC_HUMAN_CALLBACK_KEY human callback.} + */ + @Nullable + IdpInfo getIdpInfo(); + + /** + * @return The OIDC Refresh token supplied by a prior callback invocation, + * or null if no token was supplied, or if not using a + * {@linkplain MongoCredential#OIDC_HUMAN_CALLBACK_KEY human callback.} + */ + @Nullable + String getRefreshToken(); + } + + /** + * This callback is invoked when the OIDC-based authenticator requests + * tokens from the identity provider. + *

    + * It does not have to be thread-safe, unless it is provided to multiple + * MongoClients. + * + * @since 5.1 + */ + public interface OidcCallback { + /** + * @param context The context. + * @return The response produced by an OIDC Identity Provider + */ + OidcCallbackResult onRequest(OidcCallbackContext context); + } + + /** + * The OIDC Identity Provider's configuration that can be used to acquire an Access Token. + * + * @since 5.1 + */ + @Evolving + public interface IdpInfo { + /** + * @return URL which describes the Authorization Server. This identifier is the + * iss of provided access tokens, and is viable for RFC8414 metadata + * discovery and RFC9207 identification. + */ + String getIssuer(); + + /** + * @return Unique client ID for this OIDC client. + */ + @Nullable + String getClientId(); + + /** + * @return Additional scopes to request from Identity Provider. Immutable. + */ + List getRequestScopes(); + } + + /** + * The OIDC credential information. + * + * @since 5.1 + */ + public static final class OidcCallbackResult { + + private final String accessToken; + + private final Duration expiresIn; + + @Nullable + private final String refreshToken; + + + /** + * An access token that does not expire. + * @param accessToken The OIDC access token. + */ + public OidcCallbackResult(final String accessToken) { + this(accessToken, Duration.ZERO, null); + } + + /** + * @param accessToken The OIDC access token. + * @param expiresIn Time until the access token expires. + * A {@linkplain Duration#isZero() zero-length} duration + * means that the access token does not expire. + */ + public OidcCallbackResult(final String accessToken, final Duration expiresIn) { + this(accessToken, expiresIn, null); + } + + /** + * @param accessToken The OIDC access token. + * @param expiresIn Time until the access token expires. + * A {@linkplain Duration#isZero() zero-length} duration + * means that the access token does not expire. + * @param refreshToken The refresh token. If null, refresh will not be attempted. + */ + public OidcCallbackResult(final String accessToken, final Duration expiresIn, + @Nullable final String refreshToken) { + notNull("accessToken", accessToken); + notNull("expiresIn", expiresIn); + if (expiresIn.isNegative()) { + throw new IllegalArgumentException("expiresIn must not be a negative value"); + } + this.accessToken = accessToken; + this.expiresIn = expiresIn; + this.refreshToken = refreshToken; + } + + /** + * @return The OIDC access token. + */ + public String getAccessToken() { + return accessToken; + } + + /** + * @return The OIDC refresh token. If null, refresh will not be attempted. + */ + @Nullable + public String getRefreshToken() { + return refreshToken; + } + } } diff --git a/driver-core/src/main/com/mongodb/assertions/Assertions.java b/driver-core/src/main/com/mongodb/assertions/Assertions.java index ae30c179e85..9866c222c6d 100644 --- a/driver-core/src/main/com/mongodb/assertions/Assertions.java +++ b/driver-core/src/main/com/mongodb/assertions/Assertions.java @@ -17,7 +17,6 @@ package com.mongodb.assertions; -import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.lang.Nullable; import java.util.Collection; @@ -79,25 +78,6 @@ public static Iterable notNullElements(final String name, final Iterable< return values; } - /** - * Throw IllegalArgumentException if the value is null. - * - * @param name the parameter name - * @param value the value that should not be null - * @param callback the callback that also is passed the exception if the value is null - * @param the value type - * @return the value - * @throws java.lang.IllegalArgumentException if value is null - */ - public static T notNull(final String name, final T value, final SingleResultCallback callback) { - if (value == null) { - IllegalArgumentException exception = new IllegalArgumentException(name + " can not be null"); - callback.completeExceptionally(exception); - throw exception; - } - return value; - } - /** * Throw IllegalStateException if the condition if false. * @@ -111,22 +91,6 @@ public static void isTrue(final String name, final boolean condition) { } } - /** - * Throw IllegalStateException if the condition if false. - * - * @param name the name of the state that is being checked - * @param condition the condition about the parameter to check - * @param callback the callback that also is passed the exception if the condition is not true - * @throws java.lang.IllegalStateException if the condition is false - */ - public static void isTrue(final String name, final boolean condition, final SingleResultCallback callback) { - if (!condition) { - IllegalStateException exception = new IllegalStateException("state should be: " + name); - callback.completeExceptionally(exception); - throw exception; - } - } - /** * Throw IllegalArgumentException if the condition if false. * diff --git a/driver-core/src/main/com/mongodb/internal/Locks.java b/driver-core/src/main/com/mongodb/internal/Locks.java index 51583bfd56f..984de156f27 100644 --- a/driver-core/src/main/com/mongodb/internal/Locks.java +++ b/driver-core/src/main/com/mongodb/internal/Locks.java @@ -20,6 +20,7 @@ import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; +import java.util.concurrent.locks.StampedLock; import java.util.function.Supplier; import static com.mongodb.internal.thread.InterruptionUtil.interruptAndCreateMongoInterruptedException; @@ -35,6 +36,21 @@ public static void withLock(final Lock lock, final Runnable action) { }); } + public static void withInterruptibleLock(final StampedLock lock, final Runnable runnable) throws MongoInterruptedException{ + long stamp; + try { + stamp = lock.writeLockInterruptibly(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new MongoInterruptedException("Interrupted waiting for lock", e); + } + try { + runnable.run(); + } finally { + lock.unlockWrite(stamp); + } + } + public static V withLock(final Lock lock, final Supplier supplier) { return checkedWithLock(lock, supplier::get); } diff --git a/driver-core/src/main/com/mongodb/internal/authentication/AzureCredentialHelper.java b/driver-core/src/main/com/mongodb/internal/authentication/AzureCredentialHelper.java index 7c75e397d2a..2a48b8b6fc3 100644 --- a/driver-core/src/main/com/mongodb/internal/authentication/AzureCredentialHelper.java +++ b/driver-core/src/main/com/mongodb/internal/authentication/AzureCredentialHelper.java @@ -18,10 +18,13 @@ import com.mongodb.MongoClientException; import com.mongodb.internal.ExpirableValue; +import com.mongodb.lang.Nullable; import org.bson.BsonDocument; import org.bson.BsonString; import org.bson.json.JsonParseException; +import java.io.UnsupportedEncodingException; +import java.net.URLEncoder; import java.time.Duration; import java.util.HashMap; import java.util.Map; @@ -55,33 +58,11 @@ public static BsonDocument obtainFromEnvironment() { if (cachedValue.isPresent()) { accessToken = cachedValue.get(); } else { - String endpoint = "http://" + "169.254.169.254:80" - + "/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https://vault.azure.net"; - - Map headers = new HashMap<>(); - headers.put("Metadata", "true"); - headers.put("Accept", "application/json"); - long startNanoTime = System.nanoTime(); - BsonDocument responseDocument; - try { - responseDocument = BsonDocument.parse(getHttpContents("GET", endpoint, headers)); - } catch (JsonParseException e) { - throw new MongoClientException("Exception parsing JSON from Azure IMDS metadata response.", e); - } - - if (!responseDocument.isString(ACCESS_TOKEN_FIELD)) { - throw new MongoClientException(String.format( - "The %s field from Azure IMDS metadata response is missing or is not a string", ACCESS_TOKEN_FIELD)); - } - if (!responseDocument.isString(EXPIRES_IN_FIELD)) { - throw new MongoClientException(String.format( - "The %s field from Azure IMDS metadata response is missing or is not a string", EXPIRES_IN_FIELD)); - } - accessToken = responseDocument.getString(ACCESS_TOKEN_FIELD).getValue(); - int expiresInSeconds = Integer.parseInt(responseDocument.getString(EXPIRES_IN_FIELD).getValue()); - cachedAccessToken = ExpirableValue.expirable(accessToken, Duration.ofSeconds(expiresInSeconds).minus(Duration.ofMinutes(1)), - startNanoTime); + CredentialInfo response = fetchAzureCredentialInfo("https://vault.azure.net", null); + accessToken = response.getAccessToken(); + Duration duration = response.getExpiresIn().minus(Duration.ofMinutes(1)); + cachedAccessToken = ExpirableValue.expirable(accessToken, duration, startNanoTime); } } finally { CACHED_ACCESS_TOKEN_LOCK.unlock(); @@ -90,6 +71,44 @@ public static BsonDocument obtainFromEnvironment() { return new BsonDocument("accessToken", new BsonString(accessToken)); } + public static CredentialInfo fetchAzureCredentialInfo(final String resource, @Nullable final String clientId) { + String endpoint = "http://169.254.169.254:80" + + "/metadata/identity/oauth2/token?api-version=2018-02-01" + + "&resource=" + getEncoded(resource) + + (clientId == null ? "" : "&client_id=" + getEncoded(clientId)); + + Map headers = new HashMap<>(); + headers.put("Metadata", "true"); + headers.put("Accept", "application/json"); + + BsonDocument responseDocument; + try { + responseDocument = BsonDocument.parse(getHttpContents("GET", endpoint, headers)); + } catch (JsonParseException e) { + throw new MongoClientException("Exception parsing JSON from Azure IMDS metadata response.", e); + } + + if (!responseDocument.isString(ACCESS_TOKEN_FIELD)) { + throw new MongoClientException(String.format( + "The %s field from Azure IMDS metadata response is missing or is not a string", ACCESS_TOKEN_FIELD)); + } + if (!responseDocument.isString(EXPIRES_IN_FIELD)) { + throw new MongoClientException(String.format( + "The %s field from Azure IMDS metadata response is missing or is not a string", EXPIRES_IN_FIELD)); + } + String accessToken = responseDocument.getString(ACCESS_TOKEN_FIELD).getValue(); + int expiresInSeconds = Integer.parseInt(responseDocument.getString(EXPIRES_IN_FIELD).getValue()); + return new CredentialInfo(accessToken, Duration.ofSeconds(expiresInSeconds)); + } + + static String getEncoded(final String resource) { + try { + return URLEncoder.encode(resource, "UTF-8"); + } catch (UnsupportedEncodingException e) { + throw new RuntimeException(e); + } + } + private AzureCredentialHelper() { } } diff --git a/driver-core/src/main/com/mongodb/internal/authentication/CredentialInfo.java b/driver-core/src/main/com/mongodb/internal/authentication/CredentialInfo.java new file mode 100644 index 00000000000..8b1e601b13a --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/authentication/CredentialInfo.java @@ -0,0 +1,44 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.authentication; + +import java.time.Duration; + +/** + *

    This class is not part of the public API and may be removed or changed at any time

    + */ +public final class CredentialInfo { + private final String accessToken; + private final Duration expiresIn; + + /** + * @param expiresIn The meaning of {@linkplain Duration#isZero() zero-length} duration is the same as in + * {@link com.mongodb.MongoCredential.OidcCallbackResult#OidcCallbackResult(String, Duration)}. + */ + public CredentialInfo(final String accessToken, final Duration expiresIn) { + this.accessToken = accessToken; + this.expiresIn = expiresIn; + } + + public String getAccessToken() { + return accessToken; + } + + public Duration getExpiresIn() { + return expiresIn; + } +} diff --git a/driver-core/src/main/com/mongodb/internal/authentication/GcpCredentialHelper.java b/driver-core/src/main/com/mongodb/internal/authentication/GcpCredentialHelper.java index 92b3fdd6040..3f0272da48c 100644 --- a/driver-core/src/main/com/mongodb/internal/authentication/GcpCredentialHelper.java +++ b/driver-core/src/main/com/mongodb/internal/authentication/GcpCredentialHelper.java @@ -19,9 +19,11 @@ import com.mongodb.MongoClientException; import org.bson.BsonDocument; +import java.time.Duration; import java.util.HashMap; import java.util.Map; +import static com.mongodb.internal.authentication.AzureCredentialHelper.getEncoded; import static com.mongodb.internal.authentication.HttpHelper.getHttpContents; /** @@ -44,6 +46,17 @@ public static BsonDocument obtainFromEnvironment() { } } + public static CredentialInfo fetchGcpCredentialInfo(final String audience) { + String endpoint = "http://metadata/computeMetadata/v1/instance/service-accounts/default/identity?audience=" + + getEncoded(audience); + Map header = new HashMap<>(); + header.put("Metadata-Flavor", "Google"); + String response = getHttpContents("GET", endpoint, header); + return new CredentialInfo( + response, + Duration.ZERO); + } + private GcpCredentialHelper() { } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java b/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java index 9ec4780d958..232eeb45049 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java @@ -21,11 +21,13 @@ import com.mongodb.ServerApi; import com.mongodb.connection.ClusterConnectionMode; import com.mongodb.connection.ConnectionDescription; +import com.mongodb.connection.ServerType; import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.lang.NonNull; import com.mongodb.lang.Nullable; import static com.mongodb.assertions.Assertions.notNull; +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; /** *

    This class is not part of the public API and may be removed or changed at any time

    @@ -42,6 +44,11 @@ public abstract class Authenticator { this.serverApi = serverApi; } + public static boolean shouldAuthenticate(@Nullable final Authenticator authenticator, + final ConnectionDescription connectionDescription) { + return authenticator != null && connectionDescription.getServerType() != ServerType.REPLICA_SET_ARBITER; + } + @NonNull MongoCredentialWithCache getMongoCredentialWithCache() { return credential; @@ -93,4 +100,15 @@ T getNonNullMechanismProperty(final String key, @Nullable final T defaultVal abstract void authenticateAsync(InternalConnection connection, ConnectionDescription connectionDescription, SingleResultCallback callback); + + public void reauthenticate(final InternalConnection connection) { + authenticate(connection, connection.getDescription()); + } + + public void reauthenticateAsync(final InternalConnection connection, final SingleResultCallback callback) { + beginAsync().thenRun((c) -> { + authenticateAsync(connection, connection.getDescription(), c); + }).finish(callback); + } + } diff --git a/driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java index ec0fc3f9c8f..35f9f8120ee 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java @@ -16,7 +16,6 @@ package com.mongodb.internal.connection; -import com.mongodb.AuthenticationMechanism; import com.mongodb.AwsCredential; import com.mongodb.MongoClientException; import com.mongodb.MongoCredential; @@ -27,14 +26,10 @@ import com.mongodb.internal.authentication.AwsCredentialHelper; import com.mongodb.lang.Nullable; import org.bson.BsonBinary; -import org.bson.BsonBinaryWriter; import org.bson.BsonDocument; import org.bson.BsonInt32; import org.bson.BsonString; import org.bson.RawBsonDocument; -import org.bson.codecs.BsonDocumentCodec; -import org.bson.codecs.EncoderContext; -import org.bson.io.BasicOutputBuffer; import javax.security.sasl.SaslClient; import javax.security.sasl.SaslException; @@ -77,27 +72,12 @@ protected SaslClient createSaslClient(final ServerAddress serverAddress) { return new AwsSaslClient(getMongoCredential()); } - private static class AwsSaslClient implements SaslClient { - private final MongoCredential credential; + private static class AwsSaslClient extends SaslClientImpl { private final byte[] clientNonce = new byte[RANDOM_LENGTH]; private int step = -1; AwsSaslClient(final MongoCredential credential) { - this.credential = credential; - } - - @Override - public String getMechanismName() { - AuthenticationMechanism authMechanism = credential.getAuthenticationMechanism(); - if (authMechanism == null) { - throw new IllegalArgumentException("Authentication mechanism cannot be null"); - } - return authMechanism.getMechanismName(); - } - - @Override - public boolean hasInitialResponse() { - return true; + super(credential); } @Override @@ -117,26 +97,6 @@ public boolean isComplete() { return step == 1; } - @Override - public byte[] unwrap(final byte[] bytes, final int i, final int i1) { - throw new UnsupportedOperationException("Not implemented yet!"); - } - - @Override - public byte[] wrap(final byte[] bytes, final int i, final int i1) { - throw new UnsupportedOperationException("Not implemented yet!"); - } - - @Override - public Object getNegotiatedProperty(final String s) { - throw new UnsupportedOperationException("Not implemented yet!"); - } - - @Override - public void dispose() { - // nothing to do - } - private byte[] computeClientFirstMessage() { new SecureRandom().nextBytes(this.clientNonce); @@ -184,6 +144,7 @@ private byte[] computeClientFinalMessage(final byte[] serverFirst) throws SaslEx private AwsCredential createAwsCredential() { AwsCredential awsCredential; + MongoCredential credential = getCredential(); if (credential.getUserName() != null) { if (credential.getPassword() == null) { throw new MongoClientException("secretAccessKey is required for AWS credential"); @@ -207,13 +168,4 @@ private AwsCredential createAwsCredential() { return awsCredential; } } - - private static byte[] toBson(final BsonDocument document) { - byte[] bytes; - BasicOutputBuffer buffer = new BasicOutputBuffer(); - new BsonDocumentCodec().encode(new BsonBinaryWriter(buffer), document, EncoderContext.builder().build()); - bytes = new byte[buffer.size()]; - System.arraycopy(buffer.getInternalBuffer(), 0, bytes, 0, buffer.getSize()); - return bytes; - } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalConnection.java b/driver-core/src/main/com/mongodb/internal/connection/InternalConnection.java index 405ef31f5cf..e2b0188572e 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalConnection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalConnection.java @@ -49,7 +49,7 @@ public interface InternalConnection extends BufferProvider { ServerDescription getInitialServerDescription(); /** - * Opens the connection so its ready for use + * Opens the connection so its ready for use. Will perform a handshake. */ void open(); diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java index dec5a1d1977..218835f083e 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java @@ -18,6 +18,7 @@ import com.mongodb.LoggerSettings; import com.mongodb.MongoClientException; +import com.mongodb.MongoCommandException; import com.mongodb.MongoCompressor; import com.mongodb.MongoException; import com.mongodb.MongoInternalException; @@ -41,6 +42,7 @@ import com.mongodb.event.CommandListener; import com.mongodb.internal.ResourceUtil; import com.mongodb.internal.VisibleForTesting; +import com.mongodb.internal.async.AsyncSupplier; import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.diagnostics.logging.Logger; import com.mongodb.internal.diagnostics.logging.Loggers; @@ -64,11 +66,15 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; import static com.mongodb.assertions.Assertions.assertNotNull; +import static com.mongodb.assertions.Assertions.assertNull; import static com.mongodb.assertions.Assertions.isTrue; import static com.mongodb.assertions.Assertions.notNull; +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; import static com.mongodb.internal.async.ErrorHandlingResultCallback.errorHandlingCallback; +import static com.mongodb.internal.connection.Authenticator.shouldAuthenticate; import static com.mongodb.internal.connection.CommandHelper.HELLO; import static com.mongodb.internal.connection.CommandHelper.LEGACY_HELLO; import static com.mongodb.internal.connection.CommandHelper.LEGACY_HELLO_LOWER; @@ -92,6 +98,19 @@ @NotThreadSafe public class InternalStreamConnection implements InternalConnection { + private static volatile boolean recordEverything = false; + + /** + * Will attempt to record events to the command listener that are usually + * suppressed. + * + * @param recordEverything whether to attempt to record everything + */ + @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) + public static void setRecordEverything(final boolean recordEverything) { + InternalStreamConnection.recordEverything = recordEverything; + } + private static final Set SECURITY_SENSITIVE_COMMANDS = new HashSet<>(asList( "authenticate", "saslStart", @@ -111,6 +130,8 @@ public class InternalStreamConnection implements InternalConnection { private static final Logger LOGGER = Loggers.getLogger("connection"); private final ClusterConnectionMode clusterConnectionMode; + @Nullable + private final Authenticator authenticator; private final boolean isMonitoringConnection; private final ServerId serverId; private final ConnectionGenerationSupplier connectionGenerationSupplier; @@ -122,6 +143,7 @@ public class InternalStreamConnection implements InternalConnection { private final AtomicBoolean isClosed = new AtomicBoolean(); private final AtomicBoolean opened = new AtomicBoolean(); + private final AtomicBoolean authenticated = new AtomicBoolean(); private final List compressorList; private final LoggerSettings loggerSettings; @@ -147,17 +169,20 @@ public InternalStreamConnection(final ClusterConnectionMode clusterConnectionMod final ConnectionGenerationSupplier connectionGenerationSupplier, final StreamFactory streamFactory, final List compressorList, final CommandListener commandListener, final InternalConnectionInitializer connectionInitializer) { - this(clusterConnectionMode, false, serverId, connectionGenerationSupplier, streamFactory, compressorList, + this(clusterConnectionMode, null, false, serverId, connectionGenerationSupplier, streamFactory, compressorList, LoggerSettings.builder().build(), commandListener, connectionInitializer); } - public InternalStreamConnection(final ClusterConnectionMode clusterConnectionMode, final boolean isMonitoringConnection, + public InternalStreamConnection(final ClusterConnectionMode clusterConnectionMode, + @Nullable final Authenticator authenticator, + final boolean isMonitoringConnection, final ServerId serverId, final ConnectionGenerationSupplier connectionGenerationSupplier, final StreamFactory streamFactory, final List compressorList, final LoggerSettings loggerSettings, final CommandListener commandListener, final InternalConnectionInitializer connectionInitializer) { this.clusterConnectionMode = clusterConnectionMode; + this.authenticator = authenticator; this.isMonitoringConnection = isMonitoringConnection; this.serverId = notNull("serverId", serverId); this.connectionGenerationSupplier = notNull("connectionGeneration", connectionGenerationSupplier); @@ -217,7 +242,7 @@ public void open() { @Override public void openAsync(final SingleResultCallback callback) { - isTrue("Open already called", stream == null, callback); + assertNull(stream); try { stream = streamFactory.create(serverId.getAddress()); stream.openAsync(new AsyncCompletionHandler() { @@ -271,6 +296,7 @@ private void initAfterHandshakeFinish(final InternalConnectionInitializationDesc description = initializationDescription.getConnectionDescription(); initialServerDescription = initializationDescription.getServerDescription(); opened.set(true); + authenticated.set(true); sendCompressor = findSendCompressor(description); } @@ -336,8 +362,66 @@ public boolean isClosed() { @Override public T sendAndReceive(final CommandMessage message, final Decoder decoder, final SessionContext sessionContext, final RequestContext requestContext, final OperationContext operationContext) { - CommandEventSender commandEventSender; + Supplier sendAndReceiveInternal = () -> sendAndReceiveInternal( + message, decoder, sessionContext, requestContext, operationContext); + try { + return sendAndReceiveInternal.get(); + } catch (MongoCommandException e) { + if (reauthenticationIsTriggered(e)) { + return reauthenticateAndRetry(sendAndReceiveInternal); + } + throw e; + } + } + + @Override + public void sendAndReceiveAsync(final CommandMessage message, final Decoder decoder, final SessionContext sessionContext, + final RequestContext requestContext, final OperationContext operationContext, final SingleResultCallback callback) { + + AsyncSupplier sendAndReceiveAsyncInternal = c -> sendAndReceiveAsyncInternal( + message, decoder, sessionContext, requestContext, operationContext, c); + beginAsync().thenSupply(c -> { + sendAndReceiveAsyncInternal.getAsync(c); + }).onErrorIf(e -> reauthenticationIsTriggered(e), (t, c) -> { + reauthenticateAndRetryAsync(sendAndReceiveAsyncInternal, c); + }).finish(callback); + } + + private T reauthenticateAndRetry(final Supplier operation) { + authenticated.set(false); + assertNotNull(authenticator).reauthenticate(this); + authenticated.set(true); + return operation.get(); + } + + private void reauthenticateAndRetryAsync(final AsyncSupplier operation, + final SingleResultCallback callback) { + beginAsync().thenRun(c -> { + authenticated.set(false); + assertNotNull(authenticator).reauthenticateAsync(this, c); + }).thenSupply((c) -> { + authenticated.set(true); + operation.getAsync(c); + }).finish(callback); + } + + public boolean reauthenticationIsTriggered(@Nullable final Throwable t) { + if (!shouldAuthenticate(authenticator, this.description)) { + return false; + } + if (t instanceof MongoCommandException) { + MongoCommandException e = (MongoCommandException) t; + return e.getErrorCode() == 391; + } + return false; + } + + @Nullable + private T sendAndReceiveInternal(final CommandMessage message, final Decoder decoder, + final SessionContext sessionContext, final RequestContext requestContext, + final OperationContext operationContext) { + CommandEventSender commandEventSender; try (ByteBufferBsonOutput bsonOutput = new ByteBufferBsonOutput(this)) { message.encode(bsonOutput, sessionContext); commandEventSender = createCommandEventSender(message, bsonOutput, requestContext, operationContext); @@ -449,14 +533,11 @@ private T receiveCommandMessageResponse(final Decoder decoder, commandEventSender.sendFailedEvent(e); } throw e; - } + } } - @Override - public void sendAndReceiveAsync(final CommandMessage message, final Decoder decoder, final SessionContext sessionContext, + private void sendAndReceiveAsyncInternal(final CommandMessage message, final Decoder decoder, final SessionContext sessionContext, final RequestContext requestContext, final OperationContext operationContext, final SingleResultCallback callback) { - notNull("stream is open", stream, callback); - if (isClosed()) { callback.onResult(null, new MongoSocketClosedException("Can not read from a closed socket", getServerAddress())); return; @@ -567,7 +648,7 @@ public void sendMessage(final List byteBuffers, final int lastRequestId @Override public ResponseBuffers receiveMessage(final int responseTo) { - notNull("stream is open", stream); + assertNotNull(stream); if (isClosed()) { throw new MongoSocketClosedException("Cannot read from a closed stream", getServerAddress()); } @@ -585,8 +666,9 @@ private ResponseBuffers receiveMessageWithAdditionalTimeout(final int additional } @Override - public void sendMessageAsync(final List byteBuffers, final int lastRequestId, final SingleResultCallback callback) { - notNull("stream is open", stream, callback); + public void sendMessageAsync(final List byteBuffers, final int lastRequestId, + final SingleResultCallback callback) { + assertNotNull(stream); if (isClosed()) { callback.onResult(null, new MongoSocketClosedException("Can not read from a closed socket", getServerAddress())); @@ -618,7 +700,7 @@ public void failed(final Throwable t) { @Override public void receiveMessageAsync(final int responseTo, final SingleResultCallback callback) { - isTrue("stream is open", stream != null, callback); + assertNotNull(stream); if (isClosed()) { callback.onResult(null, new MongoSocketClosedException("Can not read from a closed socket", getServerAddress())); @@ -839,12 +921,14 @@ public void onResult(@Nullable final ByteBuf result, @Nullable final Throwable t private CommandEventSender createCommandEventSender(final CommandMessage message, final ByteBufferBsonOutput bsonOutput, final RequestContext requestContext, final OperationContext operationContext) { - if (!isMonitoringConnection && opened() && (commandListener != null || COMMAND_PROTOCOL_LOGGER.isRequired(DEBUG, getClusterId()))) { - return new LoggingCommandEventSender(SECURITY_SENSITIVE_COMMANDS, SECURITY_SENSITIVE_HELLO_COMMANDS, description, - commandListener, requestContext, operationContext, message, bsonOutput, COMMAND_PROTOCOL_LOGGER, loggerSettings); - } else { + boolean listensOrLogs = commandListener != null || COMMAND_PROTOCOL_LOGGER.isRequired(DEBUG, getClusterId()); + if (!recordEverything && (isMonitoringConnection || !opened() || !authenticated.get() || !listensOrLogs)) { return new NoOpCommandEventSender(); } + return new LoggingCommandEventSender( + SECURITY_SENSITIVE_COMMANDS, SECURITY_SENSITIVE_HELLO_COMMANDS, description, commandListener, + requestContext, operationContext, message, bsonOutput, + COMMAND_PROTOCOL_LOGGER, loggerSettings); } private ClusterId getClusterId() { diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionFactory.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionFactory.java index 6cf2453c187..8b5c840c501 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionFactory.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionFactory.java @@ -16,6 +16,7 @@ package com.mongodb.internal.connection; +import com.mongodb.AuthenticationMechanism; import com.mongodb.LoggerSettings; import com.mongodb.MongoCompressor; import com.mongodb.MongoDriverInformation; @@ -28,7 +29,6 @@ import java.util.List; -import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.assertions.Assertions.notNull; import static com.mongodb.internal.connection.ClientMetadataHelper.createClientMetadataDocument; @@ -74,18 +74,21 @@ class InternalStreamConnectionFactory implements InternalConnectionFactory { @Override public InternalConnection create(final ServerId serverId, final ConnectionGenerationSupplier connectionGenerationSupplier) { Authenticator authenticator = credential == null ? null : createAuthenticator(credential); - return new InternalStreamConnection(clusterConnectionMode, isMonitoringConnection, serverId, connectionGenerationSupplier, + InternalStreamConnectionInitializer connectionInitializer = new InternalStreamConnectionInitializer( + clusterConnectionMode, authenticator, clientMetadataDocument, compressorList, serverApi); + return new InternalStreamConnection( + clusterConnectionMode, authenticator, + isMonitoringConnection, serverId, connectionGenerationSupplier, streamFactory, compressorList, loggerSettings, commandListener, - new InternalStreamConnectionInitializer(clusterConnectionMode, authenticator, clientMetadataDocument, compressorList, - serverApi)); + connectionInitializer); } private Authenticator createAuthenticator(final MongoCredentialWithCache credential) { - if (credential.getAuthenticationMechanism() == null) { + AuthenticationMechanism authenticationMechanism = credential.getAuthenticationMechanism(); + if (authenticationMechanism == null) { return new DefaultAuthenticator(credential, clusterConnectionMode, serverApi); } - - switch (assertNotNull(credential.getAuthenticationMechanism())) { + switch (authenticationMechanism) { case GSSAPI: return new GSSAPIAuthenticator(credential, clusterConnectionMode, serverApi); case PLAIN: @@ -97,8 +100,10 @@ private Authenticator createAuthenticator(final MongoCredentialWithCache credent return new ScramShaAuthenticator(credential, clusterConnectionMode, serverApi); case MONGODB_AWS: return new AwsAuthenticator(credential, clusterConnectionMode, serverApi); + case MONGODB_OIDC: + return new OidcAuthenticator(credential, clusterConnectionMode, serverApi); default: - throw new IllegalArgumentException("Unsupported authentication mechanism " + credential.getAuthenticationMechanism()); + throw new IllegalArgumentException("Unsupported authentication mechanism " + authenticationMechanism); } } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java index f3d77ff2b2d..d4858f3d973 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java @@ -25,7 +25,6 @@ import com.mongodb.connection.ConnectionDescription; import com.mongodb.connection.ConnectionId; import com.mongodb.connection.ServerDescription; -import com.mongodb.connection.ServerType; import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.lang.Nullable; import org.bson.BsonArray; @@ -82,8 +81,10 @@ public InternalConnectionInitializationDescription finishHandshake(final Interna final InternalConnectionInitializationDescription description) { notNull("internalConnection", internalConnection); notNull("description", description); - - authenticate(internalConnection, description.getConnectionDescription()); + final ConnectionDescription connectionDescription = description.getConnectionDescription(); + if (Authenticator.shouldAuthenticate(authenticator, connectionDescription)) { + authenticator.authenticate(internalConnection, connectionDescription); + } return completeConnectionDescriptionInitialization(internalConnection, description); } @@ -106,11 +107,12 @@ public void startHandshakeAsync(final InternalConnection internalConnection, public void finishHandshakeAsync(final InternalConnection internalConnection, final InternalConnectionInitializationDescription description, final SingleResultCallback callback) { - if (authenticator == null || description.getConnectionDescription().getServerType() - == ServerType.REPLICA_SET_ARBITER) { + ConnectionDescription connectionDescription = description.getConnectionDescription(); + + if (!Authenticator.shouldAuthenticate(authenticator, connectionDescription)) { completeConnectionDescriptionInitializationAsync(internalConnection, description, callback); } else { - authenticator.authenticateAsync(internalConnection, description.getConnectionDescription(), + authenticator.authenticateAsync(internalConnection, connectionDescription, (result1, t1) -> { if (t1 != null) { callback.onResult(null, t1); @@ -201,12 +203,6 @@ private InternalConnectionInitializationDescription completeConnectionDescriptio description); } - private void authenticate(final InternalConnection internalConnection, final ConnectionDescription connectionDescription) { - if (authenticator != null && connectionDescription.getServerType() != ServerType.REPLICA_SET_ARBITER) { - authenticator.authenticate(internalConnection, connectionDescription); - } - } - private void setSpeculativeAuthenticateResponse(final BsonDocument helloResult) { if (authenticator instanceof SpeculativeAuthenticator) { ((SpeculativeAuthenticator) authenticator).setSpeculativeAuthenticateResponse( diff --git a/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java b/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java index 43b9ad3eec5..682637bf9ed 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java +++ b/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java @@ -22,8 +22,10 @@ import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; +import java.util.concurrent.locks.StampedLock; import static com.mongodb.internal.Locks.withInterruptibleLock; +import static com.mongodb.internal.connection.OidcAuthenticator.OidcCacheEntry; /** *

    This class is not part of the public API and may be removed or changed at any time

    @@ -33,12 +35,12 @@ public class MongoCredentialWithCache { private final Cache cache; public MongoCredentialWithCache(final MongoCredential credential) { - this(credential, null); + this(credential, new Cache()); } - private MongoCredentialWithCache(final MongoCredential credential, @Nullable final Cache cache) { + private MongoCredentialWithCache(final MongoCredential credential, final Cache cache) { this.credential = credential; - this.cache = cache != null ? cache : new Cache(); + this.cache = cache; } public MongoCredentialWithCache withMechanism(final AuthenticationMechanism mechanism) { @@ -63,15 +65,34 @@ public void putInCache(final Object key, final Object value) { cache.set(key, value); } + OidcCacheEntry getOidcCacheEntry() { + return cache.oidcCacheEntry; + } + + void setOidcCacheEntry(final OidcCacheEntry oidcCacheEntry) { + this.cache.oidcCacheEntry = oidcCacheEntry; + } + + StampedLock getOidcLock() { + return cache.oidcLock; + } + public Lock getLock() { return cache.lock; } + /** + * Stores any state associated with the credential. + */ static class Cache { private final ReentrantLock lock = new ReentrantLock(); private Object cacheKey; private Object cacheValue; + + private final StampedLock oidcLock = new StampedLock(); + private volatile OidcCacheEntry oidcCacheEntry = new OidcCacheEntry(); + Object get(final Object key) { return withInterruptibleLock(lock, () -> { if (cacheKey != null && cacheKey.equals(key)) { diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java new file mode 100644 index 00000000000..af26abbf87f --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -0,0 +1,745 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.connection; + +import com.mongodb.AuthenticationMechanism; +import com.mongodb.MongoClientException; +import com.mongodb.MongoCommandException; +import com.mongodb.MongoConfigurationException; +import com.mongodb.MongoCredential; +import com.mongodb.MongoCredential.OidcCallbackResult; +import com.mongodb.MongoException; +import com.mongodb.MongoSecurityException; +import com.mongodb.ServerAddress; +import com.mongodb.ServerApi; +import com.mongodb.connection.ClusterConnectionMode; +import com.mongodb.connection.ConnectionDescription; +import com.mongodb.internal.Locks; +import com.mongodb.internal.VisibleForTesting; +import com.mongodb.internal.async.SingleResultCallback; +import com.mongodb.internal.authentication.AzureCredentialHelper; +import com.mongodb.internal.authentication.CredentialInfo; +import com.mongodb.internal.authentication.GcpCredentialHelper; +import com.mongodb.lang.Nullable; +import org.bson.BsonDocument; +import org.bson.BsonString; +import org.bson.RawBsonDocument; + +import javax.security.sasl.SaslClient; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static com.mongodb.AuthenticationMechanism.MONGODB_OIDC; +import static com.mongodb.MongoCredential.ALLOWED_HOSTS_KEY; +import static com.mongodb.MongoCredential.DEFAULT_ALLOWED_HOSTS; +import static com.mongodb.MongoCredential.ENVIRONMENT_KEY; +import static com.mongodb.MongoCredential.IdpInfo; +import static com.mongodb.MongoCredential.OIDC_CALLBACK_KEY; +import static com.mongodb.MongoCredential.OIDC_HUMAN_CALLBACK_KEY; +import static com.mongodb.MongoCredential.OidcCallback; +import static com.mongodb.MongoCredential.OidcCallbackContext; +import static com.mongodb.MongoCredential.TOKEN_RESOURCE_KEY; +import static com.mongodb.assertions.Assertions.assertFalse; +import static com.mongodb.assertions.Assertions.assertNotNull; +import static com.mongodb.assertions.Assertions.assertTrue; +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; +import static com.mongodb.internal.connection.OidcAuthenticator.OidcValidator.validateBeforeUse; +import static java.lang.String.format; + +/** + *

    This class is not part of the public API and may be removed or changed at any time

    + */ +public final class OidcAuthenticator extends SaslAuthenticator { + + private static final String TEST_ENVIRONMENT = "test"; + private static final String AZURE_ENVIRONMENT = "azure"; + private static final String GCP_ENVIRONMENT = "gcp"; + private static final List IMPLEMENTED_ENVIRONMENTS = Arrays.asList( + AZURE_ENVIRONMENT, GCP_ENVIRONMENT, TEST_ENVIRONMENT); + private static final List USER_SUPPORTED_ENVIRONMENTS = Arrays.asList( + AZURE_ENVIRONMENT, GCP_ENVIRONMENT); + private static final List REQUIRES_TOKEN_RESOURCE = Arrays.asList( + AZURE_ENVIRONMENT, GCP_ENVIRONMENT); + private static final List ALLOWS_USERNAME = Arrays.asList( + AZURE_ENVIRONMENT); + + private static final Duration CALLBACK_TIMEOUT = Duration.ofMinutes(5); + + public static final String OIDC_TOKEN_FILE = "OIDC_TOKEN_FILE"; + + private static final int CALLBACK_API_VERSION_NUMBER = 1; + + @Nullable + private ServerAddress serverAddress; + + @Nullable + private String connectionLastAccessToken; + + private FallbackState fallbackState = FallbackState.INITIAL; + + @Nullable + private BsonDocument speculativeAuthenticateResponse; + + public OidcAuthenticator(final MongoCredentialWithCache credential, + final ClusterConnectionMode clusterConnectionMode, @Nullable final ServerApi serverApi) { + super(credential, clusterConnectionMode, serverApi); + validateBeforeUse(credential.getCredential()); + + if (getMongoCredential().getAuthenticationMechanism() != MONGODB_OIDC) { + throw new MongoException("Incorrect mechanism: " + getMongoCredential().getMechanism()); + } + } + + @Override + public String getMechanismName() { + return MONGODB_OIDC.getMechanismName(); + } + + @Override + protected SaslClient createSaslClient(final ServerAddress serverAddress) { + this.serverAddress = assertNotNull(serverAddress); + MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); + return new OidcSaslClient(mongoCredentialWithCache); + } + + @Override + @Nullable + public BsonDocument createSpeculativeAuthenticateCommand(final InternalConnection connection) { + try { + String cachedAccessToken = getMongoCredentialWithCache() + .getOidcCacheEntry() + .getCachedAccessToken(); + if (cachedAccessToken != null) { + return wrapInSpeculative(prepareTokenAsJwt(cachedAccessToken)); + } else { + // otherwise, skip speculative auth + return null; + } + } catch (Exception e) { + throw wrapException(e); + } + } + + private BsonDocument wrapInSpeculative(final byte[] outToken) { + BsonDocument startDocument = createSaslStartCommandDocument(outToken) + .append("db", new BsonString(getMongoCredential().getSource())); + appendSaslStartOptions(startDocument); + return startDocument; + } + + @Override + @Nullable + public BsonDocument getSpeculativeAuthenticateResponse() { + BsonDocument response = speculativeAuthenticateResponse; + // response should only be read once + this.speculativeAuthenticateResponse = null; + if (response == null) { + this.connectionLastAccessToken = null; + } + return response; + } + + @Override + public void setSpeculativeAuthenticateResponse(@Nullable final BsonDocument response) { + speculativeAuthenticateResponse = response; + } + + private boolean isHumanCallback() { + // built-in providers (aws, azure...) are considered machine callbacks + return getOidcCallbackMechanismProperty(OIDC_HUMAN_CALLBACK_KEY) != null; + } + + @Nullable + private OidcCallback getOidcCallbackMechanismProperty(final String key) { + return getMongoCredentialWithCache() + .getCredential() + .getMechanismProperty(key, null); + } + + private OidcCallback getRequestCallback() { + String environment = getMongoCredential().getMechanismProperty(ENVIRONMENT_KEY, null); + OidcCallback machine; + if (TEST_ENVIRONMENT.equals(environment)) { + machine = getTestCallback(); + } else if (AZURE_ENVIRONMENT.equals(environment)) { + machine = getAzureCallback(getMongoCredential()); + } else if (GCP_ENVIRONMENT.equals(environment)) { + machine = getGcpCallback(getMongoCredential()); + } else { + machine = getOidcCallbackMechanismProperty(OIDC_CALLBACK_KEY); + } + OidcCallback human = getOidcCallbackMechanismProperty(OIDC_HUMAN_CALLBACK_KEY); + return machine != null ? machine : assertNotNull(human); + } + + private static OidcCallback getTestCallback() { + return (context) -> { + String accessToken = readTokenFromFile(); + return new OidcCallbackResult(accessToken); + }; + } + + @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) + static OidcCallback getAzureCallback(final MongoCredential credential) { + return (context) -> { + String resource = assertNotNull(credential.getMechanismProperty(TOKEN_RESOURCE_KEY, null)); + String clientId = credential.getUserName(); + CredentialInfo response = AzureCredentialHelper.fetchAzureCredentialInfo(resource, clientId); + return new OidcCallbackResult(response.getAccessToken(), response.getExpiresIn()); + }; + } + + @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) + static OidcCallback getGcpCallback(final MongoCredential credential) { + return (context) -> { + String resource = assertNotNull(credential.getMechanismProperty(TOKEN_RESOURCE_KEY, null)); + CredentialInfo response = GcpCredentialHelper.fetchGcpCredentialInfo(resource); + return new OidcCallbackResult(response.getAccessToken(), response.getExpiresIn()); + }; + } + + @Override + public void reauthenticate(final InternalConnection connection) { + assertTrue(connection.opened()); + authenticationLoop(connection, connection.getDescription()); + } + + @Override + public void reauthenticateAsync(final InternalConnection connection, final SingleResultCallback callback) { + beginAsync().thenRun(c -> { + assertTrue(connection.opened()); + authenticationLoopAsync(connection, connection.getDescription(), c); + }).finish(callback); + } + + @Override + public void authenticate(final InternalConnection connection, final ConnectionDescription connectionDescription) { + assertFalse(connection.opened()); + authenticationLoop(connection, connectionDescription); + } + + @Override + void authenticateAsync(final InternalConnection connection, final ConnectionDescription connectionDescription, + final SingleResultCallback callback) { + beginAsync().thenRun(c -> { + assertFalse(connection.opened()); + authenticationLoopAsync(connection, connectionDescription, c); + }).finish(callback); + } + + private static boolean triggersRetry(@Nullable final Throwable t) { + if (t instanceof MongoSecurityException) { + MongoSecurityException e = (MongoSecurityException) t; + Throwable cause = e.getCause(); + if (cause instanceof MongoCommandException) { + MongoCommandException commandCause = (MongoCommandException) cause; + return commandCause.getErrorCode() == 18; + } + } + return false; + } + + private void authenticationLoop(final InternalConnection connection, final ConnectionDescription description) { + fallbackState = FallbackState.INITIAL; + while (true) { + try { + super.authenticate(connection, description); + break; + } catch (Exception e) { + if (triggersRetry(e) && shouldRetryHandler()) { + continue; + } + throw e; + } + } + } + + private void authenticationLoopAsync(final InternalConnection connection, final ConnectionDescription description, + final SingleResultCallback callback) { + fallbackState = FallbackState.INITIAL; + beginAsync().thenRunRetryingWhile( + c -> super.authenticateAsync(connection, description, c), + e -> triggersRetry(e) && shouldRetryHandler() + ).finish(callback); + } + + private byte[] evaluate(final byte[] challenge) { + byte[][] jwt = new byte[1][]; + Locks.withInterruptibleLock(getMongoCredentialWithCache().getOidcLock(), () -> { + OidcCacheEntry oidcCacheEntry = getMongoCredentialWithCache().getOidcCacheEntry(); + String cachedRefreshToken = oidcCacheEntry.getRefreshToken(); + IdpInfo cachedIdpInfo = oidcCacheEntry.getIdpInfo(); + String cachedAccessToken = validatedCachedAccessToken(); + OidcCallback requestCallback = getRequestCallback(); + boolean isHuman = isHumanCallback(); + String userName = getMongoCredentialWithCache().getCredential().getUserName(); + + if (cachedAccessToken != null) { + fallbackState = FallbackState.PHASE_1_CACHED_TOKEN; + jwt[0] = prepareTokenAsJwt(cachedAccessToken); + } else if (cachedRefreshToken != null) { + // cached refresh token is only set when isHuman + // original IDP info will be present, if refresh token present + assertNotNull(cachedIdpInfo); + // Invoke Callback using cached Refresh Token + fallbackState = FallbackState.PHASE_2_REFRESH_CALLBACK_TOKEN; + OidcCallbackResult result = requestCallback.onRequest(new OidcCallbackContextImpl( + CALLBACK_TIMEOUT, cachedIdpInfo, cachedRefreshToken, userName)); + jwt[0] = populateCacheWithCallbackResultAndPrepareJwt(cachedIdpInfo, result); + } else { + // cache is empty + + if (!isHuman) { + // no principal request + fallbackState = FallbackState.PHASE_3B_CALLBACK_TOKEN; + OidcCallbackResult result = requestCallback.onRequest(new OidcCallbackContextImpl( + CALLBACK_TIMEOUT, userName)); + jwt[0] = populateCacheWithCallbackResultAndPrepareJwt(null, result); + if (result.getRefreshToken() != null) { + throw new MongoConfigurationException( + "Refresh token must only be provided in human workflow"); + } + } else { + /* + A check for present idp info short-circuits phase-3a. + If a challenge is present, it can only be a response to a + "principal-request", so the challenge must be the resulting + idp info. Such a request is made during speculative auth, + though the source is unimportant, as long as we detect and + use it here. + */ + boolean idpInfoNotPresent = challenge.length == 0; + /* + Checking that the fallback state is not phase-3a ensures that + this does not loop infinitely in the case of a bug. + */ + boolean alreadyTriedPrincipal = fallbackState == FallbackState.PHASE_3A_PRINCIPAL; + if (!alreadyTriedPrincipal && idpInfoNotPresent) { + // request for idp info, only in the human workflow + fallbackState = FallbackState.PHASE_3A_PRINCIPAL; + jwt[0] = prepareUsername(userName); + } else { + IdpInfo idpInfo = toIdpInfo(challenge); + // there is no cached refresh token + fallbackState = FallbackState.PHASE_3B_CALLBACK_TOKEN; + OidcCallbackResult result = requestCallback.onRequest(new OidcCallbackContextImpl( + CALLBACK_TIMEOUT, idpInfo, null, userName)); + jwt[0] = populateCacheWithCallbackResultAndPrepareJwt(idpInfo, result); + } + } + } + }); + return jwt[0]; + } + + /** + * Must be guarded by {@link MongoCredentialWithCache#getOidcLock()}. + */ + @Nullable + private String validatedCachedAccessToken() { + MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); + OidcCacheEntry cacheEntry = mongoCredentialWithCache.getOidcCacheEntry(); + String cachedAccessToken = cacheEntry.getCachedAccessToken(); + String invalidConnectionAccessToken = connectionLastAccessToken; + + if (cachedAccessToken != null) { + boolean cachedTokenIsInvalid = cachedAccessToken.equals(invalidConnectionAccessToken); + if (cachedTokenIsInvalid) { + mongoCredentialWithCache.setOidcCacheEntry(cacheEntry.clearAccessToken()); + cachedAccessToken = null; + } + } + return cachedAccessToken; + } + + private boolean clientIsComplete() { + return fallbackState != FallbackState.PHASE_3A_PRINCIPAL; + } + + private boolean shouldRetryHandler() { + boolean[] result = new boolean[1]; + Locks.withInterruptibleLock(getMongoCredentialWithCache().getOidcLock(), () -> { + MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); + OidcCacheEntry cacheEntry = mongoCredentialWithCache.getOidcCacheEntry(); + if (fallbackState == FallbackState.PHASE_1_CACHED_TOKEN) { + // a cached access token failed + mongoCredentialWithCache.setOidcCacheEntry(cacheEntry + .clearAccessToken()); + result[0] = true; + } else if (fallbackState == FallbackState.PHASE_2_REFRESH_CALLBACK_TOKEN) { + // a refresh token failed + mongoCredentialWithCache.setOidcCacheEntry(cacheEntry + .clearAccessToken() + .clearRefreshToken()); + result[0] = true; + } else { + // a clean-restart failed + mongoCredentialWithCache.setOidcCacheEntry(cacheEntry + .clearAccessToken() + .clearRefreshToken()); + result[0] = false; + } + }); + return result[0]; + } + + static final class OidcCacheEntry { + @Nullable + private final String accessToken; + @Nullable + private final String refreshToken; + @Nullable + private final IdpInfo idpInfo; + + @Override + public String toString() { + return "OidcCacheEntry{" + + "\n accessToken=[omitted]" + + ",\n refreshToken=[omitted]" + + ",\n idpInfo=" + idpInfo + + '}'; + } + + OidcCacheEntry() { + this(null, null, null); + } + + private OidcCacheEntry(@Nullable final String accessToken, + @Nullable final String refreshToken, @Nullable final IdpInfo idpInfo) { + this.accessToken = accessToken; + this.refreshToken = refreshToken; + this.idpInfo = idpInfo; + } + + @Nullable + String getCachedAccessToken() { + return accessToken; + } + + @Nullable + String getRefreshToken() { + return refreshToken; + } + + @Nullable + IdpInfo getIdpInfo() { + return idpInfo; + } + + OidcCacheEntry clearAccessToken() { + return new OidcCacheEntry( + null, + this.refreshToken, + this.idpInfo); + } + + OidcCacheEntry clearRefreshToken() { + return new OidcCacheEntry( + this.accessToken, + null, + null); + } + } + + private final class OidcSaslClient extends SaslClientImpl { + + private OidcSaslClient(final MongoCredentialWithCache mongoCredentialWithCache) { + super(mongoCredentialWithCache.getCredential()); + } + + @Override + public byte[] evaluateChallenge(final byte[] challenge) { + return evaluate(challenge); + } + + @Override + public boolean isComplete() { + return clientIsComplete(); + } + + } + + private static String readTokenFromFile() { + String path = System.getenv(OIDC_TOKEN_FILE); + if (path == null) { + throw new MongoClientException( + format("Environment variable must be specified: %s", OIDC_TOKEN_FILE)); + } + try { + return new String(Files.readAllBytes(Paths.get(path)), StandardCharsets.UTF_8); + } catch (IOException e) { + throw new MongoClientException(format( + "Could not read file specified by environment variable: %s at path: %s", + OIDC_TOKEN_FILE, path), e); + } + } + + private byte[] populateCacheWithCallbackResultAndPrepareJwt( + @Nullable final IdpInfo serverInfo, + @Nullable final OidcCallbackResult oidcCallbackResult) { + if (oidcCallbackResult == null) { + throw new MongoConfigurationException("Result of callback must not be null"); + } + OidcCacheEntry newEntry = new OidcCacheEntry(oidcCallbackResult.getAccessToken(), + oidcCallbackResult.getRefreshToken(), serverInfo); + getMongoCredentialWithCache().setOidcCacheEntry(newEntry); + return prepareTokenAsJwt(oidcCallbackResult.getAccessToken()); + } + + private static byte[] prepareUsername(@Nullable final String username) { + BsonDocument document = new BsonDocument(); + if (username != null) { + document = document.append("n", new BsonString(username)); + } + return toBson(document); + } + + private IdpInfo toIdpInfo(final byte[] challenge) { + // validate here to prevent creating IdpInfo for unauthorized hosts + validateAllowedHosts(getMongoCredential()); + BsonDocument c = new RawBsonDocument(challenge); + String issuer = c.getString("issuer").getValue(); + String clientId = !c.containsKey("clientId") ? null : c.getString("clientId").getValue(); + return new IdpInfoImpl( + issuer, + clientId, + getStringArray(c, "requestScopes")); + } + + @Nullable + private static List getStringArray(final BsonDocument document, final String key) { + if (!document.isArray(key)) { + return null; + } + return document.getArray(key).stream() + // ignore non-string values from server, rather than error + .filter(v -> v.isString()) + .map(v -> v.asString().getValue()) + .collect(Collectors.toList()); + } + + private void validateAllowedHosts(final MongoCredential credential) { + List allowedHosts = assertNotNull(credential.getMechanismProperty(ALLOWED_HOSTS_KEY, DEFAULT_ALLOWED_HOSTS)); + String host = assertNotNull(serverAddress).getHost(); + boolean permitted = allowedHosts.stream().anyMatch(allowedHost -> { + if (allowedHost.startsWith("*.")) { + String ending = allowedHost.substring(1); + return host.endsWith(ending); + } else if (allowedHost.contains("*")) { + throw new IllegalArgumentException( + "Allowed host " + allowedHost + " contains invalid wildcard"); + } else { + return host.equals(allowedHost); + } + }); + if (!permitted) { + throw new MongoSecurityException( + credential, "Host " + host + " not permitted by " + ALLOWED_HOSTS_KEY + + ", values: " + allowedHosts); + } + } + + private byte[] prepareTokenAsJwt(final String accessToken) { + connectionLastAccessToken = accessToken; + return toJwtDocument(accessToken); + } + + private static byte[] toJwtDocument(final String accessToken) { + return toBson(new BsonDocument().append("jwt", new BsonString(accessToken))); + } + + /** + * Contains all validation logic for OIDC in one location + */ + public static final class OidcValidator { + private OidcValidator() { + } + + public static void validateOidcCredentialConstruction( + final String source, + final Map mechanismProperties) { + + if (!"$external".equals(source)) { + throw new IllegalArgumentException("source must be '$external'"); + } + + Object environmentName = mechanismProperties.get(ENVIRONMENT_KEY.toLowerCase()); + if (environmentName != null) { + if (!(environmentName instanceof String) || !IMPLEMENTED_ENVIRONMENTS.contains(environmentName)) { + throw new IllegalArgumentException(ENVIRONMENT_KEY + " must be one of: " + USER_SUPPORTED_ENVIRONMENTS); + } + } + } + + public static void validateCreateOidcCredential(@Nullable final char[] password) { + if (password != null) { + throw new IllegalArgumentException("password must not be specified for " + + AuthenticationMechanism.MONGODB_OIDC); + } + } + + @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) + public static void validateBeforeUse(final MongoCredential credential) { + String userName = credential.getUserName(); + Object environmentName = credential.getMechanismProperty(ENVIRONMENT_KEY, null); + Object machineCallback = credential.getMechanismProperty(OIDC_CALLBACK_KEY, null); + Object humanCallback = credential.getMechanismProperty(OIDC_HUMAN_CALLBACK_KEY, null); + if (environmentName == null) { + // callback + if (machineCallback == null && humanCallback == null) { + throw new IllegalArgumentException("Either " + ENVIRONMENT_KEY + + " or " + OIDC_CALLBACK_KEY + + " or " + OIDC_HUMAN_CALLBACK_KEY + + " must be specified"); + } + if (machineCallback != null && humanCallback != null) { + throw new IllegalArgumentException("Both " + OIDC_CALLBACK_KEY + + " and " + OIDC_HUMAN_CALLBACK_KEY + + " must not be specified"); + } + } else { + if (!(environmentName instanceof String)) { + throw new IllegalArgumentException(ENVIRONMENT_KEY + " must be a String"); + } + if (userName != null && !ALLOWS_USERNAME.contains(environmentName)) { + throw new IllegalArgumentException("user name must not be specified when " + ENVIRONMENT_KEY + " is specified"); + } + if (machineCallback != null) { + throw new IllegalArgumentException(OIDC_CALLBACK_KEY + " must not be specified when " + ENVIRONMENT_KEY + " is specified"); + } + if (humanCallback != null) { + throw new IllegalArgumentException(OIDC_HUMAN_CALLBACK_KEY + " must not be specified when " + ENVIRONMENT_KEY + " is specified"); + } + String tokenResource = credential.getMechanismProperty(TOKEN_RESOURCE_KEY, null); + boolean hasTokenResourceProperty = tokenResource != null; + boolean tokenResourceSupported = REQUIRES_TOKEN_RESOURCE.contains(environmentName); + if (hasTokenResourceProperty != tokenResourceSupported) { + throw new IllegalArgumentException(TOKEN_RESOURCE_KEY + + " must be provided if and only if " + ENVIRONMENT_KEY + + " " + environmentName + " " + + " is one of: " + REQUIRES_TOKEN_RESOURCE + + ". " + TOKEN_RESOURCE_KEY + ": " + tokenResource); + } + } + } + } + + @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) + static class OidcCallbackContextImpl implements OidcCallbackContext { + private final Duration timeout; + @Nullable + private final IdpInfo idpInfo; + @Nullable + private final String refreshToken; + @Nullable + private final String userName; + + OidcCallbackContextImpl(final Duration timeout, @Nullable final String userName) { + this.timeout = assertNotNull(timeout); + this.idpInfo = null; + this.refreshToken = null; + this.userName = userName; + } + + OidcCallbackContextImpl(final Duration timeout, final IdpInfo idpInfo, + @Nullable final String refreshToken, @Nullable final String userName) { + this.timeout = assertNotNull(timeout); + this.idpInfo = assertNotNull(idpInfo); + this.refreshToken = refreshToken; + this.userName = userName; + } + + @Override + @Nullable + public IdpInfo getIdpInfo() { + return idpInfo; + } + + @Override + public Duration getTimeout() { + return timeout; + } + + @Override + public int getVersion() { + return CALLBACK_API_VERSION_NUMBER; + } + + @Override + @Nullable + public String getRefreshToken() { + return refreshToken; + } + + @Override + @Nullable + public String getUserName() { + return userName; + } + } + + @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) + static final class IdpInfoImpl implements IdpInfo { + private final String issuer; + @Nullable + private final String clientId; + private final List requestScopes; + + IdpInfoImpl(final String issuer, @Nullable final String clientId, @Nullable final List requestScopes) { + this.issuer = assertNotNull(issuer); + this.clientId = clientId; + this.requestScopes = requestScopes == null + ? Collections.emptyList() + : Collections.unmodifiableList(requestScopes); + } + + @Override + public String getIssuer() { + return issuer; + } + + @Override + @Nullable + public String getClientId() { + return clientId; + } + + @Override + public List getRequestScopes() { + return requestScopes; + } + } + + /** + * What was sent in the last request by this connection to the server. + */ + private enum FallbackState { + INITIAL, + PHASE_1_CACHED_TOKEN, + PHASE_2_REFRESH_CALLBACK_TOKEN, + PHASE_3A_PRINCIPAL, + PHASE_3B_CALLBACK_TOKEN + } +} diff --git a/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java index 2c2321fcbad..6e4bea55514 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java @@ -16,6 +16,8 @@ package com.mongodb.internal.connection; +import com.mongodb.AuthenticationMechanism; +import com.mongodb.MongoCredential; import com.mongodb.MongoException; import com.mongodb.MongoInterruptedException; import com.mongodb.MongoSecurityException; @@ -30,9 +32,13 @@ import com.mongodb.lang.NonNull; import com.mongodb.lang.Nullable; import org.bson.BsonBinary; +import org.bson.BsonBinaryWriter; import org.bson.BsonDocument; import org.bson.BsonInt32; import org.bson.BsonString; +import org.bson.codecs.BsonDocumentCodec; +import org.bson.codecs.EncoderContext; +import org.bson.io.BasicOutputBuffer; import javax.security.auth.Subject; import javax.security.auth.login.LoginException; @@ -55,6 +61,7 @@ abstract class SaslAuthenticator extends Authenticator implements SpeculativeAut super(credential, clusterConnectionMode, serverApi); } + @Override public void authenticate(final InternalConnection connection, final ConnectionDescription connectionDescription) { doAsSubject(() -> { SaslClient saslClient = createSaslClient(connection.getDescription().getServerAddress()); @@ -121,7 +128,7 @@ private void throwIfSaslClientIsNull(@Nullable final SaslClient saslClient) { } private BsonDocument getNextSaslResponse(final SaslClient saslClient, final InternalConnection connection) { - BsonDocument response = getSpeculativeAuthenticateResponse(); + BsonDocument response = connection.opened() ? null : getSpeculativeAuthenticateResponse(); if (response != null) { return response; } @@ -136,9 +143,9 @@ private BsonDocument getNextSaslResponse(final SaslClient saslClient, final Inte private void getNextSaslResponseAsync(final SaslClient saslClient, final InternalConnection connection, final SingleResultCallback callback) { - BsonDocument response = getSpeculativeAuthenticateResponse(); SingleResultCallback errHandlingCallback = errorHandlingCallback(callback, LOGGER); try { + BsonDocument response = connection.opened() ? null : getSpeculativeAuthenticateResponse(); if (response == null) { byte[] serverResponse = (saslClient.hasInitialResponse() ? saslClient.evaluateChallenge(new byte[0]) : null); sendSaslStartAsync(serverResponse, connection, (result, t) -> { @@ -280,6 +287,15 @@ void doAsSubject(final java.security.PrivilegedAction action) { } } + static byte[] toBson(final BsonDocument document) { + byte[] bytes; + BasicOutputBuffer buffer = new BasicOutputBuffer(); + new BsonDocumentCodec().encode(new BsonBinaryWriter(buffer), document, EncoderContext.builder().build()); + bytes = new byte[buffer.size()]; + System.arraycopy(buffer.getInternalBuffer(), 0, bytes, 0, buffer.getSize()); + return bytes; + } + private final class Continuator implements SingleResultCallback { private final SaslClient saslClient; private final BsonDocument saslStartDocument; @@ -331,7 +347,51 @@ private void continueConversation(final BsonDocument result) { disposeOfSaslClient(saslClient); } } - } + protected abstract static class SaslClientImpl implements SaslClient { + private final MongoCredential credential; + + protected SaslClientImpl(final MongoCredential credential) { + this.credential = credential; + } + + @Override + public boolean hasInitialResponse() { + return true; + } + + @Override + public byte[] unwrap(final byte[] bytes, final int i, final int i1) { + throw new UnsupportedOperationException("Not implemented."); + } + + @Override + public byte[] wrap(final byte[] bytes, final int i, final int i1) { + throw new UnsupportedOperationException("Not implemented."); + } + + @Override + public Object getNegotiatedProperty(final String s) { + throw new UnsupportedOperationException("Not implemented."); + } + + @Override + public void dispose() { + // nothing to do + } + + @Override + public final String getMechanismName() { + AuthenticationMechanism authMechanism = getCredential().getAuthenticationMechanism(); + if (authMechanism == null) { + throw new IllegalArgumentException("Authentication mechanism cannot be null"); + } + return authMechanism.getMechanismName(); + } + + protected final MongoCredential getCredential() { + return credential; + } + } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/ScramShaAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/ScramShaAuthenticator.java index 5dec0d90c1e..02bc7912c93 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/ScramShaAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/ScramShaAuthenticator.java @@ -92,7 +92,7 @@ protected SaslClient createSaslClient(final ServerAddress serverAddress) { if (speculativeSaslClient != null) { return speculativeSaslClient; } - return new ScramShaSaslClient(getMongoCredentialWithCache(), randomStringGenerator, authenticationHashGenerator); + return new ScramShaSaslClient(getMongoCredentialWithCache().getCredential(), randomStringGenerator, authenticationHashGenerator); } @Override @@ -122,9 +122,7 @@ public void setSpeculativeAuthenticateResponse(@Nullable final BsonDocument resp } } - class ScramShaSaslClient implements SaslClient { - - private final MongoCredentialWithCache credential; + class ScramShaSaslClient extends SaslClientImpl { private final RandomStringGenerator randomStringGenerator; private final AuthenticationHashGenerator authenticationHashGenerator; private final String hAlgorithm; @@ -136,9 +134,11 @@ class ScramShaSaslClient implements SaslClient { private byte[] serverSignature; private int step = -1; - ScramShaSaslClient(final MongoCredentialWithCache credential, final RandomStringGenerator randomStringGenerator, - final AuthenticationHashGenerator authenticationHashGenerator) { - this.credential = credential; + ScramShaSaslClient( + final MongoCredential credential, + final RandomStringGenerator randomStringGenerator, + final AuthenticationHashGenerator authenticationHashGenerator) { + super(credential); this.randomStringGenerator = randomStringGenerator; this.authenticationHashGenerator = authenticationHashGenerator; if (assertNotNull(credential.getAuthenticationMechanism()).equals(SCRAM_SHA_1)) { @@ -150,14 +150,6 @@ class ScramShaSaslClient implements SaslClient { } } - public String getMechanismName() { - return assertNotNull(credential.getAuthenticationMechanism()).getMechanismName(); - } - - public boolean hasInitialResponse() { - return true; - } - public byte[] evaluateChallenge(final byte[] challenge) throws SaslException { step++; if (step == 0) { @@ -167,7 +159,8 @@ public byte[] evaluateChallenge(final byte[] challenge) throws SaslException { } else if (step == 2) { return validateServerSignature(challenge); } else { - throw new SaslException(format("Too many steps involved in the %s negotiation.", getMechanismName())); + throw new SaslException(format("Too many steps involved in the %s negotiation.", + super.getMechanismName())); } } @@ -184,22 +177,6 @@ public boolean isComplete() { return step == 2; } - public byte[] unwrap(final byte[] incoming, final int offset, final int len) { - throw new UnsupportedOperationException("Not implemented yet!"); - } - - public byte[] wrap(final byte[] outgoing, final int offset, final int len) { - throw new UnsupportedOperationException("Not implemented yet!"); - } - - public Object getNegotiatedProperty(final String propName) { - throw new UnsupportedOperationException("Not implemented yet!"); - } - - public void dispose() { - // nothing to do - } - private byte[] computeClientFirstMessage() { clientNonce = randomStringGenerator.generate(RANDOM_LENGTH); String clientFirstMessage = "n=" + getUserName() + ",r=" + clientNonce; @@ -318,9 +295,8 @@ private HashMap parseServerResponse(final String response) { return map; } - private String getUserName() { - String userName = credential.getCredential().getUserName(); + String userName = getCredential().getUserName(); if (userName == null) { throw new IllegalArgumentException("Username can not be null"); } @@ -328,8 +304,8 @@ private String getUserName() { } private String getAuthenicationHash() { - String password = authenticationHashGenerator.generate(credential.getCredential()); - if (credential.getAuthenticationMechanism() == SCRAM_SHA_256) { + String password = authenticationHashGenerator.generate(getCredential()); + if (getCredential().getAuthenticationMechanism() == SCRAM_SHA_256) { password = SaslPrep.saslPrepStored(password); } return password; diff --git a/driver-core/src/test/functional/com/mongodb/client/TestHelper.java b/driver-core/src/test/functional/com/mongodb/client/TestHelper.java new file mode 100644 index 00000000000..237c03c7e19 --- /dev/null +++ b/driver-core/src/test/functional/com/mongodb/client/TestHelper.java @@ -0,0 +1,47 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client; + +import com.mongodb.lang.Nullable; + +import java.lang.reflect.Field; +import java.util.Map; + +import static java.lang.System.getenv; + +public final class TestHelper { + + public static void setEnvironmentVariable(final String name, @Nullable final String value) { + try { + Map env = getenv(); + Field field = env.getClass().getDeclaredField("m"); + field.setAccessible(true); + @SuppressWarnings("unchecked") + Map result = (Map) field.get(env); + if (value == null) { + result.remove(name); + } else { + result.put(name, value); + } + } catch (IllegalAccessException | NoSuchFieldException e) { + throw new RuntimeException(e); + } + } + + private TestHelper() { + } +} diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/TestCommandListener.java b/driver-core/src/test/functional/com/mongodb/internal/connection/TestCommandListener.java index 0a2838c2d55..c8274f382fc 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/connection/TestCommandListener.java +++ b/driver-core/src/test/functional/com/mongodb/internal/connection/TestCommandListener.java @@ -17,11 +17,13 @@ package com.mongodb.internal.connection; import com.mongodb.MongoTimeoutException; +import com.mongodb.client.TestListener; import com.mongodb.event.CommandEvent; import com.mongodb.event.CommandFailedEvent; import com.mongodb.event.CommandListener; import com.mongodb.event.CommandStartedEvent; import com.mongodb.event.CommandSucceededEvent; +import com.mongodb.lang.Nullable; import org.bson.BsonDocument; import org.bson.BsonDocumentWriter; import org.bson.BsonDouble; @@ -55,6 +57,8 @@ public class TestCommandListener implements CommandListener { private final List eventTypes; private final List ignoredCommandMonitoringEvents; private final List events = new ArrayList<>(); + @Nullable + private final TestListener listener; private final Lock lock = new ReentrantLock(); private final Condition commandCompletedCondition = lock.newCondition(); private final boolean observeSensitiveCommands; @@ -76,25 +80,44 @@ public Codec get(final Class clazz, final CodecRegistry registry) { }); } + /** + * When a test listener is set, this command listener will send string events to the + * test listener in the form {@code " "}, where the event + * type will be lowercase and will omit the terms "command" and "event". + * For example: {@code "saslContinue succeeded"}. + * + * @see InternalStreamConnection#setRecordEverything(boolean) + * @param listener the test listener + */ + public TestCommandListener(final TestListener listener) { + this(Arrays.asList("commandStartedEvent", "commandSucceededEvent", "commandFailedEvent"), emptyList(), true, listener); + } + public TestCommandListener() { this(Arrays.asList("commandStartedEvent", "commandSucceededEvent", "commandFailedEvent"), emptyList()); } public TestCommandListener(final List eventTypes, final List ignoredCommandMonitoringEvents) { - this(eventTypes, ignoredCommandMonitoringEvents, true); + this(eventTypes, ignoredCommandMonitoringEvents, true, null); } public TestCommandListener(final List eventTypes, final List ignoredCommandMonitoringEvents, - final boolean observeSensitiveCommands) { + final boolean observeSensitiveCommands, @Nullable final TestListener listener) { this.eventTypes = eventTypes; this.ignoredCommandMonitoringEvents = ignoredCommandMonitoringEvents; this.observeSensitiveCommands = observeSensitiveCommands; + this.listener = listener; } + + public void reset() { lock.lock(); try { events.clear(); + if (listener != null) { + listener.clear(); + } } finally { lock.unlock(); } @@ -109,6 +132,18 @@ public List getEvents() { } } + private void addEvent(final CommandEvent c) { + events.add(c); + String className = c.getClass().getSimpleName() + .replace("Command", "") + .replace("Event", "") + .toLowerCase(); + // example: "saslContinue succeeded" + if (listener != null) { + listener.add(c.getCommandName() + " " + className); + } + } + public CommandStartedEvent getCommandStartedEvent(final String commandName) { for (CommandEvent event : getCommandStartedEvents()) { if (event instanceof CommandStartedEvent) { @@ -226,7 +261,7 @@ else if (!observeSensitiveCommands) { } lock.lock(); try { - events.add(new CommandStartedEvent(event.getRequestContext(), event.getOperationId(), event.getRequestId(), + addEvent(new CommandStartedEvent(event.getRequestContext(), event.getOperationId(), event.getRequestId(), event.getConnectionDescription(), event.getDatabaseName(), event.getCommandName(), event.getCommand() == null ? null : getWritableClone(event.getCommand()))); } finally { @@ -249,7 +284,7 @@ else if (!observeSensitiveCommands) { } lock.lock(); try { - events.add(new CommandSucceededEvent(event.getRequestContext(), event.getOperationId(), event.getRequestId(), + addEvent(new CommandSucceededEvent(event.getRequestContext(), event.getOperationId(), event.getRequestId(), event.getConnectionDescription(), event.getDatabaseName(), event.getCommandName(), event.getResponse() == null ? null : event.getResponse().clone(), event.getElapsedTime(TimeUnit.NANOSECONDS))); @@ -274,7 +309,7 @@ else if (!observeSensitiveCommands) { } lock.lock(); try { - events.add(event); + addEvent(event); commandCompletedCondition.signal(); } finally { lock.unlock(); diff --git a/driver-core/src/test/resources/auth/connection-string.json b/driver-core/src/test/resources/auth/legacy/connection-string.json similarity index 67% rename from driver-core/src/test/resources/auth/connection-string.json rename to driver-core/src/test/resources/auth/legacy/connection-string.json index 2a37ae8df47..072dd176dc8 100644 --- a/driver-core/src/test/resources/auth/connection-string.json +++ b/driver-core/src/test/resources/auth/legacy/connection-string.json @@ -444,6 +444,193 @@ "AWS_SESSION_TOKEN": "token!@#$%^&*()_+" } } + }, + { + "description": "should recognise the mechanism with test environment (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:test", + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "ENVIRONMENT": "test" + } + } + }, + { + "description": "should recognise the mechanism when auth source is explicitly specified and with environment (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authSource=$external&authMechanismProperties=ENVIRONMENT:test", + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "ENVIRONMENT": "test" + } + } + }, + { + "description": "should throw an exception if supplied a password (MONGODB-OIDC)", + "uri": "mongodb://user:pass@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:test", + "valid": false, + "credential": null + }, + { + "description": "should throw an exception if username is specified for test (MONGODB-OIDC)", + "uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC&ENVIRONMENT:test", + "valid": false, + "credential": null + }, + { + "description": "should throw an exception if specified environment is not supported (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:invalid", + "valid": false, + "credential": null + }, + { + "description": "should throw an exception if neither environment nor callbacks specified (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", + "valid": false, + "credential": null + }, + { + "description": "should throw an exception when unsupported auth property is specified (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=UnsupportedProperty:unexisted", + "valid": false, + "credential": null + }, + { + "description": "should recognise the mechanism with azure provider (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:azure,TOKEN_RESOURCE:foo", + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "ENVIRONMENT": "azure", + "TOKEN_RESOURCE": "foo" + } + } + }, + { + "description": "should accept a username with azure provider (MONGODB-OIDC)", + "uri": "mongodb://user@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:azure,TOKEN_RESOURCE:foo", + "valid": true, + "credential": { + "username": "user", + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "ENVIRONMENT": "azure", + "TOKEN_RESOURCE": "foo" + } + } + }, + { + "description": "should accept a url-encoded TOKEN_RESOURCE (MONGODB-OIDC)", + "uri": "mongodb://user@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:azure,TOKEN_RESOURCE:mongodb%3A%2F%2Ftest-cluster", + "valid": true, + "credential": { + "username": "user", + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "ENVIRONMENT": "azure", + "TOKEN_RESOURCE": "mongodb://test-cluster" + } + } + }, + { + "description": "should accept an un-encoded TOKEN_RESOURCE (MONGODB-OIDC)", + "uri": "mongodb://user@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:azure,TOKEN_RESOURCE:mongodb://test-cluster", + "valid": true, + "credential": { + "username": "user", + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "ENVIRONMENT": "azure", + "TOKEN_RESOURCE": "mongodb://test-cluster" + } + } + }, + { + "description": "should handle a complicated url-encoded TOKEN_RESOURCE (MONGODB-OIDC)", + "uri": "mongodb://user@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:azure,TOKEN_RESOURCE:abc%2Cd%25ef%3Ag%26hi", + "valid": true, + "credential": { + "username": "user", + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "ENVIRONMENT": "azure", + "TOKEN_RESOURCE": "abc,d%ef:g&hi" + } + } + }, + { + "description": "should url-encode a TOKEN_RESOURCE (MONGODB-OIDC)", + "uri": "mongodb://user@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:azure,TOKEN_RESOURCE:a$b", + "valid": true, + "credential": { + "username": "user", + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "ENVIRONMENT": "azure", + "TOKEN_RESOURCE": "a$b" + } + } + }, + { + "description": "should accept a username and throw an error for a password with azure provider (MONGODB-OIDC)", + "uri": "mongodb://user:pass@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:azure,TOKEN_RESOURCE:foo", + "valid": false, + "credential": null + }, + { + "description": "should throw an exception if no token audience is given for azure provider (MONGODB-OIDC)", + "uri": "mongodb://username@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:azure", + "valid": false, + "credential": null + }, + { + "description": "should recognise the mechanism with gcp provider (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:gcp,TOKEN_RESOURCE:foo", + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "ENVIRONMENT": "gcp", + "TOKEN_RESOURCE": "foo" + } + } + }, + { + "description": "should throw an error for a username and password with gcp provider (MONGODB-OIDC)", + "uri": "mongodb://user:pass@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:gcp,TOKEN_RESOURCE:foo", + "valid": false, + "credential": null + }, + { + "description": "should throw an error if not TOKEN_RESOURCE with gcp provider (MONGODB-OIDC)", + "uri": "mongodb://user:pass@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:gcp", + "valid": false, + "credential": null } ] } diff --git a/driver-core/src/test/resources/unified-test-format/auth/mongodb-oidc-no-retry.json b/driver-core/src/test/resources/unified-test-format/auth/mongodb-oidc-no-retry.json new file mode 100644 index 00000000000..83065f492ae --- /dev/null +++ b/driver-core/src/test/resources/unified-test-format/auth/mongodb-oidc-no-retry.json @@ -0,0 +1,421 @@ +{ + "description": "MONGODB-OIDC authentication with retry disabled", + "schemaVersion": "1.19", + "runOnRequirements": [ + { + "minServerVersion": "7.0", + "auth": true, + "authMechanism": "MONGODB-OIDC" + } + ], + "createEntities": [ + { + "client": { + "id": "failPointClient", + "useMultipleMongoses": false + } + }, + { + "client": { + "id": "client0", + "uriOptions": { + "authMechanism": "MONGODB-OIDC", + "authMechanismProperties": { + "$$placeholder": 1 + }, + "retryReads": false, + "retryWrites": false + }, + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent", + "commandFailedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "test" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "collName" + } + } + ], + "initialData": [ + { + "collectionName": "collName", + "databaseName": "test", + "documents": [] + } + ], + "tests": [ + { + "description": "A read operation should succeed", + "operations": [ + { + "name": "find", + "object": "collection0", + "arguments": { + "filter": {} + }, + "expectResult": [] + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": {} + } + } + }, + { + "commandSucceededEvent": { + "commandName": "find" + } + } + ] + } + ] + }, + { + "description": "A write operation should succeed", + "operations": [ + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 + } + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandSucceededEvent": { + "commandName": "insert" + } + } + ] + } + ] + }, + { + "description": "Read commands should reauthenticate and retry when a ReauthenticationRequired error happens", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "find" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "find", + "object": "collection0", + "arguments": { + "filter": {} + }, + "expectResult": [] + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": {} + } + } + }, + { + "commandFailedEvent": { + "commandName": "find" + } + }, + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": {} + } + } + }, + { + "commandSucceededEvent": { + "commandName": "find" + } + } + ] + } + ] + }, + { + "description": "Write commands should reauthenticate and retry when a ReauthenticationRequired error happens", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "insert" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 + } + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandSucceededEvent": { + "commandName": "insert" + } + } + ] + } + ] + }, + { + "description": "Handshake with cached token should use speculative authentication", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "insert" + ], + "closeConnection": true + } + } + } + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 + } + }, + "expectError": { + "isClientError": true + } + }, + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "saslStart" + ], + "errorCode": 18 + } + } + } + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 + } + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandSucceededEvent": { + "commandName": "insert" + } + } + ] + } + ] + }, + { + "description": "Handshake without cached token should not use speculative authentication", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "saslStart" + ], + "errorCode": 18 + } + } + } + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 + } + }, + "expectError": { + "errorCode": 18 + } + } + ] + } + ] +} \ No newline at end of file diff --git a/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java b/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java index dfb81ba8de4..cab5b0e0365 100644 --- a/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java +++ b/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java @@ -16,9 +16,13 @@ package com.mongodb; +import com.mongodb.internal.connection.OidcAuthenticator; +import com.mongodb.lang.Nullable; import junit.framework.TestCase; +import org.bson.BsonArray; import org.bson.BsonDocument; import org.bson.BsonNull; +import org.bson.BsonString; import org.bson.BsonValue; import org.junit.Test; import org.junit.runner.RunWith; @@ -32,7 +36,10 @@ import java.util.Collection; import java.util.List; -// See https://github.com/mongodb/specifications/tree/master/source/auth/tests +import static com.mongodb.AuthenticationMechanism.MONGODB_OIDC; +import static com.mongodb.MongoCredential.OIDC_CALLBACK_KEY; + +// See https://github.com/mongodb/specifications/tree/master/source/auth/legacy/tests @RunWith(Parameterized.class) public class AuthConnectionStringTest extends TestCase { private final String input; @@ -56,7 +63,7 @@ public void shouldPassAllOutcomes() { @Parameterized.Parameters(name = "{1}") public static Collection data() throws URISyntaxException, IOException { List data = new ArrayList<>(); - for (File file : JsonPoweredTestHelper.getTestFiles("/auth")) { + for (File file : JsonPoweredTestHelper.getTestFiles("/auth/legacy")) { BsonDocument testDocument = JsonPoweredTestHelper.getTestDocument(file); for (BsonValue test : testDocument.getArray("tests")) { data.add(new Object[]{file.getName(), test.asDocument().getString("description").getValue(), @@ -69,7 +76,7 @@ public static Collection data() throws URISyntaxException, IOException private void testInvalidUris() { Throwable expectedError = null; try { - new ConnectionString(input).getCredential(); + getMongoCredential(); } catch (Throwable t) { expectedError = t; } @@ -78,7 +85,7 @@ private void testInvalidUris() { } private void testValidUris() { - MongoCredential credential = new ConnectionString(input).getCredential(); + MongoCredential credential = getMongoCredential(); if (credential != null) { assertString("credential.source", credential.getSource()); @@ -99,6 +106,32 @@ private void testValidUris() { } } + @Nullable + private MongoCredential getMongoCredential() { + ConnectionString connectionString; + connectionString = new ConnectionString(input); + MongoCredential credential = connectionString.getCredential(); + if (credential != null) { + BsonArray callbacks = (BsonArray) getExpectedValue("callback"); + if (callbacks != null) { + for (BsonValue v : callbacks) { + String string = ((BsonString) v).getValue(); + if ("oidcRequest".equals(string)) { + credential = credential.withMechanismProperty( + OIDC_CALLBACK_KEY, + (MongoCredential.OidcCallback) (context) -> null); + } else { + fail("Unsupported callback: " + string); + } + } + } + if (MONGODB_OIDC.getMechanismName().equals(credential.getMechanism())) { + OidcAuthenticator.OidcValidator.validateBeforeUse(credential); + } + } + return credential; + } + private void assertString(final String key, final String actual) { BsonValue expected = getExpectedValue(key); @@ -142,6 +175,10 @@ private void assertMechanismProperties(final MongoCredential credential) { } } else if ((document.get(key).isBoolean())) { boolean expectedValue = document.getBoolean(key).getValue(); + if (OIDC_CALLBACK_KEY.equals(key)) { + assertTrue(actualMechanismProperty instanceof MongoCredential.OidcCallback); + return; + } assertNotNull(actualMechanismProperty); assertEquals(expectedValue, actualMechanismProperty); } else { diff --git a/driver-core/src/test/unit/com/mongodb/ConnectionStringSpecification.groovy b/driver-core/src/test/unit/com/mongodb/ConnectionStringSpecification.groovy index e8731439a84..d56aa8a9c7c 100644 --- a/driver-core/src/test/unit/com/mongodb/ConnectionStringSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/ConnectionStringSpecification.groovy @@ -601,7 +601,7 @@ class ConnectionStringSpecification extends Specification { new ConnectionString('mongodb://jeff@localhost/?' + 'authMechanism=GSSAPI' + '&authMechanismProperties=' + - 'SERVICE_NAME:foo:bar') + 'SERVICE_NAMEbar') // missing = then: thrown(IllegalArgumentException) diff --git a/driver-core/src/test/unit/com/mongodb/ConnectionStringUnitTest.java b/driver-core/src/test/unit/com/mongodb/ConnectionStringUnitTest.java index d2e41ebeafd..6a8d9ff4fc3 100644 --- a/driver-core/src/test/unit/com/mongodb/ConnectionStringUnitTest.java +++ b/driver-core/src/test/unit/com/mongodb/ConnectionStringUnitTest.java @@ -15,11 +15,16 @@ */ package com.mongodb; +import com.mongodb.assertions.Assertions; import com.mongodb.connection.ServerMonitoringMode; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import java.io.UnsupportedEncodingException; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; + import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals; @@ -34,6 +39,48 @@ void defaults() { assertAll(() -> assertNull(connectionStringDefault.getServerMonitoringMode())); } + @Test + public void mustDecodeOidcIndividually() { + String string = "abc,d!@#$%^&*;ef:ghi"; + // encoded tags will fail parsing with an "invalid read preference tag" + // error if decoding is skipped. + String encodedTags = encode("dc:ny,rack:1"); + ConnectionString cs = new ConnectionString( + "mongodb://localhost/?readPreference=primaryPreferred&readPreferenceTags=" + encodedTags + + "&authMechanism=MONGODB-OIDC&authMechanismProperties=" + + "ENVIRONMENT:azure,TOKEN_RESOURCE:" + encode(string)); + MongoCredential credential = Assertions.assertNotNull(cs.getCredential()); + assertEquals(string, credential.getMechanismProperty("TOKEN_RESOURCE", null)); + } + + @Test + public void mustDecodeNonOidcAsWhole() { + // this string allows us to check if there is no double decoding + String rawValue = encode("ot her"); + assertAll(() -> { + // even though only one part has been encoded by the user, the whole option value (pre-split) must be decoded + ConnectionString cs = new ConnectionString( + "mongodb://foo:bar@example.com/?authMechanism=GSSAPI&authMechanismProperties=" + + "SERVICE_NAME:" + encode(rawValue) + ",CANONICALIZE_HOST_NAME:true&authSource=$external"); + MongoCredential credential = Assertions.assertNotNull(cs.getCredential()); + assertEquals(rawValue, credential.getMechanismProperty("SERVICE_NAME", null)); + }, () -> { + ConnectionString cs = new ConnectionString( + "mongodb://foo:bar@example.com/?authMechanism=GSSAPI&authMechanismProperties=" + + encode("SERVICE_NAME:" + rawValue + ",CANONICALIZE_HOST_NAME:true&authSource=$external")); + MongoCredential credential = Assertions.assertNotNull(cs.getCredential()); + assertEquals(rawValue, credential.getMechanismProperty("SERVICE_NAME", null)); + }); + } + + private static String encode(final String string) { + try { + return URLEncoder.encode(string, StandardCharsets.UTF_8.name()); + } catch (UnsupportedEncodingException e) { + throw new RuntimeException(e); + } + } + @ParameterizedTest @ValueSource(strings = {DEFAULT_OPTIONS + "serverMonitoringMode=stream"}) void equalAndHashCode(final String connectionString) { diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationAsyncProseTests.java b/driver-reactive-streams/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationAsyncProseTests.java new file mode 100644 index 00000000000..276dc9b68a9 --- /dev/null +++ b/driver-reactive-streams/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationAsyncProseTests.java @@ -0,0 +1,68 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.connection; + +import com.mongodb.MongoClientSettings; +import com.mongodb.client.MongoClient; +import com.mongodb.reactivestreams.client.MongoClients; +import com.mongodb.reactivestreams.client.syncadapter.SyncMongoClient; +import org.junit.jupiter.api.Test; +import reactivestreams.helpers.SubscriberHelpers; + +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static util.ThreadTestHelpers.executeAll; + +public class OidcAuthenticationAsyncProseTests extends OidcAuthenticationProseTests { + + @Override + protected MongoClient createMongoClient(final MongoClientSettings settings) { + return new SyncMongoClient(MongoClients.create(settings)); + } + + @Test + public void testNonblockingCallbacks() { + // not a prose spec test + delayNextFind(); + + int simulatedDelayMs = 100; + TestCallback requestCallback = createCallback().setDelayMs(simulatedDelayMs); + + MongoClientSettings clientSettings = createSettings(getOidcUri(), requestCallback); + + try (com.mongodb.reactivestreams.client.MongoClient client = MongoClients.create(clientSettings)) { + executeAll(2, () -> { + SubscriberHelpers.OperationSubscriber subscriber = new SubscriberHelpers.OperationSubscriber<>(); + long t1 = System.nanoTime(); + client.getDatabase("test") + .getCollection("test") + .find() + .first() + .subscribe(subscriber); + long elapsedMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - t1); + + assertTrue(elapsedMs < simulatedDelayMs); + subscriber.get(); + }); + + // ensure both callbacks have been tested + assertEquals(1, requestCallback.getInvocations()); + } + } +} diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java b/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java index 4845ac460a1..76e49d68cdb 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java +++ b/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java @@ -19,18 +19,14 @@ import com.mongodb.ClientEncryptionSettings; import com.mongodb.ClientSessionOptions; import com.mongodb.MongoClientSettings; +import com.mongodb.MongoCredential; import com.mongodb.ReadConcern; import com.mongodb.ReadConcernLevel; import com.mongodb.ReadPreference; import com.mongodb.ServerApi; import com.mongodb.ServerApiVersion; -import com.mongodb.event.TestServerMonitorListener; -import com.mongodb.internal.connection.ServerMonitoringModeUtil; -import com.mongodb.internal.connection.TestClusterListener; -import com.mongodb.logging.TestLoggingInterceptor; import com.mongodb.TransactionOptions; import com.mongodb.WriteConcern; -import com.mongodb.assertions.Assertions; import com.mongodb.client.ClientSession; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoCollection; @@ -59,11 +55,15 @@ import com.mongodb.event.ConnectionPoolListener; import com.mongodb.event.ConnectionPoolReadyEvent; import com.mongodb.event.ConnectionReadyEvent; +import com.mongodb.event.TestServerMonitorListener; +import com.mongodb.internal.connection.ServerMonitoringModeUtil; +import com.mongodb.internal.connection.TestClusterListener; import com.mongodb.internal.connection.TestCommandListener; import com.mongodb.internal.connection.TestConnectionPoolListener; import com.mongodb.internal.connection.TestServerListener; import com.mongodb.internal.logging.LogMessage; import com.mongodb.lang.NonNull; +import com.mongodb.logging.TestLoggingInterceptor; import org.bson.BsonArray; import org.bson.BsonBoolean; import org.bson.BsonDocument; @@ -87,9 +87,12 @@ import java.util.function.Function; import java.util.stream.Collectors; +import static com.mongodb.AuthenticationMechanism.MONGODB_OIDC; import static com.mongodb.ClusterFixture.getMultiMongosConnectionString; import static com.mongodb.ClusterFixture.isLoadBalanced; import static com.mongodb.ClusterFixture.isSharded; +import static com.mongodb.assertions.Assertions.assertNotNull; +import static com.mongodb.assertions.Assertions.notNull; import static com.mongodb.client.Fixture.getMongoClientSettingsBuilder; import static com.mongodb.client.Fixture.getMultiMongosMongoClientSettingsBuilder; import static com.mongodb.client.unified.EventMatcher.getReasonString; @@ -98,6 +101,7 @@ import static com.mongodb.client.unified.UnifiedCrudHelper.asReadPreference; import static com.mongodb.client.unified.UnifiedCrudHelper.asWriteConcern; import static com.mongodb.internal.connection.AbstractConnectionPoolTest.waitForPoolAsyncWorkManagerStart; +import static java.lang.System.getenv; import static java.util.Arrays.asList; import static java.util.Collections.synchronizedList; import static org.junit.Assume.assumeTrue; @@ -391,8 +395,10 @@ private void initClient(final BsonDocument entity, final String id, .getArray("ignoreCommandMonitoringEvents", new BsonArray()).stream() .map(type -> type.asString().getValue()).collect(Collectors.toList()); ignoreCommandMonitoringEvents.add("configureFailPoint"); - TestCommandListener testCommandListener = new TestCommandListener(observeEvents, - ignoreCommandMonitoringEvents, entity.getBoolean("observeSensitiveCommands", BsonBoolean.FALSE).getValue()); + TestCommandListener testCommandListener = new TestCommandListener( + observeEvents, + ignoreCommandMonitoringEvents, entity.getBoolean("observeSensitiveCommands", BsonBoolean.FALSE).getValue(), + null); clientSettingsBuilder.addCommandListener(testCommandListener); putEntity(id + "-command-listener", testCommandListener, clientCommandListeners); @@ -516,6 +522,43 @@ private void initClient(final BsonDocument entity, final String id, clientSettingsBuilder.applyToServerSettings(builder -> builder.serverMonitoringMode( ServerMonitoringModeUtil.fromString(value.asString().getValue()))); break; + case "authMechanism": + if (value.equals(new BsonString(MONGODB_OIDC.getMechanismName()))) { + // authMechanismProperties depends on authMechanism + BsonDocument authMechanismProperties = entity + .getDocument("uriOptions") + .getDocument("authMechanismProperties"); + boolean hasPlaceholder = authMechanismProperties.equals( + new BsonDocument("$$placeholder", new BsonInt32(1))); + if (!hasPlaceholder) { + throw new UnsupportedOperationException( + "Unsupported authMechanismProperties for authMechanism: " + value); + } + + String env = assertNotNull(getenv("OIDC_ENV")); + MongoCredential oidcCredential = MongoCredential + .createOidcCredential(null) + .withMechanismProperty("ENVIRONMENT", env); + if (env.equals("azure")) { + oidcCredential = oidcCredential.withMechanismProperty( + MongoCredential.TOKEN_RESOURCE_KEY, getenv("AZUREOIDC_RESOURCE")); + } else if (env.equals("gcp")) { + oidcCredential = oidcCredential.withMechanismProperty( + MongoCredential.TOKEN_RESOURCE_KEY, getenv("GCPOIDC_RESOURCE")); + } + clientSettingsBuilder.credential(oidcCredential); + break; + } + throw new UnsupportedOperationException("Unsupported authMechanism: " + value); + case "authMechanismProperties": + // authMechanismProperties are handled as part of authMechanism, above + BsonValue authMechanism = entity + .getDocument("uriOptions") + .get("authMechanism"); + if (authMechanism.equals(new BsonString(MONGODB_OIDC.getMechanismName()))) { + break; + } + throw new UnsupportedOperationException("Failure to apply authMechanismProperties: " + value); default: throw new UnsupportedOperationException("Unsupported uri option: " + key); } @@ -679,7 +722,7 @@ private void initClientEncryption(final BsonDocument entity, final String id, } } - putEntity(id, clientEncryptionSupplier.apply(Assertions.notNull("mongoClient", mongoClient), builder.build()), clientEncryptions); + putEntity(id, clientEncryptionSupplier.apply(notNull("mongoClient", mongoClient), builder.build()), clientEncryptions); } private TransactionOptions getTransactionOptions(final BsonDocument options) { diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/ErrorMatcher.java b/driver-sync/src/test/functional/com/mongodb/client/unified/ErrorMatcher.java index e232a4c9688..7c0d340a9ad 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/unified/ErrorMatcher.java +++ b/driver-sync/src/test/functional/com/mongodb/client/unified/ErrorMatcher.java @@ -20,6 +20,7 @@ import com.mongodb.MongoClientException; import com.mongodb.MongoCommandException; import com.mongodb.MongoException; +import com.mongodb.MongoSecurityException; import com.mongodb.MongoExecutionTimeoutException; import com.mongodb.MongoServerException; import com.mongodb.MongoSocketException; @@ -76,12 +77,17 @@ void assertErrorsMatch(final BsonDocument expectedError, final Exception e) { valueMatcher.assertValuesMatch(expectedError.getDocument("errorResponse"), ((MongoCommandException) e).getResponse()); } if (expectedError.containsKey("errorCode")) { - assertTrue(context.getMessage("Exception must be of type MongoCommandException or MongoQueryException when checking" - + " for error codes"), - e instanceof MongoCommandException || e instanceof MongoWriteException); - int errorCode = (e instanceof MongoCommandException) - ? ((MongoCommandException) e).getErrorCode() - : ((MongoWriteException) e).getCode(); + Exception errorCodeException = e; + if (e instanceof MongoSecurityException && e.getCause() instanceof MongoCommandException) { + errorCodeException = (Exception) e.getCause(); + } + assertTrue(context.getMessage("Exception must be of type MongoCommandException or MongoWriteException when checking" + + " for error codes, but was " + e.getClass().getSimpleName()), + errorCodeException instanceof MongoCommandException + || errorCodeException instanceof MongoWriteException); + int errorCode = (errorCodeException instanceof MongoCommandException) + ? ((MongoCommandException) errorCodeException).getErrorCode() + : ((MongoWriteException) errorCodeException).getCode(); assertEquals(context.getMessage("Error codes must match"), expectedError.getNumber("errorCode").intValue(), errorCode); diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/RunOnRequirementsMatcher.java b/driver-sync/src/test/functional/com/mongodb/client/unified/RunOnRequirementsMatcher.java index bf6c0dcda01..60553c73f96 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/unified/RunOnRequirementsMatcher.java +++ b/driver-sync/src/test/functional/com/mongodb/client/unified/RunOnRequirementsMatcher.java @@ -69,7 +69,19 @@ public static boolean runOnRequirementsMet(final BsonArray runOnRequirements, fi } break; case "auth": - if (curRequirement.getValue().asBoolean().getValue() == (clientSettings.getCredential() == null)) { + boolean authRequired = curRequirement.getValue().asBoolean().getValue(); + boolean credentialPresent = clientSettings.getCredential() != null; + + if (authRequired != credentialPresent) { + requirementMet = false; + break requirementLoop; + } + break; + case "authMechanism": + boolean containsMechanism = getServerParameters() + .getArray("authenticationMechanisms") + .contains(curRequirement.getValue()); + if (!containsMechanism) { requirementMet = false; break requirementLoop; } diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedAuthTest.java b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedAuthTest.java new file mode 100644 index 00000000000..f94977f2546 --- /dev/null +++ b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedAuthTest.java @@ -0,0 +1,39 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client.unified; + +import org.bson.BsonArray; +import org.bson.BsonDocument; +import org.junit.runners.Parameterized; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.Collection; + +public class UnifiedAuthTest extends UnifiedSyncTest { + public UnifiedAuthTest(@SuppressWarnings("unused") final String fileDescription, + @SuppressWarnings("unused") final String testDescription, + final String schemaVersion, final BsonArray runOnRequirements, final BsonArray entitiesArray, + final BsonArray initialData, final BsonDocument definition) { + super(schemaVersion, runOnRequirements, entitiesArray, initialData, definition); + } + + @Parameterized.Parameters(name = "{0}: {1}") + public static Collection data() throws URISyntaxException, IOException { + return getTestData("unified-test-format/auth"); + } +} diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java index c1741bd5f33..62eac081d4e 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java @@ -210,7 +210,9 @@ public void setUp() { || schemaVersion.equals("1.14") || schemaVersion.equals("1.15") || schemaVersion.equals("1.16") - || schemaVersion.equals("1.17")); + || schemaVersion.equals("1.17") + || schemaVersion.equals("1.18") + || schemaVersion.equals("1.19")); if (runOnRequirements != null) { assumeTrue("Run-on requirements not met", runOnRequirementsMet(runOnRequirements, getMongoClientSettings(), getServerVersion())); diff --git a/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java b/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java new file mode 100644 index 00000000000..9915f6a6a34 --- /dev/null +++ b/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java @@ -0,0 +1,1120 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.connection; + +import com.mongodb.ConnectionString; +import com.mongodb.MongoClientSettings; +import com.mongodb.MongoCommandException; +import com.mongodb.MongoConfigurationException; +import com.mongodb.MongoCredential; +import com.mongodb.MongoSecurityException; +import com.mongodb.MongoSocketException; +import com.mongodb.assertions.Assertions; +import com.mongodb.client.Fixture; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; +import com.mongodb.client.TestListener; +import com.mongodb.event.CommandListener; +import com.mongodb.lang.Nullable; +import org.bson.BsonArray; +import org.bson.BsonBoolean; +import org.bson.BsonDocument; +import org.bson.BsonInt32; +import org.bson.BsonString; +import org.bson.Document; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.opentest4j.AssertionFailedError; + +import java.io.IOException; +import java.lang.reflect.Field; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Random; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import static com.mongodb.MongoCredential.ALLOWED_HOSTS_KEY; +import static com.mongodb.MongoCredential.ENVIRONMENT_KEY; +import static com.mongodb.MongoCredential.OIDC_CALLBACK_KEY; +import static com.mongodb.MongoCredential.OIDC_HUMAN_CALLBACK_KEY; +import static com.mongodb.MongoCredential.OidcCallback; +import static com.mongodb.MongoCredential.OidcCallbackContext; +import static com.mongodb.MongoCredential.OidcCallbackResult; +import static com.mongodb.MongoCredential.TOKEN_RESOURCE_KEY; +import static com.mongodb.assertions.Assertions.assertNotNull; +import static java.lang.System.getenv; +import static java.util.Arrays.asList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; +import static util.ThreadTestHelpers.executeAll; + +/** + * See + * Prose Tests. + */ +public class OidcAuthenticationProseTests { + + private String appName; + + public static boolean oidcTestsEnabled() { + return Boolean.parseBoolean(getenv().get("OIDC_TESTS_ENABLED")); + } + + private void assumeTestEnvironment() { + assumeTrue(getenv("OIDC_TOKEN_DIR") != null); + } + + protected static String getOidcUri() { + return getenv("MONGODB_URI_SINGLE"); + } + + private static String getOidcUriMulti() { + return getenv("MONGODB_URI_MULTI"); + } + + private static String getOidcEnv() { + return getenv("OIDC_ENV"); + } + + private static void assumeAzure() { + assumeTrue(getOidcEnv().equals("azure")); + } + + @Nullable + private static String getUserWithDomain(@Nullable final String user) { + return user == null ? null : user + "@" + getenv("OIDC_DOMAIN"); + } + + private static String oidcTokenDirectory() { + String dir = getenv("OIDC_TOKEN_DIR"); + if (!dir.endsWith("/")) { + dir = dir + "/"; + } + return dir; + } + + private static String getTestTokenFilePath() { + return getenv(OidcAuthenticator.OIDC_TOKEN_FILE); + } + + protected MongoClient createMongoClient(final MongoClientSettings settings) { + return MongoClients.create(settings); + } + + @BeforeEach + public void beforeEach() { + assumeTrue(oidcTestsEnabled()); + InternalStreamConnection.setRecordEverything(true); + this.appName = this.getClass().getSimpleName() + "-" + new Random().nextInt(Integer.MAX_VALUE); + } + + @AfterEach + public void afterEach() { + InternalStreamConnection.setRecordEverything(false); + } + + @Test + public void test1p1CallbackIsCalledDuringAuth() { + // #. Create a ``MongoClient`` configured with an OIDC callback... + TestCallback callback = createCallback(); + MongoClientSettings clientSettings = createSettings(callback); + // #. Perform a find operation that succeeds + performFind(clientSettings); + assertEquals(1, callback.invocations.get()); + } + + @Test + public void test1p2CallbackCalledOnceForMultipleConnections() { + TestCallback callback = createCallback(); + MongoClientSettings clientSettings = createSettings(callback); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + List threads = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + Thread t = new Thread(() -> performFind(mongoClient)); + t.setDaemon(true); + t.start(); + threads.add(t); + } + for (Thread t : threads) { + try { + t.join(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + } + assertEquals(1, callback.invocations.get()); + } + + @Test + public void test2p1ValidCallbackInputs() { + Duration expectedSeconds = Duration.ofMinutes(5); + + TestCallback callback1 = createCallback(); + // #. Verify that the request callback was called with the appropriate + // inputs, including the timeout parameter if possible. + OidcCallback callback2 = (context) -> { + assertEquals(expectedSeconds, context.getTimeout()); + return callback1.onRequest(context); + }; + MongoClientSettings clientSettings = createSettings(callback2); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + performFind(mongoClient); + // callback was called + assertEquals(1, callback1.getInvocations()); + } + } + + @Test + public void test2p2RequestCallbackReturnsNull() { + //noinspection ConstantConditions + OidcCallback callback = (context) -> null; + MongoClientSettings clientSettings = this.createSettings(callback); + assertFindFails(clientSettings, MongoConfigurationException.class, + "Result of callback must not be null"); + } + + @Test + public void test2p3CallbackReturnsMissingData() { + // #. Create a client with a request callback that returns data not + // conforming to the OIDCRequestTokenResult with missing field(s). + OidcCallback callback = (context) -> { + //noinspection ConstantConditions + return new OidcCallbackResult(null); + }; + // we ensure that the error is propagated + MongoClientSettings clientSettings = createSettings(callback); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + assertCause(IllegalArgumentException.class, + "accessToken can not be null", + () -> performFind(mongoClient)); + } + } + + @Test + public void test2p4InvalidClientConfigurationWithCallback() { + String uri = getOidcUri() + "&authMechanismProperties=ENVIRONMENT:" + getOidcEnv(); + MongoClientSettings settings = createSettings( + uri, createCallback(), null, OIDC_CALLBACK_KEY); + assertCause(IllegalArgumentException.class, + "OIDC_CALLBACK must not be specified when ENVIRONMENT is specified", + () -> performFind(settings)); + } + + @Test + public void test3p1AuthFailsWithCachedToken() throws ExecutionException, InterruptedException, NoSuchFieldException, IllegalAccessException { + TestCallback callbackWrapped = createCallback(); + // reference to the token to poison + CompletableFuture poisonToken = new CompletableFuture<>(); + OidcCallback callback = (context) -> { + OidcCallbackResult result = callbackWrapped.onRequest(context); + String accessToken = result.getAccessToken(); + if (!poisonToken.isDone()) { + poisonToken.complete(accessToken); + } + return result; + }; + + MongoClientSettings clientSettings = createSettings(callback); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + // populate cache + performFind(mongoClient); + assertEquals(1, callbackWrapped.invocations.get()); + // Poison the *Client Cache* with an invalid access token. + // uses reflection + String poisonString = poisonToken.get(); + Field f = String.class.getDeclaredField("value"); + f.setAccessible(true); + byte[] poisonChars = (byte[]) f.get(poisonString); + poisonChars[0] = '~'; + poisonChars[1] = '~'; + + assertEquals(1, callbackWrapped.invocations.get()); + + // cause another connection to be opened + delayNextFind(); + executeAll(2, () -> performFind(mongoClient)); + } + assertEquals(2, callbackWrapped.invocations.get()); + } + + @Test + public void test3p2AuthFailsWithoutCachedToken() { + OidcCallback callback = + (x) -> new OidcCallbackResult("invalid_token"); + MongoClientSettings clientSettings = createSettings(callback); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + assertCause(MongoCommandException.class, + "Command failed with error 18 (AuthenticationFailed):", + () -> performFind(mongoClient)); + } + } + + @Test + public void test3p3UnexpectedErrorDoesNotClearCache() { + assumeTestEnvironment(); + + TestListener listener = new TestListener(); + TestCommandListener commandListener = new TestCommandListener(listener); + + TestCallback callback = createCallback(); + MongoClientSettings clientSettings = createSettings(getOidcUri(), callback, commandListener); + + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + failCommand(20, 1, "saslStart"); + assertCause(MongoCommandException.class, + "Command failed with error 20", + () -> performFind(mongoClient)); + + assertEquals(Arrays.asList( + "isMaster started", + "isMaster succeeded", + "saslStart started", + "saslStart failed" + ), listener.getEventStrings()); + + assertEquals(1, callback.getInvocations()); + performFind(mongoClient); + assertEquals(1, callback.getInvocations()); + } + } + + @Test + public void test4p1Reauthentication() { + TestCallback callback = createCallback(); + MongoClientSettings clientSettings = createSettings(callback); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + failCommand(391, 1, "find"); + // #. Perform a find operation that succeeds. + performFind(mongoClient); + } + assertEquals(2, callback.invocations.get()); + } + + @Test + public void test4p2ReadCommandsFailIfReauthenticationFails() { + // Create a `MongoClient` whose OIDC callback returns one good token + // and then bad tokens after the first call. + TestCallback wrappedCallback = createCallback(); + OidcCallback callback = (context) -> { + OidcCallbackResult result1 = wrappedCallback.callback(context); + return new OidcCallbackResult(wrappedCallback.getInvocations() > 1 ? "bad" : result1.getAccessToken()); + }; + MongoClientSettings clientSettings = createSettings(callback); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + performFind(mongoClient); + failCommand(391, 1, "find"); + assertCause(MongoCommandException.class, + "Command failed with error 18", + () -> performFind(mongoClient)); + } + assertEquals(2, wrappedCallback.invocations.get()); + } + + @Test + public void test4p3WriteCommandsFailIfReauthenticationFails() { + // Create a `MongoClient` whose OIDC callback returns one good token + // and then bad tokens after the first call. + TestCallback wrappedCallback = createCallback(); + OidcCallback callback = (context) -> { + OidcCallbackResult result1 = wrappedCallback.callback(context); + return new OidcCallbackResult( + wrappedCallback.getInvocations() > 1 ? "bad" : result1.getAccessToken()); + }; + MongoClientSettings clientSettings = createSettings(callback); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + performInsert(mongoClient); + failCommand(391, 1, "insert"); + assertCause(MongoCommandException.class, + "Command failed with error 18", + () -> performInsert(mongoClient)); + } + assertEquals(2, wrappedCallback.invocations.get()); + } + + private static void performInsert(final MongoClient mongoClient) { + mongoClient + .getDatabase("test") + .getCollection("test") + .insertOne(Document.parse("{ x: 1 }")); + } + + @Test + public void test5p1AzureSucceedsWithNoUsername() { + assumeAzure(); + String oidcUri = getOidcUri(); + MongoClientSettings clientSettings = createSettings(oidcUri, createCallback(), null); + // Create an OIDC configured client with `ENVIRONMENT:azure` and a valid + // `TOKEN_RESOURCE` and no username. + MongoCredential credential = Assertions.assertNotNull(clientSettings.getCredential()); + assertNotNull(credential.getMechanismProperty(TOKEN_RESOURCE_KEY, null)); + assertNull(credential.getUserName()); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + // Perform a `find` operation that succeeds. + performFind(mongoClient); + } + } + + @Test + public void test5p2AzureFailsWithBadUsername() { + assumeAzure(); + String oidcUri = getOidcUri(); + ConnectionString cs = new ConnectionString(oidcUri); + MongoCredential oldCredential = Assertions.assertNotNull(cs.getCredential()); + String tokenResource = oldCredential.getMechanismProperty(TOKEN_RESOURCE_KEY, null); + assertNotNull(tokenResource); + MongoCredential cred = MongoCredential.createOidcCredential("bad") + .withMechanismProperty(ENVIRONMENT_KEY, "azure") + .withMechanismProperty(TOKEN_RESOURCE_KEY, tokenResource); + MongoClientSettings.Builder builder = MongoClientSettings.builder() + .applicationName(appName) + .retryReads(false) + .applyConnectionString(cs) + .credential(cred); + MongoClientSettings clientSettings = builder.build(); + // the failure is external to the driver + assertFindFails(clientSettings, IOException.class, "400 Bad Request"); + } + + // Tests for human authentication ("testh", to preserve ordering) + + @Test + public void testh1p1SinglePrincipalImplicitUsername() { + assumeTestEnvironment(); + // #. Create default OIDC client with authMechanism=MONGODB-OIDC. + TestCallback callback = createHumanCallback(); + MongoClientSettings clientSettings = createHumanSettings(callback, null); + // #. Perform a find operation that succeeds + performFind(clientSettings); + assertEquals(1, callback.invocations.get()); + } + + @Test + public void testh1p2SinglePrincipalExplicitUsername() { + assumeTestEnvironment(); + // #. Create a client with MONGODB_URI_SINGLE, a username of test_user1, + // authMechanism=MONGODB-OIDC, and the OIDC human callback. + TestCallback callback = createHumanCallback(); + MongoClientSettings clientSettings = createSettingsHuman(getUserWithDomain("test_user1"), callback, getOidcUri()); + // #. Perform a find operation that succeeds + performFind(clientSettings); + } + + @Test + public void testh1p3MultiplePrincipalUser1() { + assumeTestEnvironment(); + // #. Create a client with MONGODB_URI_MULTI, a username of test_user1, + // authMechanism=MONGODB-OIDC, and the OIDC human callback. + MongoClientSettings clientSettings = createSettingsMulti(getUserWithDomain("test_user1"), createHumanCallback()); + // #. Perform a find operation that succeeds + performFind(clientSettings); + } + + @Test + public void testh1p4MultiplePrincipalUser2() { + assumeTestEnvironment(); + //- Create a human callback that reads in the generated ``test_user2`` token file. + //- Create a client with ``MONGODB_URI_MULTI``, a username of ``test_user2``, + // ``authMechanism=MONGODB-OIDC``, and the OIDC human callback. + MongoClientSettings clientSettings = createSettingsMulti(getUserWithDomain("test_user2"), createHumanCallback() + .setPathSupplier(() -> tokenQueue("test_user2").remove())); + performFind(clientSettings); + } + + @Test + public void testh1p5MultiplePrincipalNoUser() { + assumeTestEnvironment(); + // Create an OIDC configured client with `MONGODB_URI_MULTI` and no username. + MongoClientSettings clientSettings = createSettingsMulti(null, createHumanCallback()); + // Assert that a `find` operation fails. + assertFindFails(clientSettings, MongoCommandException.class, "Authentication failed"); + } + + @Test + public void testh1p6AllowedHostsBlocked() { + assumeTestEnvironment(); + //- Create a default OIDC client, with an ``ALLOWED_HOSTS`` that is an empty list. + //- Assert that a ``find`` operation fails with a client-side error. + MongoClientSettings clientSettings1 = createSettings(getOidcUri(), + createHumanCallback(), null, OIDC_HUMAN_CALLBACK_KEY, Collections.emptyList()); + assertFindFails(clientSettings1, MongoSecurityException.class, "not permitted by ALLOWED_HOSTS"); + + //- Create a client that uses the URL + // ``mongodb://localhost/?authMechanism=MONGODB-OIDC&ignored=example.com``, a + // human callback, and an ``ALLOWED_HOSTS`` that contains ``["example.com"]``. + //- Assert that a ``find`` operation fails with a client-side error. + MongoClientSettings clientSettings2 = createSettings(getOidcUri() + "&ignored=example.com", + createHumanCallback(), null, OIDC_HUMAN_CALLBACK_KEY, Arrays.asList("example.com")); + assertFindFails(clientSettings2, MongoSecurityException.class, "not permitted by ALLOWED_HOSTS"); + } + + // Not a prose test + @Test + public void testAllowedHostsDisallowedInConnectionString() { + String string = "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ALLOWED_HOSTS:localhost"; + assertCause(IllegalArgumentException.class, + "connection string contains disallowed mechanism properties", + () -> new ConnectionString(string)); + } + + @Test + public void testh1p7AllowedHostsInConnectionStringIgnored() { + // example.com changed to localhost, because resolveAdditionalQueryParametersFromTxtRecords + // fails with "Failed looking up TXT record for host example.com" + String string = "mongodb+srv://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ALLOWED_HOSTS:%5B%22localhost%22%5D"; + assertCause(IllegalArgumentException.class, + "connection string contains disallowed mechanism properties", + () -> new ConnectionString(string)); + } + + @Test + public void testh1p8MachineIdpWithHumanCallback() { + assumeTrue(getenv("OIDC_IS_LOCAL") != null); + + TestCallback callback = createHumanCallback() + .setPathSupplier(() -> oidcTokenDirectory() + "test_machine"); + MongoClientSettings clientSettings = createSettingsHuman( + "test_machine", callback, getOidcUri()); + performFind(clientSettings); + } + + @Test + public void testh2p1ValidCallbackInputs() { + assumeTestEnvironment(); + TestCallback callback1 = createHumanCallback(); + OidcCallback callback2 = (context) -> { + MongoCredential.IdpInfo idpInfo = assertNotNull(context.getIdpInfo()); + assertTrue(assertNotNull(idpInfo.getClientId()).startsWith("0oad")); + assertTrue(idpInfo.getIssuer().endsWith("mock-identity-config-oidc")); + assertEquals(Arrays.asList("fizz", "buzz"), idpInfo.getRequestScopes()); + assertEquals(Duration.ofMinutes(5), context.getTimeout()); + return callback1.onRequest(context); + }; + MongoClientSettings clientSettings = createHumanSettings(callback2, null); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + performFind(mongoClient); + // Ensure that callback was called + assertEquals(1, callback1.getInvocations()); + } + } + + @Test + public void testh2p2HumanCallbackReturnsMissingData() { + assumeTestEnvironment(); + //noinspection ConstantConditions + OidcCallback callbackNull = (context) -> null; + assertFindFails(createHumanSettings(callbackNull, null), + MongoConfigurationException.class, + "Result of callback must not be null"); + + //noinspection ConstantConditions + OidcCallback callback = + (context) -> new OidcCallbackResult(null); + assertFindFails(createHumanSettings(callback, null), + IllegalArgumentException.class, + "accessToken can not be null"); + } + + // not a prose test + @Test + public void testRefreshTokenAbsent() { + // additionally, check validation for refresh in machine workflow: + OidcCallback callbackMachineRefresh = + (context) -> new OidcCallbackResult("access", Duration.ZERO, "exists"); + assertFindFails(createSettings(callbackMachineRefresh), + MongoConfigurationException.class, + "Refresh token must only be provided in human workflow"); + } + + @Test + public void testh2p3RefreshTokenPassed() { + assumeTestEnvironment(); + AtomicInteger refreshTokensProvided = new AtomicInteger(); + TestCallback callback1 = createHumanCallback(); + OidcCallback callback2 = (context) -> { + if (context.getRefreshToken() != null) { + refreshTokensProvided.incrementAndGet(); + } + return callback1.onRequest(context); + }; + MongoClientSettings clientSettings = createHumanSettings(callback2, null); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + performFind(mongoClient); + failCommand(391, 1, "find"); + performFind(mongoClient); + assertEquals(2, callback1.getInvocations()); + assertEquals(1, refreshTokensProvided.get()); + } + } + + @Test + public void testh3p1UsesSpecAuthIfCachedToken() { + assumeTestEnvironment(); + MongoClientSettings clientSettings = createHumanSettings(createHumanCallback(), null); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + failCommandAndCloseConnection("find", 1); + assertCause(MongoSocketException.class, + "Prematurely reached end of stream", + () -> performFind(mongoClient)); + failCommand(18, 1, "saslStart"); + performFind(mongoClient); + } + } + + @Test + public void testh3p2NoSpecAuthIfNoCachedToken() { + assumeTestEnvironment(); + failCommand(18, 1, "saslStart"); + TestListener listener = new TestListener(); + TestCommandListener commandListener = new TestCommandListener(listener); + assertFindFails(createHumanSettings(createHumanCallback(), commandListener), + MongoCommandException.class, + "Command failed with error 18"); + assertEquals(Arrays.asList( + "isMaster started", + "isMaster succeeded", + "saslStart started", + "saslStart failed" + ), listener.getEventStrings()); + listener.clear(); + } + + @Test + public void testh4p1ReauthenticationSucceeds() { + assumeTestEnvironment(); + TestListener listener = new TestListener(); + TestCommandListener commandListener = new TestCommandListener(listener); + TestCallback callback = createHumanCallback() + .setEventListener(listener); + MongoClientSettings clientSettings = createHumanSettings(callback, commandListener); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + performFind(mongoClient); + listener.clear(); + assertEquals(1, callback.getInvocations()); + failCommand(391, 1, "find"); + // Perform another find operation that succeeds. + performFind(mongoClient); + assertEquals(Arrays.asList( + // first find fails: + "find started", + "find failed", + "onRequest invoked (Refresh Token: present - IdpInfo: present)", + "read access token: test_user1", + "saslStart started", + "saslStart succeeded", + // second find succeeds: + "find started", + "find succeeded" + ), listener.getEventStrings()); + assertEquals(2, callback.getInvocations()); + } + } + + @Test + public void testh4p2SucceedsNoRefresh() { + assumeTestEnvironment(); + TestCallback callback = createHumanCallback(); + MongoClientSettings clientSettings = createHumanSettings(callback, null); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + performFind(mongoClient); + assertEquals(1, callback.getInvocations()); + + failCommand(391, 1, "find"); + performFind(mongoClient); + assertEquals(2, callback.getInvocations()); + } + } + + + @Test + public void testh4p3SucceedsAfterRefreshFails() { + assumeTestEnvironment(); + TestCallback callback1 = createHumanCallback(); + OidcCallback callback2 = (context) -> { + OidcCallbackResult oidcCallbackResult = callback1.onRequest(context); + return new OidcCallbackResult(oidcCallbackResult.getAccessToken(), Duration.ofMinutes(5), "BAD_REFRESH"); + }; + MongoClientSettings clientSettings = createHumanSettings(callback2, null); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + performFind(mongoClient); + failCommand(391, 1, "find"); + performFind(mongoClient); + assertEquals(2, callback1.getInvocations()); + } + } + + @Test + public void testh4p4Fails() { + assumeTestEnvironment(); + ConcurrentLinkedQueue tokens = tokenQueue( + "test_user1", + "test_user1_expires", + "test_user1_expires"); + TestCallback callback1 = createHumanCallback() + .setPathSupplier(() -> tokens.remove()); + OidcCallback callback2 = (context) -> { + OidcCallbackResult oidcCallbackResult = callback1.onRequest(context); + return new OidcCallbackResult(oidcCallbackResult.getAccessToken(), Duration.ofMinutes(5), "BAD_REFRESH"); + }; + MongoClientSettings clientSettings = createHumanSettings(callback2, null); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + performFind(mongoClient); + assertEquals(1, callback1.getInvocations()); + failCommand(391, 1, "find"); + assertCause(MongoCommandException.class, + "Command failed with error 18", + () -> performFind(mongoClient)); + assertEquals(3, callback1.getInvocations()); + } + } + + // Not a prose test + @Test + public void testErrorClearsCache() { + assumeTestEnvironment(); + // #. Create a new client with a valid request callback that + // gives credentials that expire within 5 minutes and + // a refresh callback that gives invalid credentials. + TestListener listener = new TestListener(); + ConcurrentLinkedQueue tokens = tokenQueue( + "test_user1", + "test_user1_expires", + "test_user1_expires", + "test_user1_1"); + TestCallback callback = createHumanCallback() + .setRefreshToken("refresh") + .setPathSupplier(() -> tokens.remove()) + .setEventListener(listener); + + TestCommandListener commandListener = new TestCommandListener(listener); + + MongoClientSettings clientSettings = createHumanSettings(callback, commandListener); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + // #. Ensure that a find operation adds a new entry to the cache. + performFind(mongoClient); + assertEquals(Arrays.asList( + "isMaster started", + "isMaster succeeded", + // no speculative auth. Send principal request: + "saslStart started", + "saslStart succeeded", + "onRequest invoked (Refresh Token: none - IdpInfo: present)", + "read access token: test_user1", + // the refresh token from the callback is cached here + // send jwt: + "saslContinue started", + "saslContinue succeeded", + "find started", + "find succeeded" + ), listener.getEventStrings()); + listener.clear(); + + // #. Ensure that a subsequent find operation results in a 391 error. + failCommand(391, 1, "find"); + // ensure that the operation entirely fails, after attempting both potential fallback callbacks + assertThrows(MongoSecurityException.class, () -> performFind(mongoClient)); + assertEquals(Arrays.asList( + "find started", + "find failed", // reauth 391; current access token is invalid + // fall back to refresh token, from prior find + "onRequest invoked (Refresh Token: present - IdpInfo: present)", + "read access token: test_user1_expires", + "saslStart started", + "saslStart failed", // it is expired, fails immediately + // fall back to principal request, and non-refresh callback: + "saslStart started", + "saslStart succeeded", + "onRequest invoked (Refresh Token: none - IdpInfo: present)", + "read access token: test_user1_expires", + "saslContinue started", + "saslContinue failed" // also fails due to 391 + ), listener.getEventStrings()); + listener.clear(); + + // #. Ensure that the cache value cleared. + failCommand(391, 1, "find"); + performFind(mongoClient); + assertEquals(Arrays.asList( + "find started", + "find failed", + // falling back to principal request, onRequest callback. + // this implies that the cache has been cleared during the + // preceding find operation. + "saslStart started", + "saslStart succeeded", + "onRequest invoked (Refresh Token: none - IdpInfo: present)", + "read access token: test_user1_1", + "saslContinue started", + "saslContinue succeeded", + // auth has finished + "find started", + "find succeeded" + ), listener.getEventStrings()); + listener.clear(); + } + } + + + private MongoClientSettings createSettings(final OidcCallback callback) { + return createSettings(getOidcUri(), callback, null); + } + + public MongoClientSettings createSettings( + final String connectionString, + @Nullable final TestCallback callback) { + return createSettings(connectionString, callback, null); + } + + private MongoClientSettings createSettings( + final String connectionString, + @Nullable final OidcCallback callback, + @Nullable final CommandListener commandListener) { + String cleanedConnectionString = callback == null ? connectionString : connectionString + .replace("ENVIRONMENT:azure,", "") + .replace("ENVIRONMENT:gcp,", "") + .replace("ENVIRONMENT:test,", ""); + return createSettings(cleanedConnectionString, callback, commandListener, OIDC_CALLBACK_KEY); + } + + private MongoClientSettings createHumanSettings( + final OidcCallback callback, @Nullable final TestCommandListener commandListener) { + return createHumanSettings(getOidcUri(), callback, commandListener); + } + + private MongoClientSettings createHumanSettings( + final String connectionString, + @Nullable final OidcCallback callback, + @Nullable final CommandListener commandListener) { + return createSettings(connectionString, callback, commandListener, OIDC_HUMAN_CALLBACK_KEY); + } + + private MongoClientSettings createSettings( + final String connectionString, + final @Nullable OidcCallback callback, + @Nullable final CommandListener commandListener, + final String oidcCallbackKey) { + ConnectionString cs = new ConnectionString(connectionString); + MongoCredential credential = assertNotNull(cs.getCredential()); + if (callback != null) { + credential = credential.withMechanismProperty(oidcCallbackKey, callback); + } + MongoClientSettings.Builder builder = MongoClientSettings.builder() + .applicationName(appName) + .applyConnectionString(cs) + .retryReads(false) + .credential(credential); + if (commandListener != null) { + builder.addCommandListener(commandListener); + } + return builder.build(); + } + + private MongoClientSettings createSettings( + final String connectionString, + @Nullable final OidcCallback callback, + @Nullable final CommandListener commandListener, + final String oidcCallbackKey, + @Nullable final List allowedHosts) { + ConnectionString cs = new ConnectionString(connectionString); + MongoCredential credential = cs.getCredential() + .withMechanismProperty(oidcCallbackKey, callback) + .withMechanismProperty(ALLOWED_HOSTS_KEY, allowedHosts); + MongoClientSettings.Builder builder = MongoClientSettings.builder() + .applicationName(appName) + .applyConnectionString(cs) + .credential(credential); + if (commandListener != null) { + builder.addCommandListener(commandListener); + } + return builder.build(); + } + + private MongoClientSettings createSettingsMulti(@Nullable final String user, final OidcCallback callback) { + return createSettingsHuman(user, callback, getOidcUriMulti()); + } + + private MongoClientSettings createSettingsHuman(@Nullable final String user, final OidcCallback callback, final String oidcUri) { + ConnectionString cs = new ConnectionString(oidcUri); + MongoCredential credential = MongoCredential.createOidcCredential(user) + .withMechanismProperty(OIDC_HUMAN_CALLBACK_KEY, callback); + return MongoClientSettings.builder() + .applicationName(appName) + .applyConnectionString(cs) + .retryReads(false) + .credential(credential) + .build(); + } + + private void performFind(final MongoClientSettings settings) { + try (MongoClient mongoClient = createMongoClient(settings)) { + performFind(mongoClient); + } + } + + private void assertFindFails( + final MongoClientSettings settings, + final Class expectedExceptionOrCause, + final String expectedMessage) { + try (MongoClient mongoClient = createMongoClient(settings)) { + assertCause(expectedExceptionOrCause, expectedMessage, () -> performFind(mongoClient)); + } + } + + private void performFind(final MongoClient mongoClient) { + mongoClient + .getDatabase("test") + .getCollection("test") + .find() + .first(); + } + + private static void assertCause( + final Class expectedCause, final String expectedMessageFragment, final Executable e) { + Throwable cause = assertThrows(Throwable.class, e); + while (cause.getCause() != null) { + cause = cause.getCause(); + } + if (!cause.getMessage().contains(expectedMessageFragment)) { + throw new AssertionFailedError("Unexpected message: " + cause.getMessage(), cause); + } + if (!expectedCause.isInstance(cause)) { + throw new AssertionFailedError("Unexpected cause: " + cause.getClass(), assertThrows(Throwable.class, e)); + } + } + + protected void delayNextFind() { + + try (MongoClient client = createMongoClient(Fixture.getMongoClientSettings())) { + BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand")) + .append("mode", new BsonDocument("times", new BsonInt32(1))) + .append("data", new BsonDocument() + .append("appName", new BsonString(appName)) + .append("failCommands", new BsonArray(asList(new BsonString("find")))) + .append("blockConnection", new BsonBoolean(true)) + .append("blockTimeMS", new BsonInt32(100))); + client.getDatabase("admin").runCommand(failPointDocument); + } + } + + protected void failCommand(final int code, final int times, final String... commands) { + try (MongoClient mongoClient = createMongoClient(Fixture.getMongoClientSettings())) { + List list = Arrays.stream(commands).map(c -> new BsonString(c)).collect(Collectors.toList()); + BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand")) + .append("mode", new BsonDocument("times", new BsonInt32(times))) + .append("data", new BsonDocument() + .append("appName", new BsonString(appName)) + .append("failCommands", new BsonArray(list)) + .append("errorCode", new BsonInt32(code))); + mongoClient.getDatabase("admin").runCommand(failPointDocument); + } + } + + private void failCommandAndCloseConnection(final String command, final int times) { + try (MongoClient mongoClient = createMongoClient(Fixture.getMongoClientSettings())) { + BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand")) + .append("mode", new BsonDocument("times", new BsonInt32(times))) + .append("data", new BsonDocument() + .append("appName", new BsonString(appName)) + .append("closeConnection", new BsonBoolean(true)) + .append("failCommands", new BsonArray(Arrays.asList(new BsonString(command)))) + ); + mongoClient.getDatabase("admin").runCommand(failPointDocument); + } + } + + public static class TestCallback implements OidcCallback { + private final AtomicInteger invocations = new AtomicInteger(); + @Nullable + private final Integer delayInMilliseconds; + @Nullable + private final String refreshToken; + @Nullable + private final AtomicInteger concurrentTracker; + @Nullable + private final TestListener testListener; + @Nullable + private final Supplier pathSupplier; + + public TestCallback() { + this(null, null, new AtomicInteger(), null, null); + } + + public TestCallback( + @Nullable final String refreshToken, + @Nullable final Integer delayInMilliseconds, + @Nullable final AtomicInteger concurrentTracker, + @Nullable final TestListener testListener, + @Nullable final Supplier pathSupplier) { + this.refreshToken = refreshToken; + this.delayInMilliseconds = delayInMilliseconds; + this.concurrentTracker = concurrentTracker; + this.testListener = testListener; + this.pathSupplier = pathSupplier; + } + + public int getInvocations() { + return invocations.get(); + } + + @Override + public OidcCallbackResult onRequest(final OidcCallbackContext context) { + if (testListener != null) { + testListener.add("onRequest invoked (" + + "Refresh Token: " + (context.getRefreshToken() == null ? "none" : "present") + + " - IdpInfo: " + (context.getIdpInfo() == null ? "none" : "present") + + ")"); + } + return callback(context); + } + + private OidcCallbackResult callback(final OidcCallbackContext context) { + if (concurrentTracker != null) { + if (concurrentTracker.get() > 0) { + throw new RuntimeException("Callbacks should not be invoked by multiple threads."); + } + concurrentTracker.incrementAndGet(); + } + try { + invocations.incrementAndGet(); + try { + simulateDelay(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + MongoCredential credential = assertNotNull(new ConnectionString(getOidcUri()).getCredential()); + String oidcEnv = getOidcEnv(); + OidcCallback c; + if (oidcEnv.contains("azure")) { + c = OidcAuthenticator.getAzureCallback(credential); + } else if (oidcEnv.contains("gcp")) { + c = OidcAuthenticator.getGcpCallback(credential); + } else { + c = getProseTestCallback(); + } + return c.onRequest(context); + + } finally { + if (concurrentTracker != null) { + concurrentTracker.decrementAndGet(); + } + } + } + + private OidcCallback getProseTestCallback() { + return (x) -> { + try { + Path path = Paths.get(pathSupplier == null + ? getTestTokenFilePath() + : pathSupplier.get()); + String accessToken = new String(Files.readAllBytes(path), StandardCharsets.UTF_8); + if (testListener != null) { + testListener.add("read access token: " + path.getFileName()); + } + return new OidcCallbackResult(accessToken, Duration.ZERO, refreshToken); + } catch (IOException e) { + throw new RuntimeException(e); + } + }; + } + + private void simulateDelay() throws InterruptedException { + if (delayInMilliseconds != null) { + Thread.sleep(delayInMilliseconds); + } + } + + public TestCallback setDelayMs(final int milliseconds) { + return new TestCallback( + this.refreshToken, + milliseconds, + this.concurrentTracker, + this.testListener, + this.pathSupplier); + } + + public TestCallback setConcurrentTracker(final AtomicInteger c) { + return new TestCallback( + this.refreshToken, + this.delayInMilliseconds, + c, + this.testListener, + this.pathSupplier); + } + + public TestCallback setEventListener(final TestListener testListener) { + return new TestCallback( + this.refreshToken, + this.delayInMilliseconds, + this.concurrentTracker, + testListener, + this.pathSupplier); + } + + public TestCallback setPathSupplier(final Supplier pathSupplier) { + return new TestCallback( + this.refreshToken, + this.delayInMilliseconds, + this.concurrentTracker, + this.testListener, + pathSupplier); + } + + public TestCallback setRefreshToken(final String token) { + return new TestCallback( + token, + this.delayInMilliseconds, + this.concurrentTracker, + this.testListener, + this.pathSupplier); + } + } + + private ConcurrentLinkedQueue tokenQueue(final String... queue) { + String tokenPath = oidcTokenDirectory(); + return java.util.stream.Stream + .of(queue) + .map(v -> tokenPath + v) + .collect(Collectors.toCollection(ConcurrentLinkedQueue::new)); + } + + public TestCallback createCallback() { + return new TestCallback(); + } + + public TestCallback createHumanCallback() { + return new TestCallback() + .setPathSupplier(() -> oidcTokenDirectory() + "test_user1") + .setRefreshToken("refreshToken"); + } +}