diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java index 78b1ed4af2..a03c32f3bf 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java @@ -25,6 +25,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; @@ -148,7 +150,17 @@ public static MLCreateConnectorInput parse(XContentParser parser, boolean update parameters = getParameterMap(parser.map()); break; case CONNECTOR_CREDENTIAL_FIELD: - credential = parser.mapStrings(); + // We need to filter out any key string that is trying to imitate the subfield of any kind of ARN of the credential map + credential = new HashMap<>(); + Map credentialKeyToAdd = parser.mapStrings(); + Pattern pattern = Pattern.compile("[a-zA-Z]+Arn\\."); + for (String key : credentialKeyToAdd.keySet()) { + Matcher matcher = pattern.matcher(key); + boolean matchFound = matcher.find(); + if (!matchFound) { + credential.put(key, credentialKeyToAdd.get(key)); + } + } break; case CONNECTOR_ACTIONS_FIELD: actions = new ArrayList<>(); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java index 125bbaf393..64aa17d91c 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java @@ -31,6 +31,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.function.Consumer; @@ -191,6 +192,28 @@ public void testParse_ArrayParameter() throws Exception { }); } + @Test + public void testParse_SecretArnPrefix() throws Exception { + String expectedInputStr = "{\"name\":\"test_connector_name\"," + + "\"description\":\"this is a test connector\",\"version\":\"1\",\"protocol\":\"http\"," + + "\"parameters\":{\"input\":[\"test input value\"]},\"credential\":{\"key\":\"test_key_value\"," + + "\"secretArn\":\"test_secretArn_value\", \"secretArn.key\":\"test_key_value\"," + + "\"roleArn\":\"test_roleArn_value\", \"roleArn.subfield\":\"test_subfield_value\",\"test_Arn_test\":\"test_value\"}," + + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + + "\"headers\":{\"api_key\":\"${credential.key}\"}," + + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + + "\"backend_roles\":[\"role1\",\"role2\"],\"add_all_backend_roles\":false," + + "\"access_mode\":\"PUBLIC\"}"; + HashSet expectedCredentialKeys = new HashSet<>(Arrays.asList("key", "secretArn", "roleArn","test_Arn_test")); + testParseFromJsonString(expectedInputStr, parsedInput -> { + assertEquals(expectedCredentialKeys, parsedInput.getCredential().keySet()); + assertEquals("test_secretArn_value", parsedInput.getCredential().get("secretArn")); + assertEquals("test_roleArn_value", parsedInput.getCredential().get("roleArn")); + }); + } + @Test public void testParseWithDryRun() throws Exception { String expectedInputStrWithDryRun = "{\"dry_run\":true}";