Skip to content

Commit

Permalink
revert AOS code
Browse files Browse the repository at this point in the history
Signed-off-by: Sicheng Song <sicheng.song@outlook.com>
  • Loading branch information
b4sjoo committed Dec 21, 2023
1 parent d64f03f commit 5a2d455
Show file tree
Hide file tree
Showing 34 changed files with 22 additions and 1,238 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ public abstract class AbstractConnector implements Connector {
public static final String ACCESS_KEY_FIELD = "access_key";
public static final String SECRET_KEY_FIELD = "secret_key";
public static final String SESSION_TOKEN_FIELD = "session_token";
public static final String ROLE_ARN_FIELD = "roleArn";
public static final String SECRET_ARN_FIELD = "secretArn";
public static final String NAME_FIELD = "name";
public static final String VERSION_FIELD = "version";
public static final String DESCRIPTION_FIELD = "description";
Expand All @@ -52,7 +50,6 @@ public abstract class AbstractConnector implements Connector {
protected String protocol;

protected Map<String, String> parameters;
@Getter
protected Map<String, String> credential;
protected Map<String, String> decryptedHeaders;
@Setter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,18 @@ public AwsConnector(String name, String description, String version, String prot
Map<String, String> parameters, Map<String, String> credential, List<ConnectorAction> actions,
List<String> backendRoles, AccessMode accessMode, User owner) {
super(name, description, version, protocol, parameters, credential, actions, backendRoles, accessMode, owner);
//validate();
validateAwsConnectorInManagedService();
validate();
}

public AwsConnector(String protocol, XContentParser parser) throws IOException {
super(protocol, parser);
//validate();
validateAwsConnectorInManagedService();
validate();
}


public AwsConnector(StreamInput input) throws IOException {
super(input);
//validate();
validateAwsConnectorInManagedService();
validate();
}

private void validate() {
Expand All @@ -62,19 +59,6 @@ private void validate() {
}
}

private void validateAwsConnectorInManagedService() {
// Users who are using AWS protocol must give a roleArn in credentials
if (credential == null || !credential.containsKey("roleArn")) {
throw new IllegalArgumentException("please supply a valid roleArn in credentials if utilizing an AWS service");
}
if ((credential == null || !credential.containsKey(SERVICE_NAME_FIELD)) && (parameters == null || !parameters.containsKey(SERVICE_NAME_FIELD))) {
throw new IllegalArgumentException("Missing or invalid service name");
}
if ((credential == null || !credential.containsKey(REGION_FIELD)) && (parameters == null || !parameters.containsKey(REGION_FIELD))) {
throw new IllegalArgumentException("Missing region");
}
}

@Override
public Connector cloneConnector() {
try (BytesStreamOutput bytesStreamOutput = new BytesStreamOutput()){
Expand All @@ -90,10 +74,6 @@ public String getAccessKey() {
return decryptedCredential.get(ACCESS_KEY_FIELD);
}

public String getRoleArn() {
return decryptedCredential.get(ROLE_ARN_FIELD);
}

public String getSecretKey() {
return decryptedCredential.get(SECRET_KEY_FIELD);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,4 @@ default void validateConnectorURL(List<String> urlRegexes) {
}

Map<String, String> getDecryptedHeaders();

Map<String, String> getCredential();
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@

package org.opensearch.ml.common.transport.controller;

import org.opensearch.action.support.nodes.BaseNodeRequest;
import java.io.IOException;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
public class MLDeployModelControllerNodeRequest extends BaseNodeRequest {
import org.opensearch.transport.TransportRequest;

public class MLDeployModelControllerNodeRequest extends TransportRequest {
@Getter
private MLDeployModelControllerNodesRequest deployModelControllerNodesRequest;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public MLUpdateModelCacheNodesRequest(String[] nodeIds, String modelId) {
this.modelId = modelId;
}

public MLInPlaceUpdateModelNodesRequest(DiscoveryNode[] nodeIds, String modelId) {
public MLUpdateModelCacheNodesRequest(DiscoveryNode[] nodeIds, String modelId) {
super(nodeIds);
this.modelId = modelId;
}
Expand Down
3 changes: 0 additions & 3 deletions ml-algorithms/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ dependencies {
implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.1'
implementation 'com.jayway.jsonpath:json-path:2.8.0'
implementation group: 'org.json', name: 'json', version: '20231013'
implementation "com.amazonaws:aws-java-sdk-core:1.12.48"
implementation "com.amazonaws:aws-java-sdk-sts:1.12.48"
implementation "com.amazonaws:aws-java-sdk-secretsmanager:1.12.48"
}

lombok {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,21 @@
package org.opensearch.ml.engine.algorithms.remote;

import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR;
import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD;
import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD;
import static org.opensearch.ml.common.connector.AbstractConnector.SESSION_TOKEN_FIELD;
import static org.opensearch.ml.common.connector.ConnectorProtocols.AWS_SIGV4;
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processOutput;
import static software.amazon.awssdk.http.SdkHttpMethod.POST;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.TokenBucket;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.connector.AwsConnector;
Expand All @@ -35,9 +29,6 @@
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.annotation.ConnectorExecutor;
import org.opensearch.ml.engine.credentials.aws.InternalAwsCredentials;
import org.opensearch.ml.engine.credentialscommunication.Credentials;
import org.opensearch.ml.engine.credentialscommunication.CredentialsRequest;
import org.opensearch.script.ScriptService;

import lombok.Getter;
Expand All @@ -63,9 +54,6 @@ public class AwsConnectorExecutor implements RemoteConnectorExecutor {
private ScriptService scriptService;
@Setter
@Getter
private ClusterService clusterService;
@Setter
@Getter
private TokenBucket modelRateLimiter;
@Setter
@Getter
Expand All @@ -83,95 +71,6 @@ public AwsConnectorExecutor(Connector connector) {
this(connector, new DefaultSdkHttpClientBuilder().build());
}

private Map<String, String> getCredentialsFromIAMRole(String roleArn, String clusterName) throws IOException {
Map<String, String> awsCredentials = new HashMap<>();
try {
CredentialsRequest credentialsRequest = new CredentialsRequest(roleArn, clusterName);
InternalAwsCredentials credentials = Credentials.getCredentials(credentialsRequest);
awsCredentials.put(ACCESS_KEY_FIELD, credentials.getAccessKey());
awsCredentials.put(SECRET_KEY_FIELD, credentials.getSecretKey());
awsCredentials.put(SESSION_TOKEN_FIELD, credentials.getSessionToken());
} catch (Exception ex) {
log.error("Exception occurred gaining credentials: " + ex);
throw ex;
}
return awsCredentials;
}

@Override
public void invokeRemoteModelInManagedService(
MLInput mlInput,
Map<String, String> parameters,
String payload,
List<ModelTensors> tensorOutputs
) {
try {
String clusterName = clusterService.getClusterName().toString();
String roleArn = "";
if (connector.getDecryptedCredential().get("roleArn") != null) {
roleArn = connector.getDecryptedCredential().get("roleArn");
}
Map<String, String> awsCredentials = getCredentialsFromIAMRole(roleArn, clusterName);
String endpoint = connector.getPredictEndpoint(parameters);
RequestBody requestBody = RequestBody.fromString(payload);

SdkHttpFullRequest.Builder builder = SdkHttpFullRequest
.builder()
.method(POST)
.uri(URI.create(endpoint))
.contentStreamProvider(requestBody.contentStreamProvider());
Map<String, String> headers = connector.getDecryptedHeaders();
if (headers != null) {
for (String key : headers.keySet()) {
builder.putHeader(key, headers.get(key));
}
}
SdkHttpFullRequest request = builder.build();
HttpExecuteRequest executeRequest = HttpExecuteRequest
.builder()
.request(signRequestInManagedService(request, awsCredentials))
.contentStreamProvider(request.contentStreamProvider().orElse(null))
.build();

HttpExecuteResponse response = AccessController.doPrivileged((PrivilegedExceptionAction<HttpExecuteResponse>) () -> {
return httpClient.prepareRequest(executeRequest).call();
});
int statusCode = response.httpResponse().statusCode();

AbortableInputStream body = null;
if (response.responseBody().isPresent()) {
body = response.responseBody().get();
}

StringBuilder responseBuilder = new StringBuilder();
if (body != null) {
try (BufferedReader reader = new BufferedReader(new InputStreamReader(body, StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
responseBuilder.append(line);
}
}
} else {
throw new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST);
}
String modelResponse = responseBuilder.toString();
if (statusCode < 200 || statusCode >= 300) {
throw new OpenSearchStatusException(REMOTE_SERVICE_ERROR + modelResponse, RestStatus.fromCode(statusCode));
}

ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters);
tensors.setStatusCode(statusCode);
tensorOutputs.add(tensors);
} catch (RuntimeException exception) {
log.error("Failed to execute predict in aws connector: " + exception.getMessage(), exception);
throw exception;
} catch (Throwable e) {
log.error("Failed to execute predict in aws connector", e);
throw new MLException("Fail to execute predict in aws connector", e);
}

}

@Override
public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, String payload, List<ModelTensors> tensorOutputs) {
try {
Expand Down Expand Up @@ -243,13 +142,4 @@ private SdkHttpFullRequest signRequest(SdkHttpFullRequest request) {

return ConnectorUtils.signRequest(request, accessKey, secretKey, sessionToken, signingName, region);
}

private SdkHttpFullRequest signRequestInManagedService(SdkHttpFullRequest request, Map<String, String> awsCredentials) {
String accessKey = awsCredentials.get(ACCESS_KEY_FIELD);
String secretKey = awsCredentials.get(SECRET_KEY_FIELD);
String sessionToken = awsCredentials.get(SESSION_TOKEN_FIELD);
String signingName = connector.getServiceName();
String region = connector.getRegion();
return ConnectorUtils.signRequest(request, accessKey, secretKey, sessionToken, signingName, region);
}
}
Loading

0 comments on commit 5a2d455

Please sign in to comment.