From 5a2d4555b16ea852e8c62d5cc7614226812bee98 Mon Sep 17 00:00:00 2001 From: Sicheng Song Date: Fri, 15 Dec 2023 04:27:27 +0000 Subject: [PATCH] revert AOS code Signed-off-by: Sicheng Song --- .../common/connector/AbstractConnector.java | 3 - .../ml/common/connector/AwsConnector.java | 26 +--- .../ml/common/connector/Connector.java | 2 - .../MLDeployModelControllerNodeRequest.java | 5 +- .../MLUpdateModelCacheNodesRequest.java | 2 +- ml-algorithms/build.gradle | 3 - .../remote/AwsConnectorExecutor.java | 110 --------------- .../remote/HttpJsonConnectorExecutor.java | 131 ------------------ .../remote/RemoteConnectorExecutor.java | 19 +-- .../credentials/CredentialsProvider.java | 7 - .../aws/ClientConfigurationHelper.java | 59 -------- .../aws/CredentialsProviderFactory.java | 10 -- .../ExpirableCredentialsProviderFactory.java | 84 ----------- .../InternalAuthApiCredentialsProvider.java | 86 ------------ .../InternalAuthCredentialsApiRequest.java | 108 --------------- .../aws/InternalAuthCredentialsClient.java | 64 --------- .../InternalAuthCredentialsClientPool.java | 36 ----- .../aws/InternalAwsCredentials.java | 56 -------- .../aws/PrivilegedCredentialsProvider.java | 31 ----- .../engine/credentials/aws/SocketAccess.java | 42 ------ .../credentialscommunication/Credentials.java | 21 --- .../CredentialsRequest.java | 29 ---- .../SecretManagerCredentials.java | 42 ------ .../SecretsManager.java | 23 --- .../engine/credentialscommunication/Util.java | 38 ----- .../ml/engine/factory/CredentialsFactory.java | 69 --------- .../engine/factory/SecretsManagerFactory.java | 110 --------------- plugin/build.gradle | 3 - .../TransportCreateConnectorAction.java | 1 - .../TransportPredictionTaskAction.java | 3 +- .../TransportRegisterModelAction.java | 1 - .../UpdateModelCacheTransportAction.java | 2 +- .../update/UpdateModelTransportAction.java | 31 +---- .../opensearch/ml/model/MLModelManager.java | 3 +- 34 files changed, 22 insertions(+), 1238 deletions(-) delete mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/CredentialsProvider.java delete mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/ClientConfigurationHelper.java delete mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/CredentialsProviderFactory.java delete mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/ExpirableCredentialsProviderFactory.java delete mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/InternalAuthApiCredentialsProvider.java delete mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/InternalAuthCredentialsApiRequest.java delete mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/InternalAuthCredentialsClient.java delete mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/InternalAuthCredentialsClientPool.java delete mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/InternalAwsCredentials.java delete mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/PrivilegedCredentialsProvider.java delete mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/SocketAccess.java delete mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/credentialscommunication/Credentials.java delete mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/credentialscommunication/CredentialsRequest.java delete mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/credentialscommunication/SecretManagerCredentials.java delete mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/credentialscommunication/SecretsManager.java delete mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/credentialscommunication/Util.java delete mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/factory/CredentialsFactory.java delete mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/factory/SecretsManagerFactory.java diff --git a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java index 058b1f359d..5fa213db99 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java @@ -31,8 +31,6 @@ public abstract class AbstractConnector implements Connector { public static final String ACCESS_KEY_FIELD = "access_key"; public static final String SECRET_KEY_FIELD = "secret_key"; public static final String SESSION_TOKEN_FIELD = "session_token"; - public static final String ROLE_ARN_FIELD = "roleArn"; - public static final String SECRET_ARN_FIELD = "secretArn"; public static final String NAME_FIELD = "name"; public static final String VERSION_FIELD = "version"; public static final String DESCRIPTION_FIELD = "description"; @@ -52,7 +50,6 @@ public abstract class AbstractConnector implements Connector { protected String protocol; protected Map parameters; - @Getter protected Map credential; protected Map decryptedHeaders; @Setter diff --git a/common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java index b2cc03d8fc..ed9c64ac94 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java @@ -33,21 +33,18 @@ public AwsConnector(String name, String description, String version, String prot Map parameters, Map credential, List actions, List backendRoles, AccessMode accessMode, User owner) { super(name, description, version, protocol, parameters, credential, actions, backendRoles, accessMode, owner); - //validate(); - validateAwsConnectorInManagedService(); + validate(); } public AwsConnector(String protocol, XContentParser parser) throws IOException { super(protocol, parser); - //validate(); - validateAwsConnectorInManagedService(); + validate(); } public AwsConnector(StreamInput input) throws IOException { super(input); - //validate(); - validateAwsConnectorInManagedService(); + validate(); } private void validate() { @@ -62,19 +59,6 @@ private void validate() { } } - private void validateAwsConnectorInManagedService() { - // Users who are using AWS protocol must give a roleArn in credentials - if (credential == null || !credential.containsKey("roleArn")) { - throw new IllegalArgumentException("please supply a valid roleArn in credentials if utilizing an AWS service"); - } - if ((credential == null || !credential.containsKey(SERVICE_NAME_FIELD)) && (parameters == null || !parameters.containsKey(SERVICE_NAME_FIELD))) { - throw new IllegalArgumentException("Missing or invalid service name"); - } - if ((credential == null || !credential.containsKey(REGION_FIELD)) && (parameters == null || !parameters.containsKey(REGION_FIELD))) { - throw new IllegalArgumentException("Missing region"); - } - } - @Override public Connector cloneConnector() { try (BytesStreamOutput bytesStreamOutput = new BytesStreamOutput()){ @@ -90,10 +74,6 @@ public String getAccessKey() { return decryptedCredential.get(ACCESS_KEY_FIELD); } - public String getRoleArn() { - return decryptedCredential.get(ROLE_ARN_FIELD); - } - public String getSecretKey() { return decryptedCredential.get(SECRET_KEY_FIELD); } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java index 9c18faefd8..0652a83421 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java @@ -160,6 +160,4 @@ default void validateConnectorURL(List urlRegexes) { } Map getDecryptedHeaders(); - - Map getCredential(); } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodeRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodeRequest.java index c3352054c3..d11e488641 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodeRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodeRequest.java @@ -5,12 +5,13 @@ package org.opensearch.ml.common.transport.controller; -import org.opensearch.action.support.nodes.BaseNodeRequest; import java.io.IOException; import lombok.Getter; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -public class MLDeployModelControllerNodeRequest extends BaseNodeRequest { +import org.opensearch.transport.TransportRequest; + +public class MLDeployModelControllerNodeRequest extends TransportRequest { @Getter private MLDeployModelControllerNodesRequest deployModelControllerNodesRequest; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/update/MLUpdateModelCacheNodesRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/update/MLUpdateModelCacheNodesRequest.java index 88a1159071..8dacc06cef 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/update/MLUpdateModelCacheNodesRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/update/MLUpdateModelCacheNodesRequest.java @@ -27,7 +27,7 @@ public MLUpdateModelCacheNodesRequest(String[] nodeIds, String modelId) { this.modelId = modelId; } - public MLInPlaceUpdateModelNodesRequest(DiscoveryNode[] nodeIds, String modelId) { + public MLUpdateModelCacheNodesRequest(DiscoveryNode[] nodeIds, String modelId) { super(nodeIds); this.modelId = modelId; } diff --git a/ml-algorithms/build.gradle b/ml-algorithms/build.gradle index 23c4d007dd..cd79560e90 100644 --- a/ml-algorithms/build.gradle +++ b/ml-algorithms/build.gradle @@ -66,9 +66,6 @@ dependencies { implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.1' implementation 'com.jayway.jsonpath:json-path:2.8.0' implementation group: 'org.json', name: 'json', version: '20231013' - implementation "com.amazonaws:aws-java-sdk-core:1.12.48" - implementation "com.amazonaws:aws-java-sdk-sts:1.12.48" - implementation "com.amazonaws:aws-java-sdk-secretsmanager:1.12.48" } lombok { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index a31bd24674..178228992e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -6,27 +6,21 @@ package org.opensearch.ml.engine.algorithms.remote; import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR; -import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD; -import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD; -import static org.opensearch.ml.common.connector.AbstractConnector.SESSION_TOKEN_FIELD; import static org.opensearch.ml.common.connector.ConnectorProtocols.AWS_SIGV4; import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processOutput; import static software.amazon.awssdk.http.SdkHttpMethod.POST; import java.io.BufferedReader; -import java.io.IOException; import java.io.InputStreamReader; import java.net.URI; import java.nio.charset.StandardCharsets; import java.security.AccessController; import java.security.PrivilegedExceptionAction; -import java.util.HashMap; import java.util.List; import java.util.Map; import org.opensearch.OpenSearchStatusException; import org.opensearch.client.Client; -import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.util.TokenBucket; import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.connector.AwsConnector; @@ -35,9 +29,6 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.engine.annotation.ConnectorExecutor; -import org.opensearch.ml.engine.credentials.aws.InternalAwsCredentials; -import org.opensearch.ml.engine.credentialscommunication.Credentials; -import org.opensearch.ml.engine.credentialscommunication.CredentialsRequest; import org.opensearch.script.ScriptService; import lombok.Getter; @@ -63,9 +54,6 @@ public class AwsConnectorExecutor implements RemoteConnectorExecutor { private ScriptService scriptService; @Setter @Getter - private ClusterService clusterService; - @Setter - @Getter private TokenBucket modelRateLimiter; @Setter @Getter @@ -83,95 +71,6 @@ public AwsConnectorExecutor(Connector connector) { this(connector, new DefaultSdkHttpClientBuilder().build()); } - private Map getCredentialsFromIAMRole(String roleArn, String clusterName) throws IOException { - Map awsCredentials = new HashMap<>(); - try { - CredentialsRequest credentialsRequest = new CredentialsRequest(roleArn, clusterName); - InternalAwsCredentials credentials = Credentials.getCredentials(credentialsRequest); - awsCredentials.put(ACCESS_KEY_FIELD, credentials.getAccessKey()); - awsCredentials.put(SECRET_KEY_FIELD, credentials.getSecretKey()); - awsCredentials.put(SESSION_TOKEN_FIELD, credentials.getSessionToken()); - } catch (Exception ex) { - log.error("Exception occurred gaining credentials: " + ex); - throw ex; - } - return awsCredentials; - } - - @Override - public void invokeRemoteModelInManagedService( - MLInput mlInput, - Map parameters, - String payload, - List tensorOutputs - ) { - try { - String clusterName = clusterService.getClusterName().toString(); - String roleArn = ""; - if (connector.getDecryptedCredential().get("roleArn") != null) { - roleArn = connector.getDecryptedCredential().get("roleArn"); - } - Map awsCredentials = getCredentialsFromIAMRole(roleArn, clusterName); - String endpoint = connector.getPredictEndpoint(parameters); - RequestBody requestBody = RequestBody.fromString(payload); - - SdkHttpFullRequest.Builder builder = SdkHttpFullRequest - .builder() - .method(POST) - .uri(URI.create(endpoint)) - .contentStreamProvider(requestBody.contentStreamProvider()); - Map headers = connector.getDecryptedHeaders(); - if (headers != null) { - for (String key : headers.keySet()) { - builder.putHeader(key, headers.get(key)); - } - } - SdkHttpFullRequest request = builder.build(); - HttpExecuteRequest executeRequest = HttpExecuteRequest - .builder() - .request(signRequestInManagedService(request, awsCredentials)) - .contentStreamProvider(request.contentStreamProvider().orElse(null)) - .build(); - - HttpExecuteResponse response = AccessController.doPrivileged((PrivilegedExceptionAction) () -> { - return httpClient.prepareRequest(executeRequest).call(); - }); - int statusCode = response.httpResponse().statusCode(); - - AbortableInputStream body = null; - if (response.responseBody().isPresent()) { - body = response.responseBody().get(); - } - - StringBuilder responseBuilder = new StringBuilder(); - if (body != null) { - try (BufferedReader reader = new BufferedReader(new InputStreamReader(body, StandardCharsets.UTF_8))) { - String line; - while ((line = reader.readLine()) != null) { - responseBuilder.append(line); - } - } - } else { - throw new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST); - } - String modelResponse = responseBuilder.toString(); - if (statusCode < 200 || statusCode >= 300) { - throw new OpenSearchStatusException(REMOTE_SERVICE_ERROR + modelResponse, RestStatus.fromCode(statusCode)); - } - - ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters); - tensors.setStatusCode(statusCode); - tensorOutputs.add(tensors); - } catch (RuntimeException exception) { - log.error("Failed to execute predict in aws connector: " + exception.getMessage(), exception); - throw exception; - } catch (Throwable e) { - log.error("Failed to execute predict in aws connector", e); - throw new MLException("Fail to execute predict in aws connector", e); - } - - } - @Override public void invokeRemoteModel(MLInput mlInput, Map parameters, String payload, List tensorOutputs) { try { @@ -243,13 +142,4 @@ private SdkHttpFullRequest signRequest(SdkHttpFullRequest request) { return ConnectorUtils.signRequest(request, accessKey, secretKey, sessionToken, signingName, region); } - - private SdkHttpFullRequest signRequestInManagedService(SdkHttpFullRequest request, Map awsCredentials) { - String accessKey = awsCredentials.get(ACCESS_KEY_FIELD); - String secretKey = awsCredentials.get(SECRET_KEY_FIELD); - String sessionToken = awsCredentials.get(SESSION_TOKEN_FIELD); - String signingName = connector.getServiceName(); - String region = connector.getRegion(); - return ConnectorUtils.signRequest(request, accessKey, secretKey, sessionToken, signingName, region); - } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index 13b2a356ac..d881707195 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -11,16 +11,11 @@ import java.security.AccessController; import java.security.PrivilegedExceptionAction; -import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.concurrent.atomic.AtomicReference; -import java.util.regex.Matcher; -import java.util.regex.Pattern; -import org.apache.commons.text.StringSubstitutor; import org.apache.http.HttpEntity; import org.apache.http.client.methods.CloseableHttpResponse; import org.apache.http.client.methods.HttpGet; @@ -31,7 +26,6 @@ import org.apache.http.util.EntityUtils; import org.opensearch.OpenSearchStatusException; import org.opensearch.client.Client; -import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.util.TokenBucket; import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.connector.Connector; @@ -40,14 +34,9 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.engine.annotation.ConnectorExecutor; -import org.opensearch.ml.engine.credentialscommunication.SecretManagerCredentials; -import org.opensearch.ml.engine.credentialscommunication.SecretsManager; import org.opensearch.ml.engine.httpclient.MLHttpClientFactory; import org.opensearch.script.ScriptService; -import com.google.gson.JsonElement; -import com.google.gson.JsonObject; - import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; @@ -62,9 +51,6 @@ public class HttpJsonConnectorExecutor implements RemoteConnectorExecutor { @Getter private ScriptService scriptService; - @Setter - @Getter - private ClusterService clusterService; @Setter @Getter private TokenBucket modelRateLimiter; @@ -79,123 +65,6 @@ public HttpJsonConnectorExecutor(Connector connector) { this.connector = (HttpConnector) connector; } - @Override - public void invokeRemoteModelInManagedService( - MLInput mlInput, - Map parameters, - String payload, - List tensorOutputs - ) { - try { - AtomicReference responseRef = new AtomicReference<>(""); - AtomicReference statusCodeRef = new AtomicReference<>(); - HttpUriRequest request; - switch (connector.getPredictHttpMethod().toUpperCase(Locale.ROOT)) { - case "POST": - try { - String predictEndpoint = connector.getPredictEndpoint(parameters); - request = new HttpPost(predictEndpoint); - HttpEntity entity = new StringEntity(payload); - ((HttpPost) request).setEntity(entity); - } catch (Exception e) { - throw new MLException("Failed to create http request for remote model", e); - } - break; - case "GET": - try { - request = new HttpGet(connector.getPredictEndpoint(parameters)); - } catch (Exception e) { - throw new MLException("Failed to create http request for remote model", e); - } - break; - default: - throw new IllegalArgumentException("unsupported http method"); - } - - Map headers = connector.getDecryptedHeaders(); - - Map secretManagerCredentials = new HashMap<>(); - - boolean hasContentTypeHeader = false; - String secretArnPrefix = "${credential.secretArn."; - String secretArnSuffix = "}"; - String regex = Pattern.quote(secretArnPrefix) + "(.*?)" + Pattern.quote(secretArnSuffix); - - if (headers != null) { - for (String key : headers.keySet()) { - if (headers.get(key).contains(secretArnPrefix)) { - List matches = new ArrayList<>(); - Matcher matcher = Pattern.compile(regex).matcher(headers.get(key)); - while (matcher.find()) { - String match = matcher.group(1); // This is the text between the prefix and suffix - matches.add(match); - } - for (String match : matches) { - secretManagerCredentials.put(match, ""); - } - } else { - request.addHeader(key, headers.get(key)); - } - } - } - - if (!secretManagerCredentials.entrySet().isEmpty()) { - String clusterName = clusterService.getClusterName().toString(); - String roleArn = connector.getDecryptedCredential().get("roleArn"); - String secretArn = connector.getDecryptedCredential().get("secretArn"); - - SecretManagerCredentials secretManagerCredentialsRequest = new SecretManagerCredentials(roleArn, clusterName, secretArn); - JsonObject secretManagerResponse = SecretsManager.getSecretValue(secretManagerCredentialsRequest); - for (String key : secretManagerCredentials.keySet()) { - JsonElement secretValue = secretManagerResponse.get(key); - String secretVal = secretValue.isJsonNull() ? "" : secretValue.getAsString(); - secretManagerCredentials.put(key, secretVal); - } - - } - - StringSubstitutor substitutor = new StringSubstitutor(secretManagerCredentials, "${credential.secretArn.", "}"); - if (headers != null) { - for (String key : headers.keySet()) { - headers.put(key, substitutor.replace(headers.get(key))); - if (!request.containsHeader(key)) { - request.addHeader(key, substitutor.replace(headers.get(key))); - } - } - } - - if (!hasContentTypeHeader) { - request.addHeader("Content-Type", "application/json"); - } - - AccessController.doPrivileged((PrivilegedExceptionAction) () -> { - try (CloseableHttpClient httpClient = getHttpClient(); CloseableHttpResponse response = httpClient.execute(request)) { - HttpEntity responseEntity = response.getEntity(); - String responseBody = EntityUtils.toString(responseEntity); - EntityUtils.consume(responseEntity); - responseRef.set(responseBody); - statusCodeRef.set(response.getStatusLine().getStatusCode()); - } - return null; - }); - String modelResponse = responseRef.get(); - Integer statusCode = statusCodeRef.get(); - if (statusCode < 200 || statusCode >= 300) { - throw new OpenSearchStatusException(REMOTE_SERVICE_ERROR + modelResponse, RestStatus.fromCode(statusCode)); - } - - ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters); - tensors.setStatusCode(statusCode); - tensorOutputs.add(tensors); - } catch (RuntimeException e) { - log.error("Fail to execute http connector in managed service", e); - throw e; - } catch (Throwable e) { - log.error("Fail to execute http connector in managed service", e); - throw new MLException("Fail to execute http connector in managed service", e); - } - } - @Override public void invokeRemoteModel(MLInput mlInput, Map parameters, String payload, List tensorOutputs) { try { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index b08df7236f..c03dae34db 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -120,26 +120,19 @@ default void preparePayloadAndInvokeRemoteModel(MLInput mlInput, List parameters, String payload, List tensorOutputs); - void invokeRemoteModelInManagedService( - MLInput mlInput, - Map parameters, - String payload, - List tensorOutputs - ); - } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/CredentialsProvider.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/CredentialsProvider.java deleted file mode 100644 index 35ec2909a4..0000000000 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/CredentialsProvider.java +++ /dev/null @@ -1,7 +0,0 @@ -package org.opensearch.ml.engine.credentials; - -import com.amazonaws.auth.AWSCredentialsProvider; - -public interface CredentialsProvider { - AWSCredentialsProvider getCredentialsProvider(String region, String roleArn); -} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/ClientConfigurationHelper.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/ClientConfigurationHelper.java deleted file mode 100644 index 7e9c8dcc63..0000000000 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/ClientConfigurationHelper.java +++ /dev/null @@ -1,59 +0,0 @@ -package org.opensearch.ml.engine.credentials.aws; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.core.common.Strings; - -import com.amazonaws.ClientConfiguration; - -public class ClientConfigurationHelper { - private ClientConfigurationHelper() { - // no instance - } - - protected static final String SOURCE_ACCOUNT_HEADER = "x-amz-source-account"; - protected static final String SOURCE_ARN_HEADER = "x-amz-source-arn"; - protected static final String OS_DOMAIN_ARN_FORMAT = "arn:%s:es:%s:%s:domain/%s"; - - private final static Logger LOGGER = LogManager.getLogger(ClientConfigurationHelper.class); - - public static ClientConfiguration getConfusedDeputyConfiguration(String[] clusterNameTuple, String region) { - String clientId = clusterNameTuple[0]; - String domainArn = generateDomainArn(clusterNameTuple, region); - - // Confused Deputy Protection Requirement - // https://w.amazon.com/bin/view/AWSAuth/AccessManagement/Resource_Policy_Confused_Deputy_Protection - LOGGER - .debug( - "Adding Source ARN " + domainArn + " and Source Account " + clientId + " in request headers for Confused Deputy Protection" - ); - return new ClientConfiguration().withHeader(SOURCE_ARN_HEADER, domainArn).withHeader(SOURCE_ACCOUNT_HEADER, clientId); - } - - protected static String generateDomainArn(String[] clusterNameTuple, String region) { - String partition = getPartition(region); - return String.format(OS_DOMAIN_ARN_FORMAT, partition, region, clusterNameTuple[0], clusterNameTuple[1]); - } - - protected static String getPartition(String region) { - final String partition = System.getenv("DOMAIN_PARTITION"); - - if (!Strings.isNullOrEmpty(partition)) { - return partition; - } - LOGGER.warn("Domain Partition is missing from environment variable, assuming partition on the basis of current region"); - if (region.contains("gov")) { - return "aws-us-gov"; - } - if (region.contains("-isob-")) { - return "aws-iso-b"; - } - if (region.contains("-iso-")) { - return "aws-iso"; - } - if (region.startsWith("cn-")) { - return "aws-cn"; - } - return "aws"; - } -} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/CredentialsProviderFactory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/CredentialsProviderFactory.java deleted file mode 100644 index a7f3973244..0000000000 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/CredentialsProviderFactory.java +++ /dev/null @@ -1,10 +0,0 @@ -package org.opensearch.ml.engine.credentials.aws; - -import com.amazonaws.auth.AWSCredentialsProvider; - -/** - * Interface which enables to plug in multiple Credentials Providers - */ -public interface CredentialsProviderFactory { - public AWSCredentialsProvider getProvider(String roleArn); -} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/ExpirableCredentialsProviderFactory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/ExpirableCredentialsProviderFactory.java deleted file mode 100644 index 7f6eb53fc2..0000000000 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/ExpirableCredentialsProviderFactory.java +++ /dev/null @@ -1,84 +0,0 @@ -package org.opensearch.ml.engine.credentials.aws; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; - -import com.amazonaws.ClientConfiguration; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.auth.STSAssumeRoleSessionCredentialsProvider; -import com.amazonaws.services.securitytoken.AWSSecurityTokenService; -import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder; -import com.amazonaws.util.EC2MetadataUtils; - -/** - * Factory class that provides temporary credentials. It refreshes the credentials on demand. - */ -public class ExpirableCredentialsProviderFactory implements CredentialsProviderFactory { - - public ExpirableCredentialsProviderFactory(InternalAuthCredentialsClient internalAuthCredentialsClient, String[] clusterNameTuple) { - this.internalAuthCredentialsClient = internalAuthCredentialsClient; - this.clusterNameTuple = clusterNameTuple; - } - - /** - * Provide expirable credentials. - * - * @param roleArn IAM role arn - * @return AWSCredentialsProvider which holds the credentials. - */ - @Override - public AWSCredentialsProvider getProvider(String roleArn) { - return getExpirableCredentialsProvider(roleArn); - } - - private static final Logger logger = LogManager.getLogger(ExpirableCredentialsProviderFactory.class); - - private final InternalAuthCredentialsClient internalAuthCredentialsClient; - private final String[] clusterNameTuple; - - private AWSCredentialsProvider getExpirableCredentialsProvider(String roleArn) { - return findStsAssumeRoleCredentialsProvider(roleArn); - } - - private AWSCredentialsProvider findStsAssumeRoleCredentialsProvider(String roleArn) { - AWSCredentialsProvider assumeRoleApiCredentialsProvider = getAssumeRoleApiCredentialsProvider(); - - if (assumeRoleApiCredentialsProvider != null) { - logger.info("Fetching credentials from STS for assumed role"); - return getStsAssumeCustomerRoleProvider(assumeRoleApiCredentialsProvider, roleArn); - } - logger.info("Could not fetch credentials from internal service to assume role"); - return null; - } - - private AWSCredentialsProvider getAssumeRoleApiCredentialsProvider() { - InternalAuthApiCredentialsProvider internalAuthApiCredentialsProvider = new InternalAuthApiCredentialsProvider( - internalAuthCredentialsClient, - InternalAuthApiCredentialsProvider.POLICY_TYPES.get("ASSUME_ROLE") - ); - return internalAuthApiCredentialsProvider.getCredentials() != null ? internalAuthApiCredentialsProvider : null; - } - - private AWSCredentialsProvider getStsAssumeCustomerRoleProvider(AWSCredentialsProvider apiCredentialsProvider, String roleArn) { - String region = "us-east-1"; - try { - region = EC2MetadataUtils.getEC2InstanceRegion(); - } catch (Exception ex) { - logger.info("Exception occurred while fetching the region info from EC2 metadata. Defaulting to us-east-1"); - } - - final ClientConfiguration configurationWithConfusedDeputyHeaders = ClientConfigurationHelper - .getConfusedDeputyConfiguration(clusterNameTuple, region); - AWSSecurityTokenServiceClientBuilder stsClientBuilder = AWSSecurityTokenServiceClientBuilder - .standard() - .withCredentials(apiCredentialsProvider) - .withClientConfiguration(configurationWithConfusedDeputyHeaders) - .withRegion(region); - AWSSecurityTokenService stsClient = stsClientBuilder.build(); - STSAssumeRoleSessionCredentialsProvider.Builder providerBuilder = new STSAssumeRoleSessionCredentialsProvider.Builder( - roleArn, - "ml-commons" - ).withStsClient(stsClient); - return new PrivilegedCredentialsProvider(providerBuilder.build()); - } -} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/InternalAuthApiCredentialsProvider.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/InternalAuthApiCredentialsProvider.java deleted file mode 100644 index 91201bca77..0000000000 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/InternalAuthApiCredentialsProvider.java +++ /dev/null @@ -1,86 +0,0 @@ -package org.opensearch.ml.engine.credentials.aws; - -import java.util.HashMap; -import java.util.Map; - -import org.opensearch.common.unit.TimeValue; - -import com.amazonaws.auth.AWSCredentials; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.auth.BasicSessionCredentials; - -/** - * This classes fetches credentials to assume role by making AWS ES internal service call - */ -class InternalAuthApiCredentialsProvider implements AWSCredentialsProvider { - - public static final Map POLICY_TYPES; - static { - POLICY_TYPES = new HashMap(); - POLICY_TYPES.put("ASSUME_ROLE", "AR"); - } - private final InternalAuthCredentialsClient internalApiCredentialsClient; - private final String policyType; - private AWSCredentials awsCredentials; - private long expiryTimestamp; - - InternalAuthApiCredentialsProvider(InternalAuthCredentialsClient internalApiCredentialsClient, String policyType) { - this.internalApiCredentialsClient = internalApiCredentialsClient; - this.policyType = policyType; - } - - /** - * Fetches credentials. It refreshes the credentials if expired - * - * @return AWSCredentials - */ - @Override - public AWSCredentials getCredentials() { - if (credentialsHaveExpired()) { - refresh(); - } - - return awsCredentials; - } - - /** - * Refreshes credentials - */ - @Override - public synchronized void refresh() { - if (!credentialsHaveExpired()) { - return; - } - - InternalAwsCredentials apiCredentials = internalApiCredentialsClient.getAwsCredentials(policyType); - - if (apiCredentials == null) { - resetCredentials(); - } else { - this.awsCredentials = new BasicSessionCredentials( - apiCredentials.getAccessKey(), - apiCredentials.getSecretKey(), - apiCredentials.getSessionToken() - ); - this.expiryTimestamp = apiCredentials.getExpiry() - TimeValue.timeValueSeconds(10).millis(); - } - } - - /** - * Gets the expiry timestamp of the temporary credentials - * - * @return expiry timestamp - */ - public long getExpiryTimestamp() { - return expiryTimestamp; - } - - private boolean credentialsHaveExpired() { - return awsCredentials == null || System.currentTimeMillis() > expiryTimestamp; - } - - private void resetCredentials() { - this.awsCredentials = null; - this.expiryTimestamp = 0; - } -} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/InternalAuthCredentialsApiRequest.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/InternalAuthCredentialsApiRequest.java deleted file mode 100644 index de8a318105..0000000000 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/InternalAuthCredentialsApiRequest.java +++ /dev/null @@ -1,108 +0,0 @@ -package org.opensearch.ml.engine.credentials.aws; - -import java.io.IOException; -import java.net.URI; -import java.net.URISyntaxException; - -import org.apache.http.HttpEntity; -import org.apache.http.HttpResponse; -import org.apache.http.client.methods.HttpGet; -import org.apache.http.client.utils.URIBuilder; -import org.apache.http.impl.client.CloseableHttpClient; -import org.apache.http.util.EntityUtils; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; - -import com.fasterxml.jackson.core.JsonParseException; -import com.fasterxml.jackson.databind.JsonMappingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.PropertyNamingStrategy; - -/** - * This class handles the connections to AWS ES internal service endpoint, to - * fetch the temporary credentials to assume the role. - */ -class InternalAuthCredentialsApiRequest { - - private static final Logger logger = LogManager.getLogger(InternalAuthCredentialsApiRequest.class); - private static final InternalAwsCredentials EMPTY_CREDENTIALS = new InternalAwsCredentials(); - private static final String ENDPOINT = "http://localhost:9200/_internal/auth"; - private final CloseableHttpClient httpClient; - private final String policyType; - - private static ObjectMapper JSON_MAPPER = new ObjectMapper(); - static { - JSON_MAPPER.setPropertyNamingStrategy(new PropertyNamingStrategy.LowerCaseWithUnderscoresStrategy()); - } - - InternalAuthCredentialsApiRequest(CloseableHttpClient httpClient, String policyType) { - this.httpClient = httpClient; - this.policyType = policyType; - } - - InternalAwsCredentials execute() throws IOException { - HttpResponse response = getHttpResponse(); - validateResponseStatus(response); - String responseString = getResponseString(response); - return httpResponseAsCredentialsObject(responseString); - } - - private HttpResponse getHttpResponse() throws IOException { - HttpGet internalAuthGetRequest = new HttpGet(internalAuthUri()); - - return httpClient.execute(internalAuthGetRequest); - } - - private URI internalAuthUri() { - try { - return new URIBuilder(ENDPOINT).addParameter("policy_id", policyType).build(); - } catch (URISyntaxException exception) { - logger.error(exception); - throw new IllegalStateException("Error creating URI"); - } - } - - private String getResponseString(HttpResponse response) throws IOException { - HttpEntity entity = response.getEntity(); - if (entity == null) - return "{}"; - - String responseString = EntityUtils.toString(entity); - logger.debug("Internal Auth response: " + responseString); - - return responseString; - } - - private String getResponseString(HttpEntity entity) throws IOException { - if (entity == null) - return "{}"; - - String responseString = EntityUtils.toString(entity); - logger.debug("Internal Auth response: " + responseString); - - return responseString; - } - - private void validateResponseStatus(HttpResponse response) throws IOException { - int statusCode = response.getStatusLine().getStatusCode(); - - if (statusCode != 200) { - throw new IOException("Request to internal auth failed with not OK response"); - } - } - - private InternalAwsCredentials httpResponseAsCredentialsObject(String responseString) throws IOException { - try { - return JSON_MAPPER.readValue(responseString, InternalAwsCredentials.class); - } catch (JsonParseException e) { - logger.error("Error in parsing internal aws credentials response", e); - return EMPTY_CREDENTIALS; - } catch (JsonMappingException e) { - logger.error("Error in parsing internal aws credentials response", e); - return EMPTY_CREDENTIALS; - } catch (IOException e) { - logger.error("Error in parsing internal aws credentials response", e); - return EMPTY_CREDENTIALS; - } - } -} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/InternalAuthCredentialsClient.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/InternalAuthCredentialsClient.java deleted file mode 100644 index b45ed36642..0000000000 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/InternalAuthCredentialsClient.java +++ /dev/null @@ -1,64 +0,0 @@ -package org.opensearch.ml.engine.credentials.aws; - -import java.io.IOException; - -import org.apache.http.client.config.RequestConfig; -import org.apache.http.impl.client.CloseableHttpClient; -import org.apache.http.impl.client.DefaultHttpRequestRetryHandler; -import org.apache.http.impl.client.HttpClientBuilder; -import org.apache.http.impl.conn.PoolingHttpClientConnectionManager; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.common.unit.TimeValue; - -/** - * This class handles client configuration for AWS ES internal service calls - */ -public class InternalAuthCredentialsClient { - - private static final Logger logger = LogManager.getLogger(InternalAuthCredentialsClient.class); - - private static final int TIMEOUT_MILLISECONDS = (int) TimeValue.timeValueSeconds(5).millis(); - private static final int SOCKET_TIMEOUT_MILLISECONDS = (int) TimeValue.timeValueSeconds(70).millis(); - - private final static CloseableHttpClient HTTP_CLIENT; - - static { - HTTP_CLIENT = createHttpClient(); - } - - public InternalAwsCredentials getAwsCredentials(String policyType) { - try { - InternalAwsCredentials internalAwsCredentials = getInternalAwsCredentials(policyType); - - return !internalAwsCredentials.isEmpty() ? internalAwsCredentials : null; - - } catch (IOException e) { - logger.error("Could not fetch AWS credentials", e); - return null; - } - } - - private static CloseableHttpClient createHttpClient() { - RequestConfig config = RequestConfig - .custom() - .setConnectTimeout(TIMEOUT_MILLISECONDS) - .setConnectionRequestTimeout(TIMEOUT_MILLISECONDS) - .setSocketTimeout(SOCKET_TIMEOUT_MILLISECONDS) - .build(); - - PoolingHttpClientConnectionManager connectionManager = new PoolingHttpClientConnectionManager(); - connectionManager.setDefaultMaxPerRoute(5); - - return HttpClientBuilder - .create() - .setDefaultRequestConfig(config) - .setConnectionManager(connectionManager) - .setRetryHandler(new DefaultHttpRequestRetryHandler()) - .build(); - } - - private InternalAwsCredentials getInternalAwsCredentials(String policyType) throws IOException { - return (new InternalAuthCredentialsApiRequest(HTTP_CLIENT, policyType)).execute(); - } -} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/InternalAuthCredentialsClientPool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/InternalAuthCredentialsClientPool.java deleted file mode 100644 index 47792cca8a..0000000000 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/InternalAuthCredentialsClientPool.java +++ /dev/null @@ -1,36 +0,0 @@ -package org.opensearch.ml.engine.credentials.aws; - -import java.util.HashMap; -import java.util.Map; - -/** - * This class fetches credentials provider from different sources(based on priority) and uses the first one that works. - */ -public final class InternalAuthCredentialsClientPool { - - private static final InternalAuthCredentialsClientPool instance = new InternalAuthCredentialsClientPool(); - - private Map clientPool; - - public static InternalAuthCredentialsClientPool getInstance() { - return instance; - } - - public synchronized InternalAuthCredentialsClient getInternalAuthClient(String factoryName) { - if (clientPool.containsKey(factoryName)) { - return clientPool.get(factoryName); - } - - return newClient(factoryName); - } - - private InternalAuthCredentialsClientPool() { - this.clientPool = new HashMap(); - } - - private InternalAuthCredentialsClient newClient(String factoryName) { - InternalAuthCredentialsClient client = new InternalAuthCredentialsClient(); - clientPool.put(factoryName, client); - return client; - } -} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/InternalAwsCredentials.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/InternalAwsCredentials.java deleted file mode 100644 index 3549a0e459..0000000000 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/InternalAwsCredentials.java +++ /dev/null @@ -1,56 +0,0 @@ -package org.opensearch.ml.engine.credentials.aws; - -/** - * This class is a placeholder for credentials - */ -public class InternalAwsCredentials { - - private String accessKey; - private String secretKey; - private String sessionToken; - private long expiry; - - public InternalAwsCredentials(String accessKey, String secretKey, String sessionToken, long expiry) { - this.accessKey = accessKey; - this.secretKey = secretKey; - this.sessionToken = sessionToken; - } - - public InternalAwsCredentials() {} - - public String getAccessKey() { - return accessKey; - } - - public String getSecretKey() { - return secretKey; - } - - public String getSessionToken() { - return sessionToken; - } - - public long getExpiry() { - return expiry; - } - - public void setAccessKey(String accessKey) { - this.accessKey = accessKey; - } - - public void setSecretKey(String secretKey) { - this.secretKey = secretKey; - } - - public void setSessionToken(String sessionToken) { - this.sessionToken = sessionToken; - } - - public void setExpiry(long expiryTimestamp) { - this.expiry = expiryTimestamp; - } - - public boolean isEmpty() { - return accessKey == null || secretKey == null || sessionToken == null; - } -} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/PrivilegedCredentialsProvider.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/PrivilegedCredentialsProvider.java deleted file mode 100644 index 570676de2e..0000000000 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/PrivilegedCredentialsProvider.java +++ /dev/null @@ -1,31 +0,0 @@ -package org.opensearch.ml.engine.credentials.aws; - -import com.amazonaws.auth.AWSCredentials; -import com.amazonaws.auth.AWSCredentialsProvider; - -/** - * This class helps in fetching the credentials by making socket connections in - * privileged mode. - */ -public class PrivilegedCredentialsProvider implements AWSCredentialsProvider { - - private final AWSCredentialsProvider credentials; - - PrivilegedCredentialsProvider(AWSCredentialsProvider delegate) { - this.credentials = delegate; - } - - @Override - public AWSCredentials getCredentials() { - return SocketAccess.doPrivileged(credentials::getCredentials); - } - - @Override - public void refresh() { - SocketAccess.doPrivilegedVoid(credentials::refresh); - } - - public AWSCredentialsProvider wrappedProvider() { - return credentials; - } -} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/SocketAccess.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/SocketAccess.java deleted file mode 100644 index cd65066885..0000000000 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentials/aws/SocketAccess.java +++ /dev/null @@ -1,42 +0,0 @@ -package org.opensearch.ml.engine.credentials.aws; - -import java.io.IOException; -import java.net.SocketPermission; -import java.security.AccessController; -import java.security.PrivilegedAction; -import java.security.PrivilegedActionException; -import java.security.PrivilegedExceptionAction; - -import org.opensearch.SpecialPermission; - -/** - * This plugin uses aws libraries to connect to STS. For these remote calls the plugin needs - * {@link SocketPermission} 'connect' to establish connections. This class wraps the operations requiring access in - * {@link AccessController#doPrivileged(PrivilegedAction)} blocks. - */ -final class SocketAccess { - - private SocketAccess() {} - - public static T doPrivileged(PrivilegedAction operation) { - SpecialPermission.check(); - return AccessController.doPrivileged(operation); - } - - public static T doPrivilegedIOException(PrivilegedExceptionAction operation) throws IOException { - SpecialPermission.check(); - try { - return AccessController.doPrivileged(operation); - } catch (PrivilegedActionException e) { - throw (IOException) e.getCause(); - } - } - - public static void doPrivilegedVoid(Runnable action) { - SpecialPermission.check(); - AccessController.doPrivileged((PrivilegedAction) () -> { - action.run(); - return null; - }); - } -} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentialscommunication/Credentials.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentialscommunication/Credentials.java deleted file mode 100644 index 1916f86e8d..0000000000 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentialscommunication/Credentials.java +++ /dev/null @@ -1,21 +0,0 @@ -package org.opensearch.ml.engine.credentialscommunication; - -import java.io.IOException; -import java.security.AccessController; -import java.security.PrivilegedAction; - -import org.opensearch.ml.engine.credentials.aws.InternalAwsCredentials; -import org.opensearch.ml.engine.factory.CredentialsFactory; - -public class Credentials { - /** - * Retrieves the credentials for the given credentialsRequest through ExpirableCredentialsProviderFactory class - * - */ - public static InternalAwsCredentials getCredentials(CredentialsRequest credentialsRequest) throws IOException { - return AccessController.doPrivileged((PrivilegedAction) () -> { - CredentialsFactory credentialsProviderFactory = new CredentialsFactory(); - return credentialsProviderFactory.getAWSCredentialsProvider(credentialsRequest); - }); - } -} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentialscommunication/CredentialsRequest.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentialscommunication/CredentialsRequest.java deleted file mode 100644 index d71c6087dd..0000000000 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentialscommunication/CredentialsRequest.java +++ /dev/null @@ -1,29 +0,0 @@ -package org.opensearch.ml.engine.credentialscommunication; - -import org.opensearch.core.common.Strings; - -public class CredentialsRequest { - private String clusterName; - private String roleArn; - - public CredentialsRequest(final String roleArn, final String clusterName) { - if (Strings.isNullOrEmpty(roleArn) || !Util.isValidIAMArn(roleArn)) { - throw new IllegalArgumentException("Role arn is missing/invalid: " + roleArn); - } - this.roleArn = roleArn; - this.clusterName = clusterName; - } - - @Override - public String toString() { - return "RoleArn: " + roleArn + ", ClusterName: " + clusterName; - } - - public String getRoleArn() { - return roleArn; - } - - public String getClusterName() { - return clusterName; - } -} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentialscommunication/SecretManagerCredentials.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentialscommunication/SecretManagerCredentials.java deleted file mode 100644 index 1911abc2df..0000000000 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentialscommunication/SecretManagerCredentials.java +++ /dev/null @@ -1,42 +0,0 @@ -package org.opensearch.ml.engine.credentialscommunication; - -import org.opensearch.core.common.Strings; - -public class SecretManagerCredentials { - - private String clusterName; - private String roleArn; - private String secretArn; - - public SecretManagerCredentials(final String roleArn, final String clusterName, final String secretArn) { - - if (Strings.isNullOrEmpty(roleArn) || !Util.isValidIAMArn(roleArn)) { - throw new IllegalArgumentException("Role arn is missing/invalid: " + roleArn); - } - - if (Strings.isNullOrEmpty(secretArn) || !Util.isValidSecretManagerArn(secretArn)) { - throw new IllegalArgumentException("secret arn is missing/invalid: " + secretArn); - } - - this.roleArn = roleArn; - this.clusterName = clusterName; - this.secretArn = secretArn; - } - - @Override - public String toString() { - return "RoleARn: " + roleArn + ", ClusterName: " + clusterName + ", secretArn: " + secretArn; - } - - public String getSecretArn() { - return secretArn; - } - - public String getRoleArn() { - return roleArn; - } - - public String getClusterName() { - return clusterName; - } -} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentialscommunication/SecretsManager.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentialscommunication/SecretsManager.java deleted file mode 100644 index 85197f491d..0000000000 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentialscommunication/SecretsManager.java +++ /dev/null @@ -1,23 +0,0 @@ -package org.opensearch.ml.engine.credentialscommunication; - -import java.io.IOException; -import java.security.AccessController; -import java.security.PrivilegedAction; - -import org.opensearch.ml.engine.factory.SecretsManagerFactory; - -import com.google.gson.JsonObject; - -public class SecretsManager { - - /** - * Retrieves the secretValue key pair mapping for the given requested secret - * - */ - public static JsonObject getSecretValue(SecretManagerCredentials secretManagerCredentials) throws IOException { - return AccessController.doPrivileged((PrivilegedAction) () -> { - SecretsManagerFactory secretsManagerFactory = new SecretsManagerFactory(); - return secretsManagerFactory.getSecrets(secretManagerCredentials); - }); - } -} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentialscommunication/Util.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentialscommunication/Util.java deleted file mode 100644 index 80b66bd139..0000000000 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/credentialscommunication/Util.java +++ /dev/null @@ -1,38 +0,0 @@ -package org.opensearch.ml.engine.credentialscommunication; - -import java.util.regex.Pattern; - -import org.opensearch.common.ValidationException; -import org.opensearch.core.common.Strings; - -public class Util { - - private Util() {} - - public static final Pattern SECRET_ARN_REGEX = Pattern - .compile("^arn:aws(-[^:]+)?:secretsmanager:([a-zA-Z0-9-]+):([0-9]{12}):secret:([a-zA-Z0-9-/_+=@.,]+)$"); - public static final Pattern IAM_ARN_REGEX = Pattern.compile("^arn:aws(-[^:]+)?:iam::([0-9]{12}):([a-zA-Z0-9-/_+=@.,]+)$"); - - public static String getRegionFromSecretArn(String secretArn) { - if (isValidSecretManagerArn(secretArn)) { - return secretArn.split(":")[3]; - } - throw new IllegalArgumentException("Unable to retrieve region from secretARN " + secretArn); - } - - public static boolean isValidIAMArn(String arn) { - return Strings.hasLength(arn) && IAM_ARN_REGEX.matcher(arn).find(); - } - - public static boolean isValidSecretManagerArn(String secretArn) throws ValidationException { - return Strings.hasLength(secretArn) && SECRET_ARN_REGEX.matcher(secretArn).find(); - } - - public static boolean isValidContentType(String contentType) { - return contentType.equals("application/json"); - } - - public static boolean isValidAWSService(String serviceName) { - return (serviceName.equalsIgnoreCase("sagemaker") || serviceName.equalsIgnoreCase("bedrock")); - } -} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/factory/CredentialsFactory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/factory/CredentialsFactory.java deleted file mode 100644 index 10822f126a..0000000000 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/factory/CredentialsFactory.java +++ /dev/null @@ -1,69 +0,0 @@ -package org.opensearch.ml.engine.factory; - -import java.util.HashMap; -import java.util.Map; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.ml.engine.credentials.aws.ExpirableCredentialsProviderFactory; -import org.opensearch.ml.engine.credentials.aws.InternalAuthCredentialsClient; -import org.opensearch.ml.engine.credentials.aws.InternalAuthCredentialsClientPool; -import org.opensearch.ml.engine.credentials.aws.InternalAwsCredentials; -import org.opensearch.ml.engine.credentials.aws.PrivilegedCredentialsProvider; -import org.opensearch.ml.engine.credentialscommunication.CredentialsRequest; - -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.auth.BasicSessionCredentials; - -public class CredentialsFactory { - private static final Logger logger = LogManager.getLogger(CredentialsFactory.class); - - private final InternalAuthCredentialsClient internalApiCredentialsClient; - - /* - * Mapping between IAM roleArn and AWSCredentialsProvider. Each role will have its own credentials. - */ - Map roleClientMap = new HashMap<>(); - - public CredentialsFactory() { - this.internalApiCredentialsClient = InternalAuthCredentialsClientPool.getInstance().getInternalAuthClient(getClass().getName()); - } - - /** - * Fetches the client corresponding to an IAM role - * - * @return AmazonSNS AWS SNS client - */ - public InternalAwsCredentials getAWSCredentialsProvider(CredentialsRequest credentialsRequest) { - AWSCredentialsProvider credentialsProvider; - String roleArn = credentialsRequest.getRoleArn(); - String clusterName = credentialsRequest.getClusterName(); - if (!roleClientMap.containsKey(roleArn)) { - credentialsProvider = getProvider(roleArn, clusterName); - roleClientMap.put(roleArn, credentialsProvider); - } - AWSCredentialsProvider awsCredentialsProvider = roleClientMap.get(roleArn); - PrivilegedCredentialsProvider privilegedCredentialsProvider = (PrivilegedCredentialsProvider) awsCredentialsProvider; - BasicSessionCredentials basic = (BasicSessionCredentials) privilegedCredentialsProvider.getCredentials(); - InternalAwsCredentials apiCredentials = new InternalAwsCredentials( - basic.getAWSAccessKeyId(), - basic.getAWSSecretKey(), - basic.getSessionToken(), - 0 - ); - return apiCredentials; - } - - /** - * @param roleArn - * @return AWSCredentialsProvider - * @throws IllegalArgumentException - */ - public AWSCredentialsProvider getProvider(String roleArn, String clusterName) throws IllegalArgumentException { - org.opensearch.ml.engine.credentials.aws.CredentialsProviderFactory providerSource = new ExpirableCredentialsProviderFactory( - internalApiCredentialsClient, - clusterName.split(":") - ); - return providerSource.getProvider(roleArn); - } -} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/factory/SecretsManagerFactory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/factory/SecretsManagerFactory.java deleted file mode 100644 index e2a4420868..0000000000 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/factory/SecretsManagerFactory.java +++ /dev/null @@ -1,110 +0,0 @@ -package org.opensearch.ml.engine.factory; - -import java.util.HashMap; -import java.util.Map; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.ml.engine.credentials.aws.CredentialsProviderFactory; -import org.opensearch.ml.engine.credentials.aws.ExpirableCredentialsProviderFactory; -import org.opensearch.ml.engine.credentials.aws.InternalAuthCredentialsClient; -import org.opensearch.ml.engine.credentials.aws.InternalAuthCredentialsClientPool; -import org.opensearch.ml.engine.credentialscommunication.SecretManagerCredentials; -import org.opensearch.ml.engine.credentialscommunication.Util; - -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; -import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest; -import com.amazonaws.services.secretsmanager.model.GetSecretValueResult; -import com.google.gson.JsonObject; -import com.google.gson.JsonParser; - -final public class SecretsManagerFactory { - private static final Logger logger = LogManager.getLogger(SecretsManagerFactory.class); - - private final InternalAuthCredentialsClient internalApiCredentialsClient; - - /* - * Mapping between IAM roleArn and SecretManagerClientHelper. Each role will have its own credentials. - */ - Map roleClientMap = new HashMap<>(); - - public SecretsManagerFactory() { - this.internalApiCredentialsClient = InternalAuthCredentialsClientPool.getInstance().getInternalAuthClient(getClass().getName()); - } - - public JsonObject getSecrets(SecretManagerCredentials secretCredentials) { - try { - AWSSecretsManager secretsManager = getClient(secretCredentials); - GetSecretValueRequest secretsRequest = new GetSecretValueRequest(); - secretsRequest.setSecretId(secretCredentials.getSecretArn()); - GetSecretValueResult secretValueResponse = secretsManager.getSecretValue(secretsRequest); - JsonObject jsonObject = JsonParser.parseString(secretValueResponse.getSecretString()).getAsJsonObject(); - return jsonObject; - } catch (Exception ex) { - logger.error("Exception getting secrets from SecretManager", ex); - throw ex; - } - } - - /** - * Fetches the client corresponding to an IAM role - * - * @return AWSSecretsManager AWS SecretsManager client - */ - public AWSSecretsManager getClient(SecretManagerCredentials secretCredentials) { - AWSCredentialsProvider credentialsProvider; - String roleArn = secretCredentials.getRoleArn(); - String clusterName = secretCredentials.getClusterName(); - if (!roleClientMap.containsKey(roleArn)) { - credentialsProvider = getProvider(roleArn, clusterName); - roleClientMap.put(roleArn, new SecretManagerClientHelper(credentialsProvider)); - } - - AWSSecretsManager secretsManagerClient = roleClientMap - .get(roleArn) - .getSecretManagerClient(Util.getRegionFromSecretArn(secretCredentials.getSecretArn())); - return secretsManagerClient; - } - - /** - * @param roleArn - * @return AWSCredentialsProvider - * @throws IllegalArgumentException - */ - public AWSCredentialsProvider getProvider(String roleArn, String clusterName) throws IllegalArgumentException { - - CredentialsProviderFactory providerSource = new ExpirableCredentialsProviderFactory( - internalApiCredentialsClient, - clusterName.split(":") - ); - return providerSource.getProvider(roleArn); - } -} - -/** - * This helper class caches the credentials for a role and creates client - * for each AWS region based on the topic ARN - */ -class SecretManagerClientHelper { - private AWSCredentialsProvider credentialsProvider; - // Map between Region and client - private Map secretManagerClientMap = new HashMap(); - - SecretManagerClientHelper(AWSCredentialsProvider credentialsProvider) { - this.credentialsProvider = credentialsProvider; - } - - public AWSSecretsManager getSecretManagerClient(String region) { - if (!secretManagerClientMap.containsKey(region)) { - AWSSecretsManager secretsManagerClient = AWSSecretsManagerClientBuilder - .standard() - .withRegion(region) - .withCredentials(credentialsProvider) - .build(); - secretManagerClientMap.put(region, secretsManagerClient); - } - return secretManagerClientMap.get(region); - } -} diff --git a/plugin/build.gradle b/plugin/build.gradle index 865c216098..a6fbbf1851 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -336,9 +336,6 @@ configurations.all { resolutionStrategy.force 'org.apache.httpcomponents:httpclient:4.5.14' resolutionStrategy.force 'commons-codec:commons-codec:1.15' resolutionStrategy.force 'org.slf4j:slf4j-api:1.7.36' - resolutionStrategy.force "joda-time:joda-time:2.8.1" // Resolve for amazonaws - resolutionStrategy.force "com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}" // resolve for amazonaws - resolutionStrategy.force "com.fasterxml.jackson.dataformat:jackson-dataformat-cbor:${versions.jackson}" // resolve for amazonaws } apply plugin: 'com.netflix.nebula.ospackage' diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java index 645e5871fa..4cadcc936a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java @@ -93,7 +93,6 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { + mlModelManager.updateModelCache(modelId, ActionListener.wrap(r -> { log.info("Successfully performed in-place update model {} on node {}", modelId, localNodeId); }, e -> { log.error("Failed to perform in-place update model for model {} on node {}", modelId, localNodeId); })); return new MLUpdateModelCacheNodeResponse(clusterService.localNode(), modelUpdateStatus); diff --git a/plugin/src/main/java/org/opensearch/ml/action/update/UpdateModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/update/UpdateModelTransportAction.java index ad3d448ba6..cf9915ec07 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/update/UpdateModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/update/UpdateModelTransportAction.java @@ -46,7 +46,6 @@ import org.opensearch.ml.common.MLModelGroup; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.controller.MLRateLimiter; -import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.transport.model.MLUpdateModelAction; import org.opensearch.ml.common.transport.model.MLUpdateModelInput; @@ -131,13 +130,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (hasPermission) { - updateRemoteOrTextEmbeddingModel( - modelId, - updateModelInput, - mlModel, - user, - wrappedListener - ); + updateRemoteOrTextEmbeddingModel(modelId, updateModelInput, mlModel, user, wrappedListener); } else { wrappedListener .onFailure( @@ -217,12 +204,10 @@ private void updateRemoteOrTextEmbeddingModel( String newModelGroupId = (Strings.hasLength(updateModelInput.getModelGroupId()) && !Objects.equals(updateModelInput.getModelGroupId(), mlModel.getModelGroupId())) ? updateModelInput.getModelGroupId() : null; String newConnectorId = Strings.hasLength(updateModelInput.getConnectorId()) ? updateModelInput.getConnectorId() : null; - - String newConnectorId = Strings.hasLength(updateModelInput.getConnectorId()) ? updateModelInput.getConnectorId() : null; boolean isModelDeployed = isModelDeployed(mlModel.getModelState()); // This flag is used to decide if we need to re-deploy the predictor(model) when performing the in-place update boolean isPredictorUpdate = (updateModelInput.getConnectorUpdateContent() != null) - || (relinkConnectorId != null) + || (newConnectorId != null) || !Objects.equals(updateModelInput.getIsEnabled(), mlModel.getIsEnabled()); if (updateModelInput.getModelRateLimiterConfig() != null) { MLRateLimiter modelRateLimiterConfig = mlModel.getModelRateLimiterConfig(); @@ -398,10 +383,7 @@ private void updateRequestConstructor( updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); if (isUpdateModelCache) { String[] targetNodeIds = getAllNodes(); - MLUpdateModelCacheNodesRequest mlUpdateModelCacheNodesRequest = new MLUpdateModelCacheNodesRequest( - targetNodeIds, - modelId - ); + MLUpdateModelCacheNodesRequest mlUpdateModelCacheNodesRequest = new MLUpdateModelCacheNodesRequest(targetNodeIds, modelId); client .update( updateRequest, @@ -443,10 +425,7 @@ private void updateRequestConstructor( updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); if (isUpdateModelCache) { String[] targetNodeIds = getAllNodes(); - MLUpdateModelCacheNodesRequest mlUpdateModelCacheNodesRequest = new MLUpdateModelCacheNodesRequest( - targetNodeIds, - modelId - ); + MLUpdateModelCacheNodesRequest mlUpdateModelCacheNodesRequest = new MLUpdateModelCacheNodesRequest(targetNodeIds, modelId); client.update(updateModelGroupRequest, ActionListener.wrap(r -> { client .update( diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 45055e12cb..b5b87e679d 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -69,7 +69,6 @@ import org.opensearch.OpenSearchStatusException; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.get.GetRequest; -import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.IndicesOptions; @@ -851,7 +850,7 @@ private void updateModelRegisterStateAsDone( } @VisibleForTesting - private void deployModelAfterRegistering(MLRegisterModelInput registerModelInput, String modelId) { + void deployModelAfterRegistering(MLRegisterModelInput registerModelInput, String modelId) { String[] modelNodeIds = registerModelInput.getModelNodeIds(); log.debug("start deploying model after registering, modelId: {} on nodes: {}", modelId, Arrays.toString(modelNodeIds)); MLDeployModelRequest request = new MLDeployModelRequest(modelId, modelNodeIds, false, true);