diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java index 4449ee6996..191d5ad9c2 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java @@ -7,6 +7,11 @@ import com.google.common.annotations.VisibleForTesting; import lombok.extern.log4j.Log4j2; +import org.opensearch.ResourceNotFoundException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.LatchedActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.xcontent.NamedXContentRegistry; @@ -23,6 +28,11 @@ import org.opensearch.script.ScriptService; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; + +import static org.opensearch.ml.common.CommonValue.MASTER_KEY; +import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX; @Log4j2 @Function(FunctionName.REMOTE) @@ -77,11 +87,42 @@ public boolean isModelReady() { public void initModel(MLModel model, Map params, Encryptor encryptor) { try { Connector connector = model.getConnector().cloneConnector(); - connector.decrypt((credential) -> encryptor.decrypt(credential)); + + ClusterService clusterService = (ClusterService) params.get(CLUSTER_SERVICE); + Client client = (Client) params.get(CLIENT); + CountDownLatch latch = new CountDownLatch(1); + AtomicReference exceptionRef = new AtomicReference<>(); + if (encryptor.getMasterKey() == null) { + if (clusterService.state().metadata().hasIndex(ML_CONFIG_INDEX)) { + GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY); + client.get(getRequest, new LatchedActionListener(ActionListener.< GetResponse >wrap(r-> { + if (r.isExists()) { + String masterKey = (String)r.getSourceAsMap().get(MASTER_KEY); + encryptor.setMasterKey(masterKey); + } else { + exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet")); + } + }, e-> { + log.error("Failed to get ML encryption master key", e); + exceptionRef.set(e); + }), latch)); + } else { + exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet")); + } + } + + if (exceptionRef.get() != null) { + throw exceptionRef.get(); + } + if (encryptor.getMasterKey() != null) { + connector.decrypt((credential) -> encryptor.decrypt(credential)); + } else { + throw new MLException("ML encryptor not initialized"); + } this.connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class); this.connectorExecutor.setScriptService((ScriptService) params.get(SCRIPT_SERVICE)); - this.connectorExecutor.setClusterService((ClusterService) params.get(CLUSTER_SERVICE)); - this.connectorExecutor.setClient((Client) params.get(CLIENT)); + this.connectorExecutor.setClusterService(clusterService); + this.connectorExecutor.setClient(client); this.connectorExecutor.setXContentRegistry((NamedXContentRegistry) params.get(XCONTENT_REGISTRY)); } catch (RuntimeException e) { log.error("Failed to init remote model", e); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/Encryptor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/Encryptor.java index 7bbe58cec5..2316869ffd 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/Encryptor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/Encryptor.java @@ -32,6 +32,7 @@ public interface Encryptor { * @param masterKey masterKey to be set. */ void setMasterKey(String masterKey); + String getMasterKey(); String generateMasterKey(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java index 16abde0a24..6179e0e297 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java @@ -29,6 +29,11 @@ public void setMasterKey(String masterKey) { this.masterKey = masterKey; } + @Override + public String getMasterKey() { + return masterKey; + } + @Override public String encrypt(String plainText) { checkMasterKey(); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java index b39ae6bc9e..b1fabaa0c8 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java @@ -13,6 +13,16 @@ import org.junit.rules.ExpectedException; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.ResourceNotFoundException; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; @@ -22,19 +32,38 @@ import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import java.time.Instant; import java.util.Arrays; +import java.util.HashMap; import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.CREATE_TIME_FIELD; +import static org.opensearch.ml.common.CommonValue.MASTER_KEY; +import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX; +import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.CLIENT; +import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.CLUSTER_SERVICE; public class RemoteModelTest { @Mock MLInput mlInput; + @Mock + Client client; + + @Mock + ClusterService clusterService; + + @Mock + ClusterState clusterState; + @Mock MLModel mlModel; @@ -44,12 +73,46 @@ public class RemoteModelTest { RemoteModel remoteModel; Encryptor encryptor; + String masterKey; + + Map params; + private static final AtomicInteger portGenerator = new AtomicInteger(); + @Before public void setUp() { MockitoAnnotations.openMocks(this); remoteModel = new RemoteModel(); encryptor = spy(new EncryptorImpl()); - encryptor.setMasterKey("0000000000000001"); + masterKey = "0000000000000001"; + encryptor.setMasterKey(masterKey); + params = new HashMap<>(); + params.put(CLIENT, client); + params.put(CLUSTER_SERVICE, clusterService); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(true); + when(response.getSourceAsMap()) + .thenReturn(ImmutableMap.of(MASTER_KEY, masterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli())); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + + + when(clusterService.state()).thenReturn(clusterState); + + Metadata metadata = new Metadata.Builder() + .indices(ImmutableMap + .builder() + .put(ML_CONFIG_INDEX, IndexMetadata.builder(ML_CONFIG_INDEX) + .settings(Settings.builder() + .put("index.number_of_shards", 1) + .put("index.number_of_replicas", 1) + .put("index.version.created", Version.CURRENT.id)) + .build()) + .build()).build(); + when(clusterState.metadata()).thenReturn(metadata); } @Test @@ -112,6 +175,80 @@ public void initModel_WithHeader() { Assert.assertNull(remoteModel.getConnectorExecutor()); } + @Test + public void initModel_WithHeader_NullMasterKey_MasterKeyExistInIndex() { + Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); + when(mlModel.getConnector()).thenReturn(connector); + Encryptor encryptor = new EncryptorImpl(); + remoteModel.initModel(mlModel, params, encryptor); + Map decryptedHeaders = connector.getDecryptedHeaders(); + RemoteConnectorExecutor executor = remoteModel.getConnectorExecutor(); + Assert.assertNotNull(executor); + Assert.assertNull(decryptedHeaders); + Assert.assertNotNull(executor.getConnector().getDecryptedHeaders()); + Assert.assertEquals(1, executor.getConnector().getDecryptedHeaders().size()); + Assert.assertEquals("Bearer test_api_key", executor.getConnector().getDecryptedHeaders().get("Authorization")); + + remoteModel.close(); + Assert.assertNull(remoteModel.getConnectorExecutor()); + } + + @Test + public void initModel_WithHeader_NullMasterKey_MasterKeyNotExistInIndex() { + exceptionRule.expect(ResourceNotFoundException.class); + exceptionRule.expectMessage("ML encryption master key not initialized yet"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(false); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + + Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); + when(mlModel.getConnector()).thenReturn(connector); + Encryptor encryptor = new EncryptorImpl(); + remoteModel.initModel(mlModel, params, encryptor); + } + + @Test + public void initModel_WithHeader_GetMasterKey_Exception() { + exceptionRule.expect(RuntimeException.class); + exceptionRule.expectMessage("test error"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("test error")); + return null; + }).when(client).get(any(), any()); + + Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); + when(mlModel.getConnector()).thenReturn(connector); + Encryptor encryptor = new EncryptorImpl(); + remoteModel.initModel(mlModel, params, encryptor); + } + + @Test + public void initModel_WithHeader_IndexNotFound() { + exceptionRule.expect(ResourceNotFoundException.class); + exceptionRule.expectMessage("ML encryption master key not initialized yet"); + + Metadata metadata = new Metadata.Builder().indices(ImmutableMap.of()).build(); + when(clusterState.metadata()).thenReturn(metadata); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("test error")); + return null; + }).when(client).get(any(), any()); + + Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); + when(mlModel.getConnector()).thenReturn(connector); + Encryptor encryptor = new EncryptorImpl(); + remoteModel.initModel(mlModel, params, encryptor); + } + private Connector createConnector(Map headers) { ConnectorAction predictAction = ConnectorAction.builder() .actionType(ConnectorAction.ActionType.PREDICT)