Skip to content

Commit

Permalink
Add a setting to enable/disable local upload while registering model (#…
Browse files Browse the repository at this point in the history
…873)

Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com>
  • Loading branch information
rbhavna authored May 1, 2023
1 parent 94f6d22 commit 368c362
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,8 @@ public List<RestHandler> getRestHandlers(
RestMLRegisterModelAction restMLRegisterModelAction = new RestMLRegisterModelAction(clusterService, settings);
RestMLDeployModelAction restMLDeployModelAction = new RestMLDeployModelAction();
RestMLUndeployModelAction restMLUndeployModelAction = new RestMLUndeployModelAction(clusterService, settings);
RestMLRegisterModelMetaAction restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction();
RestMLUploadModelChunkAction restMLUploadModelChunkAction = new RestMLUploadModelChunkAction();
RestMLRegisterModelMetaAction restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction(clusterService, settings);
RestMLUploadModelChunkAction restMLUploadModelChunkAction = new RestMLUploadModelChunkAction(clusterService, settings);

return ImmutableList
.of(
Expand Down Expand Up @@ -507,7 +507,8 @@ public List<Setting<?>> getSettings() {
MLCommonsSettings.ML_COMMONS_ENABLE_INHOUSE_PYTHON_MODEL,
MLCommonsSettings.ML_COMMONS_MODEL_AUTO_REDEPLOY_ENABLE,
MLCommonsSettings.ML_COMMONS_MODEL_AUTO_REDEPLOY_LIFETIME_RETRY_TIMES,
MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL
MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL,
MLCommonsSettings.ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD
);
return settings;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD;

import java.io.IOException;
import java.util.List;
import java.util.Locale;

import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaAction;
import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaInput;
Expand All @@ -27,10 +30,19 @@
public class RestMLRegisterModelMetaAction extends BaseRestHandler {
private static final String ML_REGISTER_MODEL_META_ACTION = "ml_register_model_meta_action";

private volatile boolean isLocalFileUploadAllowed;

/**
* Constructor
* @param clusterService cluster service
* @param settings settings
*/
public RestMLRegisterModelMetaAction() {}
public RestMLRegisterModelMetaAction(ClusterService clusterService, Settings settings) {
isLocalFileUploadAllowed = ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD.get(settings);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD, it -> isLocalFileUploadAllowed = it);
}

@Override
public String getName() {
Expand Down Expand Up @@ -66,7 +78,11 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
@VisibleForTesting
MLRegisterModelMetaRequest getRequest(RestRequest request) throws IOException {
boolean hasContent = request.hasContent();
if (!hasContent) {
if (!isLocalFileUploadAllowed) {
throw new IllegalArgumentException(
"To upload custom model from local file, user needs to enable allow_registering_model_via_local_file settings. Otherwise please use opensearch pre-trained models"
);
} else if (!hasContent) {
throw new IOException("Model meta request has empty body");
}
XContentParser parser = request.contentParser();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
package org.opensearch.ml.rest;

import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD;

import java.io.IOException;
import java.util.List;
import java.util.Locale;

import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.ml.common.transport.upload_chunk.MLUploadModelChunkAction;
import org.opensearch.ml.common.transport.upload_chunk.MLUploadModelChunkInput;
import org.opensearch.ml.common.transport.upload_chunk.MLUploadModelChunkRequest;
Expand All @@ -24,11 +27,17 @@

public class RestMLUploadModelChunkAction extends BaseRestHandler {
private static final String ML_UPLOAD_MODEL_CHUNK_ACTION = "ml_upload_model_chunk_action";
private volatile boolean isLocalFileUploadAllowed;

/**
* Constructor
*/
public RestMLUploadModelChunkAction() {}
public RestMLUploadModelChunkAction(ClusterService clusterService, Settings settings) {
isLocalFileUploadAllowed = ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD.get(settings);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD, it -> isLocalFileUploadAllowed = it);
}

@Override
public String getName() {
Expand Down Expand Up @@ -65,6 +74,11 @@ MLUploadModelChunkRequest getRequest(RestRequest request) throws IOException {
final String modelId = request.param("model_id");
String chunk_number = request.param("chunk_number");
byte[] content = request.content().streamInput().readAllBytes();
if (!isLocalFileUploadAllowed) {
throw new IllegalArgumentException(
"To upload custom model from local file, user needs to enable allow_registering_model_via_local_file settings. Otherwise please use opensearch pre-trained models."
);
}
MLUploadModelChunkInput mlInput = new MLUploadModelChunkInput(modelId, Integer.parseInt(chunk_number), content);
return new MLUploadModelChunkRequest(mlInput);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,12 @@ private MLCommonsSettings() {}
// This setting is to enable/disable model url in model register API.
public static final Setting<Boolean> ML_COMMONS_ALLOW_MODEL_URL = Setting
.boolSetting("plugins.ml_commons.allow_registering_model_via_url", false, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<Boolean> ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD = Setting
.boolSetting(
"plugins.ml_commons.allow_registering_model_via_local_file",
false,
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,8 @@

package org.opensearch.ml.rest;

import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_ENABLED;
import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH;
import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_KEYPASSWORD;
import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_PASSWORD;
import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_PEMCERT_FILEPATH;
import static org.opensearch.ml.common.MLTask.FUNCTION_NAME_FIELD;
import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD;
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
import static org.opensearch.ml.common.MLTask.TASK_ID_FIELD;
import static org.opensearch.commons.ConfigConstants.*;
import static org.opensearch.ml.common.MLTask.*;
import static org.opensearch.ml.stats.MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT;
import static org.opensearch.ml.stats.MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT;
import static org.opensearch.ml.utils.TestData.SENTENCE_TRANSFORMER_MODEL_URL;
Expand All @@ -23,14 +16,7 @@
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
Expand Down Expand Up @@ -128,6 +114,17 @@ public void setupSettings() throws IOException {
);
assertEquals(200, response.getStatusLine().getStatusCode());

response = TestHelper
.makeRequest(
client(),
"PUT",
"_cluster/settings",
null,
"{\"persistent\":{\"plugins.ml_commons.allow_registering_model_via_local_file\":true}}",
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, ""))
);
assertEquals(200, response.getStatusLine().getStatusCode());

String jsonEntity = "{\n"
+ " \"persistent\" : {\n"
+ " \"plugins.ml_commons.native_memory_threshold\" : 100 \n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD;
import static org.opensearch.ml.utils.TestHelper.clusterSetting;

import java.io.IOException;
import java.util.HashMap;
Expand All @@ -20,10 +22,13 @@
import org.junit.rules.ExpectedException;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionListener;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.Strings;
import org.opensearch.common.bytes.BytesArray;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.xcontent.NamedXContentRegistry;
Expand All @@ -50,12 +55,21 @@ public class RestMLRegisterModelMetaActionTests extends OpenSearchTestCase {
@Mock
RestChannel channel;

@Mock
private ClusterService clusterService;

private Settings settings;

@Rule
public ExpectedException expectedEx = ExpectedException.none();
public ExpectedException expectedException = ExpectedException.none();

@Before
public void setup() {
restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction();
MockitoAnnotations.openMocks(this);
settings = Settings.builder().put(ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD.getKey(), true).build();
ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction(clusterService, settings);
threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool");
client = spy(new NodeClient(Settings.EMPTY, threadPool));
doAnswer(invocation -> {
Expand All @@ -72,7 +86,7 @@ public void tearDown() throws Exception {
}

public void testConstructor() {
RestMLRegisterModelMetaAction mlUploadModel = new RestMLRegisterModelMetaAction();
RestMLRegisterModelMetaAction mlUploadModel = new RestMLRegisterModelMetaAction(clusterService, settings);
assertNotNull(mlUploadModel);
}

Expand Down Expand Up @@ -112,12 +126,26 @@ public void testRegisterModelMetaRequest() throws Exception {
assertEquals(Integer.valueOf(2), metaModelRequest.getTotalChunks());
}

public void testRegisterModelFileUploadNotAllowed() throws Exception {
settings = Settings.builder().put(ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD.getKey(), false).build();
ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction(clusterService, settings);
expectedException.expect(IllegalArgumentException.class);
expectedException
.expectMessage(
"To upload custom model from local file, user needs to enable allow_registering_model_via_local_file settings. Otherwise please use opensearch pre-trained models"
);
RestRequest request = getRestRequest();
restMLRegisterModelMetaAction.handleRequest(request, channel, client);
}

public void testRegisterModelMeta_NoContent() throws Exception {
RestRequest.Method method = RestRequest.Method.POST;
Map<String, String> params = new HashMap<>();
RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withMethod(method).withParams(params).build();
expectedEx.expect(IOException.class);
expectedEx.expectMessage("Model meta request has empty body");
expectedException.expect(IOException.class);
expectedException.expectMessage("Model meta request has empty body");
restMLRegisterModelMetaAction.handleRequest(request, channel, client);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,26 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD;
import static org.opensearch.ml.utils.TestHelper.clusterSetting;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.junit.Before;
import org.junit.Ignore;
import org.junit.Rule;
import org.junit.rules.ExpectedException;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionListener;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.Strings;
import org.opensearch.common.bytes.BytesArray;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.transport.model.MLModelGetResponse;
Expand All @@ -43,10 +50,21 @@ public class RestMLUploadModelChunkActionTests extends OpenSearchTestCase {

@Mock
RestChannel channel;
@Mock
private ClusterService clusterService;

private Settings settings;

@Rule
public ExpectedException expectedException = ExpectedException.none();

@Before
public void setup() {
restChunkUploadAction = new RestMLUploadModelChunkAction();
MockitoAnnotations.openMocks(this);
settings = Settings.builder().put(ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD.getKey(), true).build();
ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
restChunkUploadAction = new RestMLUploadModelChunkAction(clusterService, settings);
threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool");
client = spy(new NodeClient(Settings.EMPTY, threadPool));
doAnswer(invocation -> {
Expand All @@ -63,7 +81,7 @@ public void tearDown() throws Exception {
}

public void testConstructor() {
RestMLUploadModelChunkAction mlUploadChunk = new RestMLUploadModelChunkAction();
RestMLUploadModelChunkAction mlUploadChunk = new RestMLUploadModelChunkAction(clusterService, settings);
assertNotNull(mlUploadChunk);
}

Expand Down Expand Up @@ -102,6 +120,20 @@ public void testUploadChunkRequest() throws Exception {
assertEquals(Integer.valueOf(0), chunkRequest.getChunkNumber());
}

public void testRegisterModelFileUploadNotAllowed() throws Exception {
settings = Settings.builder().put(ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD.getKey(), false).build();
ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
restChunkUploadAction = new RestMLUploadModelChunkAction(clusterService, settings);
expectedException.expect(IllegalArgumentException.class);
expectedException
.expectMessage(
"To upload custom model from local file, user needs to enable allow_registering_model_via_local_file settings. Otherwise please use opensearch pre-trained models"
);
RestRequest request = getRestRequest();
restChunkUploadAction.handleRequest(request, channel, client);
}

private RestRequest getRestRequest() {
RestRequest.Method method = RestRequest.Method.POST;
BytesArray content = new BytesArray("12345678");
Expand Down

0 comments on commit 368c362

Please sign in to comment.