Skip to content

Commit

Permalink
PenTest fixes: error codes and update model group fix (opensearch-pro…
Browse files Browse the repository at this point in the history
…ject#1074)

* PenTest fixes: error codes and update model group fix

Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com>

* fix get model assertion error

Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com>

---------

Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com>
  • Loading branch information
rbhavna authored and zane-neo committed Sep 1, 2023
1 parent 34b07c5 commit b51b23d
Show file tree
Hide file tree
Showing 12 changed files with 48 additions and 81 deletions.
6 changes: 3 additions & 3 deletions docs/model_access_control.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ Updating a model group request is very similar to register model group request.
### Path and HTTP method

```
PUT /_plugins/_ml/model_groups/<model_group_id>/_update
PUT /_plugins/_ml/model_groups/<model_group_id>
```

A user can make updates to a model group to which he/she has access which is determined by the access mode of the model group.
Expand All @@ -196,7 +196,7 @@ For example,
Sample request allowed by admin/owner

```
PUT /_plugins/_ml/model_groups/<model_group_id>/_update
PUT /_plugins/_ml/model_groups/<model_group_id>
{
"name": "model_group_test",
"description": "This is an example description",
Expand All @@ -215,7 +215,7 @@ PUT /_plugins/_ml/model_groups/<model_group_id>/_update
Sample update request allowed by any other user with access to model group.

```
PUT /_plugins/_ml/model_groups/<model_group_id>/_update
PUT /_plugins/_ml/model_groups/<model_group_id>
{
"name": "model_group_test",
"description": "This is an example description"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.util.Map;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.support.ActionFilters;
Expand Down Expand Up @@ -41,6 +42,7 @@
import org.opensearch.ml.model.MLModelGroupManager;
import org.opensearch.ml.utils.MLNodeUtils;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.rest.RestStatus;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

Expand Down Expand Up @@ -101,7 +103,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLUpda
updateModelGroup(modelGroupId, modelGroup.getSource(), updateModelGroupInput, listener, user);
}
} else {
listener.onFailure(new MLResourceNotFoundException("Failed to find model group"));
listener.onFailure(new OpenSearchStatusException("Failed to find model group", RestStatus.NOT_FOUND));
}
}, e -> {
if (e instanceof IndexNotFoundException) {
Expand Down Expand Up @@ -197,18 +199,18 @@ private void validateRequestForAccessControl(MLUpdateModelGroupInput input, User
if (hasAccessControlChange(input)) {
if (!modelAccessControlHelper.isOwner(mlModelGroup.getOwner(), user) && !modelAccessControlHelper.isAdmin(user)) {
throw new IllegalArgumentException("Only owner or admin can update access control data.");
} else if (modelAccessControlHelper.isOwner(mlModelGroup.getOwner(), user)
&& !modelAccessControlHelper.isAdmin(user)
&& !modelAccessControlHelper.isOwnerStillHasPermission(user, mlModelGroup)) {
throw new IllegalArgumentException(
"You don’t have the specified backend role to update access control data. For more information, contact your administrator."
);
}
}
if (!modelAccessControlHelper.isAdmin(user)
&& !modelAccessControlHelper.isOwner(mlModelGroup.getOwner(), user)
&& !modelAccessControlHelper.isUserHasBackendRole(user, mlModelGroup)) {
throw new IllegalArgumentException("You don't have permission to update this model group.");
} else if (modelAccessControlHelper.isOwner(mlModelGroup.getOwner(), user)
&& !modelAccessControlHelper.isAdmin(user)
&& !modelAccessControlHelper.isOwnerStillHasPermission(user, mlModelGroup)) {
throw new IllegalArgumentException(
"You don’t have the specified backend role to update access control data. For more information, contact your administrator."
);
}
AccessMode accessMode = input.getModelAccessMode();
if ((AccessMode.PUBLIC == accessMode || AccessMode.PRIVATE == accessMode)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;
import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.get.GetResponse;
Expand All @@ -22,6 +23,7 @@
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.IndexNotFoundException;
Expand Down Expand Up @@ -110,7 +112,13 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLMode
actionListener.onFailure(e);
}
} else {
actionListener.onFailure(new IllegalArgumentException("Failed to find model with the provided model id: " + modelId));
actionListener
.onFailure(
new OpenSearchStatusException(
"Failed to find model with the provided model id: " + modelId,
RestStatus.NOT_FOUND
)
);
}
}, e -> {
if (e instanceof IndexNotFoundException) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.support.ActionFilters;
Expand All @@ -17,6 +18,7 @@
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.IndexNotFoundException;
Expand Down Expand Up @@ -68,7 +70,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLTask
actionListener.onFailure(e);
}
} else {
actionListener.onFailure(new MLResourceNotFoundException("Fail to find task"));
actionListener.onFailure(new OpenSearchStatusException("Fail to find task", RestStatus.NOT_FOUND));
}
}, e -> {
if (e instanceof IndexNotFoundException) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import java.util.function.Supplier;

import org.apache.logging.log4j.util.Strings;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.delete.DeleteRequest;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.get.GetResponse;
Expand Down Expand Up @@ -891,7 +892,7 @@ public void getModel(String modelId, String[] includes, String[] excludes, Actio
listener.onFailure(e);
}
} else {
listener.onFailure(new MLResourceNotFoundException("Fail to find model"));
listener.onFailure(new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND));
}
}, e -> { listener.onFailure(e); }));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
}, e -> {
log.error("Failed to get ML model", e);
try {
channel.sendResponse(new BytesRestResponse(channel, RestStatus.BAD_REQUEST, e));
channel.sendResponse(new BytesRestResponse(channel, RestStatus.NOT_FOUND, e));
} catch (IOException ex) {
log.error("Failed to send error response", ex);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@ public String getName() {
public List<Route> routes() {
return ImmutableList
.of(
new Route(
RestRequest.Method.PUT,
String.format(Locale.ROOT, "%s/model_groups/{%s}/_update", ML_BASE_URI, PARAMETER_MODEL_GROUP_ID)
)
new Route(RestRequest.Method.PUT, String.format(Locale.ROOT, "%s/model_groups/{%s}", ML_BASE_URI, PARAMETER_MODEL_GROUP_ID))
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ public void testDeleteModel_Success() throws IOException {
return null;
}).when(client).execute(any(), any(), any());

GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED);
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null);
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
Expand All @@ -172,16 +172,7 @@ public void testDeleteModel_Success_AlgorithmNotNull() throws IOException {
return null;
}).when(client).execute(any(), any(), any());

MLModel mlModel = MLModel
.builder()
.modelId("test_id")
.modelState(MLModelState.REGISTERED)
.algorithm(FunctionName.TEXT_EMBEDDING)
.build();
XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
BytesReference bytesReference = BytesReference.bytes(content);
GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null);
GetResponse getResponse = new GetResponse(getResult);
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null);
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
Expand Down Expand Up @@ -213,17 +204,8 @@ public void test_Success_ModelGroupIDNotNull_LastModelOfGroup() throws IOExcepti
return null;
}).when(client).search(any(), isA(ActionListener.class));

MLModel mlModel = MLModel
.builder()
.modelId("test_id")
.modelGroupId("modelGroupID")
.modelState(MLModelState.REGISTERED)
.algorithm(FunctionName.TEXT_EMBEDDING)
.build();
XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
BytesReference bytesReference = BytesReference.bytes(content);
GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null);
GetResponse getResponse = new GetResponse(getResult);
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, "modelGroupID");

doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
Expand Down Expand Up @@ -296,17 +278,8 @@ public void test_Failure_FailedToSearchLastModel() throws IOException {
return null;
}).when(client).search(any(), isA(ActionListener.class));

MLModel mlModel = MLModel
.builder()
.modelId("test_id")
.modelGroupId("modelGroupID")
.modelState(MLModelState.REGISTERED)
.algorithm(FunctionName.TEXT_EMBEDDING)
.build();
XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
BytesReference bytesReference = BytesReference.bytes(content);
GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null);
GetResponse getResponse = new GetResponse(getResult);
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, "modelGroupID");

doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
Expand All @@ -320,7 +293,7 @@ public void test_Failure_FailedToSearchLastModel() throws IOException {
}

public void test_UserHasNoAccessException() throws IOException {
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED);
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, "modelGroupID");
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
Expand All @@ -340,7 +313,7 @@ public void test_UserHasNoAccessException() throws IOException {
}

public void testDeleteModel_CheckModelState() throws IOException {
GetResponse getResponse = prepareMLModel(MLModelState.DEPLOYING);
GetResponse getResponse = prepareMLModel(MLModelState.DEPLOYING, null);
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
Expand Down Expand Up @@ -383,7 +356,7 @@ public void testDeleteModel_ResourceNotFoundException() throws IOException {
return null;
}).when(client).execute(any(), any(), any());

GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED);
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null);
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
Expand All @@ -397,7 +370,7 @@ public void testDeleteModel_ResourceNotFoundException() throws IOException {
}

public void test_ValidationFailedException() throws IOException {
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED);
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null);
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
Expand Down Expand Up @@ -441,7 +414,7 @@ public void testDeleteModelChunks_Success() {
}

public void testDeleteModel_RuntimeException() throws IOException {
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED);
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null);
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
Expand Down Expand Up @@ -535,8 +508,8 @@ public void test_FailToDeleteAllModelChunks_SearchFailure() {
assertEquals(OS_STATUS_EXCEPTION_MESSAGE + ", " + SEARCH_FAILURE_MSG + "test_id", argumentCaptor.getValue().getMessage());
}

public GetResponse prepareMLModel(MLModelState mlModelState) throws IOException {
MLModel mlModel = MLModel.builder().modelId("test_id").modelState(mlModelState).build();
public GetResponse prepareMLModel(MLModelState mlModelState, String modelGroupID) throws IOException {
MLModel mlModel = MLModel.builder().modelId("test_id").modelState(mlModelState).modelGroupId(modelGroupID).build();
XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
BytesReference bytesReference = BytesReference.bytes(content);
GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null);
Expand Down
Loading

0 comments on commit b51b23d

Please sign in to comment.