Skip to content

Commit 3ba2baf

Browse files
Add global resource support (#4003)
* feat: Add global resource support for remote metadata with DynamoDB integration - Add DynamoDB client dependency for remote metadata storage - Implement tenant-aware encryption for global vs tenant-specific resources - Migrate GetConfigTransportAction to use SDK client instead of direct ES client - Add REMOTE_METADATA_GLOBAL_TENANT_ID setting support - Force AWS SDK version alignment for dependency consistency Signed-off-by: Zirui Song <zrsong@amazon.com> * Fix config get transport action test failures Signed-off-by: zane-neo <zaniu@amazon.com> * Add UTs to increase code coverage Signed-off-by: zane-neo <zaniu@amazon.com> --------- Signed-off-by: Zirui Song <zrsong@amazon.com> Signed-off-by: zane-neo <zaniu@amazon.com> Co-authored-by: zane-neo <zaniu@amazon.com>
1 parent 247a88c commit 3ba2baf

File tree

14 files changed

+374
-116
lines changed

14 files changed

+374
-116
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ ml-algorithms/build/
1010
plugin/build/
1111
.DS_Store
1212
*/bin/
13+
**/*.factorypath

common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
package org.opensearch.ml.common.settings;
77

88
import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_ENDPOINT_KEY;
9+
import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_GLOBAL_RESOURCE_CACHE_TTL_KEY;
10+
import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_GLOBAL_TENANT_ID_KEY;
911
import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_REGION_KEY;
1012
import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_SERVICE_NAME_KEY;
1113
import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_TYPE_KEY;
@@ -372,4 +374,14 @@ private MLCommonsSettings() {}
372374
.boolSetting("plugins.ml_commons.agentic_memory_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
373375
public static final String ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE =
374376
"The Agentic Memory APIs are not enabled. To enable, please update the setting " + ML_COMMONS_AGENTIC_MEMORY_ENABLED.getKey();
377+
378+
public static final Setting<String> REMOTE_METADATA_GLOBAL_TENANT_ID = Setting
379+
.simpleString("plugins.ml-commons." + REMOTE_METADATA_GLOBAL_TENANT_ID_KEY, Setting.Property.NodeScope, Setting.Property.Final);
380+
381+
public static final Setting<String> REMOTE_METADATA_GLOBAL_RESOURCE_CACHE_TTL = Setting
382+
.simpleString(
383+
"plugins.ml-commons." + REMOTE_METADATA_GLOBAL_RESOURCE_CACHE_TTL_KEY,
384+
Setting.Property.NodeScope,
385+
Setting.Property.Final
386+
);
375387
}

ml-algorithms/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ dependencies {
5454
// Multi-tenant SDK Client
5555
implementation "org.opensearch:opensearch-remote-metadata-sdk:${opensearch_build}"
5656
implementation 'commons-beanutils:commons-beanutils:1.11.0'
57+
implementation "org.opensearch:opensearch-remote-metadata-sdk-ddb-client:${opensearch_build}"
5758

5859
def os = DefaultNativePlatform.currentOperatingSystem
5960
//arm/macos doesn't support GPU

ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,15 @@ public Predictable deploy(MLModel mlModel, Map<String, Object> params) {
150150
return predictable;
151151
}
152152

153+
public void deploy(MLModel mlModel, Map<String, Object> params, ActionListener<Predictable> listener) {
154+
Predictable predictable = MLEngineClassLoader.initInstance(mlModel.getAlgorithm(), null, MLAlgoParams.class);
155+
predictable.initModelAsync(mlModel, params, encryptor).thenAccept((b) -> listener.onResponse(predictable)).exceptionally(e -> {
156+
log.error("Failed to init init model", e);
157+
listener.onFailure(new RuntimeException(e));
158+
return null;
159+
});
160+
}
161+
153162
public MLExecutable deployExecute(MLModel mlModel, Map<String, Object> params) {
154163
MLExecutable executable = MLEngineClassLoader.initInstance(mlModel.getAlgorithm(), null, MLAlgoParams.class);
155164
executable.initModel(mlModel, params);

ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package org.opensearch.ml.engine;
77

88
import java.util.Map;
9+
import java.util.concurrent.CompletionStage;
910

1011
import org.opensearch.core.action.ActionListener;
1112
import org.opensearch.ml.common.MLModel;
@@ -19,6 +20,8 @@
1920
*/
2021
public interface Predictable {
2122

23+
String METHOD_NOT_IMPLEMENTED_ERROR_MSG = "Method is not implemented";
24+
2225
/**
2326
* Predict with given input data and model.
2427
* Will reload model into memory with model content.
@@ -34,11 +37,11 @@ public interface Predictable {
3437
* @return predicted results
3538
*/
3639
default MLOutput predict(MLInput mlInput) {
37-
throw new IllegalStateException("Method is not implemented");
40+
throw new IllegalStateException(METHOD_NOT_IMPLEMENTED_ERROR_MSG);
3841
}
3942

4043
default void asyncPredict(MLInput mlInput, ActionListener<MLTaskResponse> actionListener) {
41-
actionListener.onFailure(new IllegalStateException("Method is not implemented"));
44+
actionListener.onFailure(new IllegalStateException(METHOD_NOT_IMPLEMENTED_ERROR_MSG));
4245
}
4346

4447
/**
@@ -47,7 +50,13 @@ default void asyncPredict(MLInput mlInput, ActionListener<MLTaskResponse> action
4750
* @param params other parameters
4851
* @param encryptor encryptor
4952
*/
50-
void initModel(MLModel model, Map<String, Object> params, Encryptor encryptor);
53+
default void initModel(MLModel model, Map<String, Object> params, Encryptor encryptor) {
54+
throw new IllegalStateException(METHOD_NOT_IMPLEMENTED_ERROR_MSG);
55+
}
56+
57+
default CompletionStage<Boolean> initModelAsync(MLModel model, Map<String, Object> params, Encryptor encryptor) {
58+
throw new IllegalStateException(METHOD_NOT_IMPLEMENTED_ERROR_MSG);
59+
}
5160

5261
/**
5362
* Close resources like deployed model.

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,20 @@
66
package org.opensearch.ml.engine.algorithms.remote;
77

88
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
9+
import static org.opensearch.ml.common.settings.MLCommonsSettings.REMOTE_METADATA_GLOBAL_TENANT_ID;
910

1011
import java.util.Map;
12+
import java.util.concurrent.CompletableFuture;
13+
import java.util.concurrent.CompletionStage;
1114
import java.util.concurrent.atomic.AtomicBoolean;
1215

1316
import org.opensearch.cluster.service.ClusterService;
17+
import org.opensearch.common.settings.Settings;
1418
import org.opensearch.common.util.TokenBucket;
1519
import org.opensearch.core.action.ActionListener;
1620
import org.opensearch.core.xcontent.NamedXContentRegistry;
1721
import org.opensearch.ml.common.FunctionName;
22+
import org.opensearch.ml.common.MLIndex;
1823
import org.opensearch.ml.common.MLModel;
1924
import org.opensearch.ml.common.connector.Connector;
2025
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
@@ -28,6 +33,7 @@
2833
import org.opensearch.ml.engine.Predictable;
2934
import org.opensearch.ml.engine.annotation.Function;
3035
import org.opensearch.ml.engine.encryptor.Encryptor;
36+
import org.opensearch.remote.metadata.client.SdkClient;
3137
import org.opensearch.script.ScriptService;
3238
import org.opensearch.transport.client.Client;
3339

@@ -47,6 +53,8 @@ public class RemoteModel implements Predictable {
4753
public static final String USER_RATE_LIMITER_MAP = "user_rate_limiter_map";
4854
public static final String GUARDRAILS = "guardrails";
4955
public static final String CONNECTOR_PRIVATE_IP_ENABLED = "connectorPrivateIpEnabled";
56+
public static final String SDK_CLIENT = "sdk_client";
57+
public static final String SETTINGS = "settings";
5058

5159
private RemoteConnectorExecutor connectorExecutor;
5260

@@ -98,11 +106,14 @@ public boolean isModelReady() {
98106
}
99107

100108
@Override
101-
public void initModel(MLModel model, Map<String, Object> params, Encryptor encryptor) {
102-
try {
109+
public CompletionStage<Boolean> initModelAsync(MLModel model, Map<String, Object> params, Encryptor encryptor) {
110+
SdkClient sdkClient = (SdkClient) params.get(SDK_CLIENT);
111+
return sdkClient.isGlobalResource(MLIndex.MODEL.getIndexName(), model.getModelId()).thenCompose(isGlobalResource -> {
112+
String decryptTenantId = Boolean.TRUE.equals(isGlobalResource)
113+
? REMOTE_METADATA_GLOBAL_TENANT_ID.get((Settings) params.get(SETTINGS))
114+
: model.getTenantId();
103115
Connector connector = model.getConnector().cloneConnector();
104-
connector
105-
.decrypt(PREDICT.name(), (credential, tenantId) -> encryptor.decrypt(credential, model.getTenantId()), model.getTenantId());
116+
connector.decrypt(PREDICT.name(), (credential, tenantId) -> encryptor.decrypt(credential, decryptTenantId), decryptTenantId);
106117
// This situation can only happen for inline connector where we don't provide tenant id.
107118
if (connector.getTenantId() == null && model.getTenantId() != null) {
108119
connector.setTenantId(model.getTenantId());
@@ -116,13 +127,10 @@ public void initModel(MLModel model, Map<String, Object> params, Encryptor encry
116127
this.connectorExecutor.setUserRateLimiterMap((Map<String, TokenBucket>) params.get(USER_RATE_LIMITER_MAP));
117128
this.connectorExecutor.setMlGuard((MLGuard) params.get(GUARDRAILS));
118129
this.connectorExecutor.setConnectorPrivateIpEnabled((AtomicBoolean) params.get(CONNECTOR_PRIVATE_IP_ENABLED));
119-
} catch (RuntimeException e) {
120-
log.error("Failed to init remote model.", e);
121-
throw e;
122-
} catch (Throwable e) {
130+
return CompletableFuture.completedStage(true);
131+
}).exceptionally(e -> {
123132
log.error("Failed to init remote model.", e);
124133
throw new MLException(e);
125-
}
134+
});
126135
}
127-
128136
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,14 @@
88
import static org.junit.Assert.assertEquals;
99
import static org.junit.Assert.assertNotEquals;
1010
import static org.junit.Assert.assertNotNull;
11+
import static org.junit.Assert.assertTrue;
1112
import static org.junit.Assert.fail;
13+
import static org.mockito.ArgumentMatchers.any;
14+
import static org.mockito.Mockito.mock;
15+
import static org.mockito.Mockito.verify;
16+
import static org.mockito.Mockito.when;
17+
import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.SDK_CLIENT;
18+
import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.SETTINGS;
1219
import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame;
1320
import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame;
1421
import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame;
@@ -19,14 +26,17 @@
1926
import java.util.Arrays;
2027
import java.util.Collections;
2128
import java.util.List;
29+
import java.util.Locale;
2230
import java.util.Map;
2331
import java.util.UUID;
32+
import java.util.concurrent.CompletableFuture;
2433

2534
import org.junit.Assert;
2635
import org.junit.Before;
2736
import org.junit.Rule;
2837
import org.junit.Test;
2938
import org.junit.rules.ExpectedException;
39+
import org.mockito.ArgumentCaptor;
3040
import org.mockito.MockedStatic;
3141
import org.opensearch.common.settings.Settings;
3242
import org.opensearch.common.xcontent.XContentType;
@@ -37,6 +47,7 @@
3747
import org.opensearch.core.xcontent.XContentParser;
3848
import org.opensearch.ml.common.FunctionName;
3949
import org.opensearch.ml.common.MLModel;
50+
import org.opensearch.ml.common.connector.AwsConnector;
4051
import org.opensearch.ml.common.connector.HttpConnector;
4152
import org.opensearch.ml.common.dataframe.ColumnMeta;
4253
import org.opensearch.ml.common.dataframe.DataFrame;
@@ -57,8 +68,11 @@
5768
import org.opensearch.ml.engine.algorithms.regression.LinearRegression;
5869
import org.opensearch.ml.engine.encryptor.Encryptor;
5970
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
71+
import org.opensearch.remote.metadata.client.SdkClient;
6072
import org.opensearch.search.SearchModule;
6173

74+
import software.amazon.awssdk.utils.ImmutableMap;
75+
6276
// TODO: refactor MLEngineClassLoader's static functions to avoid mockStatic
6377
public class MLEngineTest extends MLStaticMockBase {
6478
@Rule
@@ -523,4 +537,74 @@ public void testGetConnectorCredentialWithoutRegion() throws IOException {
523537
assertEquals("test_key_value", decryptedCredential.get("key"));
524538
assertEquals(null, decryptedCredential.get("region"));
525539
}
540+
541+
@Test
542+
public void testDeploy_withPredictableActionListener_successful() throws IOException {
543+
String encryptedAccessKey = mlEngine.encrypt("access-key", null);
544+
String encryptedSecretKey = mlEngine.encrypt("secret-key", null);
545+
String testConnector = String.format(Locale.ROOT, """
546+
{
547+
"name": "sagemaker: t2ppl",
548+
"description": "t2ppl model",
549+
"version": 1,
550+
"protocol": "aws_sigv4",
551+
"credential": {
552+
"access_key": "%s",
553+
"secret_key": "%s"
554+
},
555+
"parameters": {
556+
"region": "us-east-1",
557+
"service_name": "sagemaker",
558+
"input_type": "search_document"
559+
},
560+
"actions": [
561+
{
562+
"action_type": "predict",
563+
"method": "POST",
564+
"headers": {
565+
"content-type": "application/json",
566+
"x-amz-content-sha256": "required"
567+
},
568+
"url": "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/my-endpoint/invocations",
569+
"request_body": "{\\"prompt\\":\\"${parameters.prompt}\\"}"
570+
}
571+
]
572+
}
573+
""", encryptedAccessKey, encryptedSecretKey);
574+
575+
XContentParser parser = XContentType.JSON
576+
.xContent()
577+
.createParser(
578+
new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()),
579+
null,
580+
testConnector
581+
);
582+
parser.nextToken();
583+
584+
MLModel model = mock(MLModel.class);
585+
AwsConnector connector = new AwsConnector("aws_sigv4", parser);
586+
when(model.getAlgorithm()).thenReturn(FunctionName.REMOTE);
587+
when(model.getConnector()).thenReturn(connector);
588+
ActionListener<Predictable> actionListener = mock(ActionListener.class);
589+
SdkClient sdkClient = mock(SdkClient.class);
590+
when(sdkClient.isGlobalResource(any(), any())).thenReturn(CompletableFuture.completedFuture(false));
591+
Map<String, Object> params = ImmutableMap.of(SDK_CLIENT, sdkClient, SETTINGS, Settings.EMPTY);
592+
mlEngine.deploy(model, params, actionListener);
593+
verify(actionListener).onResponse(any(Predictable.class));
594+
}
595+
596+
@Test
597+
public void testDeploy_withPredictableActionListener_exceptional() {
598+
MLModel model = mock(MLModel.class);
599+
when(model.getAlgorithm()).thenReturn(FunctionName.REMOTE);
600+
when(model.getConnector()).thenThrow(new RuntimeException("Runtime error"));
601+
ActionListener<Predictable> actionListener = mock(ActionListener.class);
602+
SdkClient sdkClient = mock(SdkClient.class);
603+
when(sdkClient.isGlobalResource(any(), any())).thenReturn(CompletableFuture.completedFuture(false));
604+
Map<String, Object> params = ImmutableMap.of(SDK_CLIENT, sdkClient, SETTINGS, Settings.EMPTY);
605+
mlEngine.deploy(model, params, actionListener);
606+
ArgumentCaptor<RuntimeException> argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class);
607+
verify(actionListener).onFailure(argumentCaptor.capture());
608+
assertTrue(argumentCaptor.getValue().getMessage().contains("Runtime error"));
609+
}
526610
}

0 commit comments

Comments
 (0)