diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 1b3d22fe66..94be28baa6 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -388,8 +388,8 @@ public List getRestHandlers( RestMLRegisterModelAction restMLRegisterModelAction = new RestMLRegisterModelAction(clusterService, settings); RestMLDeployModelAction restMLDeployModelAction = new RestMLDeployModelAction(); RestMLUndeployModelAction restMLUndeployModelAction = new RestMLUndeployModelAction(clusterService, settings); - RestMLRegisterModelMetaAction restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction(); - RestMLUploadModelChunkAction restMLUploadModelChunkAction = new RestMLUploadModelChunkAction(); + RestMLRegisterModelMetaAction restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction(clusterService, settings); + RestMLUploadModelChunkAction restMLUploadModelChunkAction = new RestMLUploadModelChunkAction(clusterService, settings); return ImmutableList .of( @@ -507,7 +507,8 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_ENABLE_INHOUSE_PYTHON_MODEL, MLCommonsSettings.ML_COMMONS_MODEL_AUTO_REDEPLOY_ENABLE, MLCommonsSettings.ML_COMMONS_MODEL_AUTO_REDEPLOY_LIFETIME_RETRY_TIMES, - MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL + MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL, + MLCommonsSettings.ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD ); return settings; } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelMetaAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelMetaAction.java index 3f8ad55300..fa18eb4f8b 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelMetaAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelMetaAction.java @@ -7,12 +7,15 @@ import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD; import java.io.IOException; import java.util.List; import java.util.Locale; import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaAction; import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaInput; @@ -27,10 +30,19 @@ public class RestMLRegisterModelMetaAction extends BaseRestHandler { private static final String ML_REGISTER_MODEL_META_ACTION = "ml_register_model_meta_action"; + private volatile boolean isLocalFileUploadAllowed; + /** * Constructor + * @param clusterService cluster service + * @param settings settings */ - public RestMLRegisterModelMetaAction() {} + public RestMLRegisterModelMetaAction(ClusterService clusterService, Settings settings) { + isLocalFileUploadAllowed = ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD, it -> isLocalFileUploadAllowed = it); + } @Override public String getName() { @@ -66,7 +78,11 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client @VisibleForTesting MLRegisterModelMetaRequest getRequest(RestRequest request) throws IOException { boolean hasContent = request.hasContent(); - if (!hasContent) { + if (!isLocalFileUploadAllowed) { + throw new IllegalArgumentException( + "To upload custom model from local file, user needs to enable allow_registering_model_via_local_file settings. Otherwise please use opensearch pre-trained models" + ); + } else if (!hasContent) { throw new IOException("Model meta request has empty body"); } XContentParser parser = request.contentParser(); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUploadModelChunkAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUploadModelChunkAction.java index 0e3441ada7..82b3754433 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUploadModelChunkAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUploadModelChunkAction.java @@ -6,12 +6,15 @@ package org.opensearch.ml.rest; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD; import java.io.IOException; import java.util.List; import java.util.Locale; import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.ml.common.transport.upload_chunk.MLUploadModelChunkAction; import org.opensearch.ml.common.transport.upload_chunk.MLUploadModelChunkInput; import org.opensearch.ml.common.transport.upload_chunk.MLUploadModelChunkRequest; @@ -24,11 +27,17 @@ public class RestMLUploadModelChunkAction extends BaseRestHandler { private static final String ML_UPLOAD_MODEL_CHUNK_ACTION = "ml_upload_model_chunk_action"; + private volatile boolean isLocalFileUploadAllowed; /** * Constructor */ - public RestMLUploadModelChunkAction() {} + public RestMLUploadModelChunkAction(ClusterService clusterService, Settings settings) { + isLocalFileUploadAllowed = ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD, it -> isLocalFileUploadAllowed = it); + } @Override public String getName() { @@ -65,6 +74,11 @@ MLUploadModelChunkRequest getRequest(RestRequest request) throws IOException { final String modelId = request.param("model_id"); String chunk_number = request.param("chunk_number"); byte[] content = request.content().streamInput().readAllBytes(); + if (!isLocalFileUploadAllowed) { + throw new IllegalArgumentException( + "To upload custom model from local file, user needs to enable allow_registering_model_via_local_file settings. Otherwise please use opensearch pre-trained models." + ); + } MLUploadModelChunkInput mlInput = new MLUploadModelChunkInput(modelId, Integer.parseInt(chunk_number), content); return new MLUploadModelChunkRequest(mlInput); } diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 1e0a354b13..161dc228b4 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -91,4 +91,12 @@ private MLCommonsSettings() {} // This setting is to enable/disable model url in model register API. public static final Setting ML_COMMONS_ALLOW_MODEL_URL = Setting .boolSetting("plugins.ml_commons.allow_registering_model_via_url", false, Setting.Property.NodeScope, Setting.Property.Dynamic); + + public static final Setting ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD = Setting + .boolSetting( + "plugins.ml_commons.allow_registering_model_via_local_file", + false, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 8ed7568724..a7167e727c 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -5,15 +5,8 @@ package org.opensearch.ml.rest; -import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_ENABLED; -import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH; -import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_KEYPASSWORD; -import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_PASSWORD; -import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_PEMCERT_FILEPATH; -import static org.opensearch.ml.common.MLTask.FUNCTION_NAME_FIELD; -import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD; -import static org.opensearch.ml.common.MLTask.STATE_FIELD; -import static org.opensearch.ml.common.MLTask.TASK_ID_FIELD; +import static org.opensearch.commons.ConfigConstants.*; +import static org.opensearch.ml.common.MLTask.*; import static org.opensearch.ml.stats.MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT; import static org.opensearch.ml.stats.MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT; import static org.opensearch.ml.utils.TestData.SENTENCE_TRANSFORMER_MODEL_URL; @@ -23,14 +16,7 @@ import java.net.URI; import java.net.URISyntaxException; import java.nio.file.Path; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; +import java.util.*; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; @@ -128,6 +114,17 @@ public void setupSettings() throws IOException { ); assertEquals(200, response.getStatusLine().getStatusCode()); + response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"plugins.ml_commons.allow_registering_model_via_local_file\":true}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + String jsonEntity = "{\n" + " \"persistent\" : {\n" + " \"plugins.ml_commons.native_memory_threshold\" : 100 \n" diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelMetaActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelMetaActionTests.java index e431d3a22c..065823ba17 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelMetaActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelMetaActionTests.java @@ -8,6 +8,8 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.*; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD; +import static org.opensearch.ml.utils.TestHelper.clusterSetting; import java.io.IOException; import java.util.HashMap; @@ -20,10 +22,13 @@ import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.opensearch.action.ActionListener; import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Strings; import org.opensearch.common.bytes.BytesArray; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.xcontent.NamedXContentRegistry; @@ -50,12 +55,21 @@ public class RestMLRegisterModelMetaActionTests extends OpenSearchTestCase { @Mock RestChannel channel; + @Mock + private ClusterService clusterService; + + private Settings settings; + @Rule - public ExpectedException expectedEx = ExpectedException.none(); + public ExpectedException expectedException = ExpectedException.none(); @Before public void setup() { - restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction(); + MockitoAnnotations.openMocks(this); + settings = Settings.builder().put(ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD.getKey(), true).build(); + ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction(clusterService, settings); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); doAnswer(invocation -> { @@ -72,7 +86,7 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLRegisterModelMetaAction mlUploadModel = new RestMLRegisterModelMetaAction(); + RestMLRegisterModelMetaAction mlUploadModel = new RestMLRegisterModelMetaAction(clusterService, settings); assertNotNull(mlUploadModel); } @@ -112,12 +126,26 @@ public void testRegisterModelMetaRequest() throws Exception { assertEquals(Integer.valueOf(2), metaModelRequest.getTotalChunks()); } + public void testRegisterModelFileUploadNotAllowed() throws Exception { + settings = Settings.builder().put(ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD.getKey(), false).build(); + ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction(clusterService, settings); + expectedException.expect(IllegalArgumentException.class); + expectedException + .expectMessage( + "To upload custom model from local file, user needs to enable allow_registering_model_via_local_file settings. Otherwise please use opensearch pre-trained models" + ); + RestRequest request = getRestRequest(); + restMLRegisterModelMetaAction.handleRequest(request, channel, client); + } + public void testRegisterModelMeta_NoContent() throws Exception { RestRequest.Method method = RestRequest.Method.POST; Map params = new HashMap<>(); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withMethod(method).withParams(params).build(); - expectedEx.expect(IOException.class); - expectedEx.expectMessage("Model meta request has empty body"); + expectedException.expect(IOException.class); + expectedException.expectMessage("Model meta request has empty body"); restMLRegisterModelMetaAction.handleRequest(request, channel, client); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUploadModelChunkActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUploadModelChunkActionTests.java index c3f227789c..eb411a0179 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUploadModelChunkActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUploadModelChunkActionTests.java @@ -8,6 +8,8 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.*; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD; +import static org.opensearch.ml.utils.TestHelper.clusterSetting; import java.util.HashMap; import java.util.List; @@ -15,12 +17,17 @@ import org.junit.Before; import org.junit.Ignore; +import org.junit.Rule; +import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.opensearch.action.ActionListener; import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Strings; import org.opensearch.common.bytes.BytesArray; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.transport.model.MLModelGetResponse; @@ -43,10 +50,21 @@ public class RestMLUploadModelChunkActionTests extends OpenSearchTestCase { @Mock RestChannel channel; + @Mock + private ClusterService clusterService; + + private Settings settings; + + @Rule + public ExpectedException expectedException = ExpectedException.none(); @Before public void setup() { - restChunkUploadAction = new RestMLUploadModelChunkAction(); + MockitoAnnotations.openMocks(this); + settings = Settings.builder().put(ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD.getKey(), true).build(); + ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + restChunkUploadAction = new RestMLUploadModelChunkAction(clusterService, settings); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); doAnswer(invocation -> { @@ -63,7 +81,7 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLUploadModelChunkAction mlUploadChunk = new RestMLUploadModelChunkAction(); + RestMLUploadModelChunkAction mlUploadChunk = new RestMLUploadModelChunkAction(clusterService, settings); assertNotNull(mlUploadChunk); } @@ -102,6 +120,20 @@ public void testUploadChunkRequest() throws Exception { assertEquals(Integer.valueOf(0), chunkRequest.getChunkNumber()); } + public void testRegisterModelFileUploadNotAllowed() throws Exception { + settings = Settings.builder().put(ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD.getKey(), false).build(); + ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + restChunkUploadAction = new RestMLUploadModelChunkAction(clusterService, settings); + expectedException.expect(IllegalArgumentException.class); + expectedException + .expectMessage( + "To upload custom model from local file, user needs to enable allow_registering_model_via_local_file settings. Otherwise please use opensearch pre-trained models" + ); + RestRequest request = getRestRequest(); + restChunkUploadAction.handleRequest(request, channel, client); + } + private RestRequest getRestRequest() { RestRequest.Method method = RestRequest.Method.POST; BytesArray content = new BytesArray("12345678");