Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PenTest fixes: error codes and update model group fix #1074

Merged
merged 2 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/model_access_control.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ Updating a model group request is very similar to register model group request.
### Path and HTTP method

```
PUT /_plugins/_ml/model_groups/<model_group_id>/_update
PUT /_plugins/_ml/model_groups/<model_group_id>
```

A user can make updates to a model group to which he/she has access which is determined by the access mode of the model group.
Expand All @@ -196,7 +196,7 @@ For example,
Sample request allowed by admin/owner

```
PUT /_plugins/_ml/model_groups/<model_group_id>/_update
PUT /_plugins/_ml/model_groups/<model_group_id>
{
"name": "model_group_test",
"description": "This is an example description",
Expand All @@ -215,7 +215,7 @@ PUT /_plugins/_ml/model_groups/<model_group_id>/_update
Sample update request allowed by any other user with access to model group.

```
PUT /_plugins/_ml/model_groups/<model_group_id>/_update
PUT /_plugins/_ml/model_groups/<model_group_id>
{
"name": "model_group_test",
"description": "This is an example description"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.util.Map;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.get.GetRequest;
Expand Down Expand Up @@ -41,6 +42,7 @@
import org.opensearch.ml.model.MLModelGroupManager;
import org.opensearch.ml.utils.MLNodeUtils;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.rest.RestStatus;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

Expand Down Expand Up @@ -101,7 +103,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLUpda
updateModelGroup(modelGroupId, modelGroup.getSource(), updateModelGroupInput, listener, user);
}
} else {
listener.onFailure(new MLResourceNotFoundException("Failed to find model group"));
listener.onFailure(new OpenSearchStatusException("Failed to find model group", RestStatus.NOT_FOUND));
}
}, e -> {
if (e instanceof IndexNotFoundException) {
Expand Down Expand Up @@ -197,18 +199,18 @@ private void validateRequestForAccessControl(MLUpdateModelGroupInput input, User
if (hasAccessControlChange(input)) {
if (!modelAccessControlHelper.isOwner(mlModelGroup.getOwner(), user) && !modelAccessControlHelper.isAdmin(user)) {
throw new IllegalArgumentException("Only owner or admin can update access control data.");
} else if (modelAccessControlHelper.isOwner(mlModelGroup.getOwner(), user)
&& !modelAccessControlHelper.isAdmin(user)
&& !modelAccessControlHelper.isOwnerStillHasPermission(user, mlModelGroup)) {
throw new IllegalArgumentException(
"You don’t have the specified backend role to update access control data. For more information, contact your administrator."
);
}
}
if (!modelAccessControlHelper.isAdmin(user)
&& !modelAccessControlHelper.isOwner(mlModelGroup.getOwner(), user)
&& !modelAccessControlHelper.isUserHasBackendRole(user, mlModelGroup)) {
throw new IllegalArgumentException("You don't have permission to update this model group.");
} else if (modelAccessControlHelper.isOwner(mlModelGroup.getOwner(), user)
&& !modelAccessControlHelper.isAdmin(user)
&& !modelAccessControlHelper.isOwnerStillHasPermission(user, mlModelGroup)) {
throw new IllegalArgumentException(
"You don’t have the specified backend role to update access control data. For more information, contact your administrator."
);
}
AccessMode accessMode = input.getModelAccessMode();
if ((AccessMode.PUBLIC == accessMode || AccessMode.PRIVATE == accessMode)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;
import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.get.GetRequest;
Expand All @@ -34,6 +35,7 @@
import org.opensearch.ml.common.transport.model.MLModelGetResponse;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.rest.RestStatus;
import org.opensearch.search.fetch.subphase.FetchSourceContext;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
Expand Down Expand Up @@ -110,7 +112,13 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLMode
actionListener.onFailure(e);
}
} else {
actionListener.onFailure(new IllegalArgumentException("Failed to find model with the provided model id: " + modelId));
actionListener
.onFailure(
new OpenSearchStatusException(
"Failed to find model with the provided model id: " + modelId,
RestStatus.NOT_FOUND
)
);
}
}, e -> {
if (e instanceof IndexNotFoundException) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.get.GetRequest;
Expand All @@ -25,6 +26,7 @@
import org.opensearch.ml.common.transport.task.MLTaskGetAction;
import org.opensearch.ml.common.transport.task.MLTaskGetRequest;
import org.opensearch.ml.common.transport.task.MLTaskGetResponse;
import org.opensearch.rest.RestStatus;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

Expand Down Expand Up @@ -68,7 +70,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLTask
actionListener.onFailure(e);
}
} else {
actionListener.onFailure(new MLResourceNotFoundException("Fail to find task"));
actionListener.onFailure(new OpenSearchStatusException("Fail to find task", RestStatus.NOT_FOUND));
}
}, e -> {
if (e instanceof IndexNotFoundException) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import java.util.function.Supplier;

import org.apache.logging.log4j.util.Strings;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionListener;
import org.opensearch.action.delete.DeleteRequest;
import org.opensearch.action.get.GetRequest;
Expand Down Expand Up @@ -891,7 +892,7 @@ public void getModel(String modelId, String[] includes, String[] excludes, Actio
listener.onFailure(e);
}
} else {
listener.onFailure(new MLResourceNotFoundException("Fail to find model"));
listener.onFailure(new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND));
}
}, e -> { listener.onFailure(e); }));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
}, e -> {
log.error("Failed to get ML model", e);
try {
channel.sendResponse(new BytesRestResponse(channel, RestStatus.BAD_REQUEST, e));
channel.sendResponse(new BytesRestResponse(channel, RestStatus.NOT_FOUND, e));
} catch (IOException ex) {
log.error("Failed to send error response", ex);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@ public String getName() {
public List<Route> routes() {
return ImmutableList
.of(
new Route(
RestRequest.Method.PUT,
String.format(Locale.ROOT, "%s/model_groups/{%s}/_update", ML_BASE_URI, PARAMETER_MODEL_GROUP_ID)
)
new Route(RestRequest.Method.PUT, String.format(Locale.ROOT, "%s/model_groups/{%s}", ML_BASE_URI, PARAMETER_MODEL_GROUP_ID))
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ public void testDeleteModel_Success() throws IOException {
return null;
}).when(client).execute(any(), any(), any());

GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED);
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null);
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
Expand All @@ -172,16 +172,7 @@ public void testDeleteModel_Success_AlgorithmNotNull() throws IOException {
return null;
}).when(client).execute(any(), any(), any());

MLModel mlModel = MLModel
.builder()
.modelId("test_id")
.modelState(MLModelState.REGISTERED)
.algorithm(FunctionName.TEXT_EMBEDDING)
.build();
XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
BytesReference bytesReference = BytesReference.bytes(content);
GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null);
GetResponse getResponse = new GetResponse(getResult);
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null);
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
Expand Down Expand Up @@ -213,17 +204,8 @@ public void test_Success_ModelGroupIDNotNull_LastModelOfGroup() throws IOExcepti
return null;
}).when(client).search(any(), isA(ActionListener.class));

MLModel mlModel = MLModel
.builder()
.modelId("test_id")
.modelGroupId("modelGroupID")
.modelState(MLModelState.REGISTERED)
.algorithm(FunctionName.TEXT_EMBEDDING)
.build();
XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
BytesReference bytesReference = BytesReference.bytes(content);
GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null);
GetResponse getResponse = new GetResponse(getResult);
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, "modelGroupID");

doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
Expand Down Expand Up @@ -296,17 +278,8 @@ public void test_Failure_FailedToSearchLastModel() throws IOException {
return null;
}).when(client).search(any(), isA(ActionListener.class));

MLModel mlModel = MLModel
.builder()
.modelId("test_id")
.modelGroupId("modelGroupID")
.modelState(MLModelState.REGISTERED)
.algorithm(FunctionName.TEXT_EMBEDDING)
.build();
XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
BytesReference bytesReference = BytesReference.bytes(content);
GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null);
GetResponse getResponse = new GetResponse(getResult);
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, "modelGroupID");

doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
Expand All @@ -320,7 +293,7 @@ public void test_Failure_FailedToSearchLastModel() throws IOException {
}

public void test_UserHasNoAccessException() throws IOException {
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED);
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, "modelGroupID");
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
Expand All @@ -340,7 +313,7 @@ public void test_UserHasNoAccessException() throws IOException {
}

public void testDeleteModel_CheckModelState() throws IOException {
GetResponse getResponse = prepareMLModel(MLModelState.DEPLOYING);
GetResponse getResponse = prepareMLModel(MLModelState.DEPLOYING, null);
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
Expand Down Expand Up @@ -383,7 +356,7 @@ public void testDeleteModel_ResourceNotFoundException() throws IOException {
return null;
}).when(client).execute(any(), any(), any());

GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED);
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null);
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
Expand All @@ -397,7 +370,7 @@ public void testDeleteModel_ResourceNotFoundException() throws IOException {
}

public void test_ValidationFailedException() throws IOException {
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED);
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null);
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
Expand Down Expand Up @@ -441,7 +414,7 @@ public void testDeleteModelChunks_Success() {
}

public void testDeleteModel_RuntimeException() throws IOException {
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED);
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null);
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
Expand Down Expand Up @@ -535,8 +508,8 @@ public void test_FailToDeleteAllModelChunks_SearchFailure() {
assertEquals(OS_STATUS_EXCEPTION_MESSAGE + ", " + SEARCH_FAILURE_MSG + "test_id", argumentCaptor.getValue().getMessage());
}

public GetResponse prepareMLModel(MLModelState mlModelState) throws IOException {
MLModel mlModel = MLModel.builder().modelId("test_id").modelState(mlModelState).build();
public GetResponse prepareMLModel(MLModelState mlModelState, String modelGroupID) throws IOException {
MLModel mlModel = MLModel.builder().modelId("test_id").modelState(mlModelState).modelGroupId(modelGroupID).build();
XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
BytesReference bytesReference = BytesReference.bytes(content);
GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null);
Expand Down
Loading