Skip to content

Commit

Permalink
do not allow non super admin users to undeploy hidden models (opensea…
Browse files Browse the repository at this point in the history
…rch-project#1981) (opensearch-project#1983)

* do not allow non super admin users to undeploy hidden models

Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com>
(cherry picked from commit 4de949c)

Co-authored-by: Bhavana Ramaram <rbhavna@amazon.com>
  • Loading branch information
opensearch-trigger-bot[bot] and rbhavna authored Feb 2, 2024
1 parent bfebdd8 commit f63c4df
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 12 deletions.
1 change: 1 addition & 0 deletions plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ List<String> jacocoExclusions = [
'org.opensearch.ml.profile.MLPredictRequestStats',
'org.opensearch.ml.action.deploy.TransportDeployModelAction',
'org.opensearch.ml.action.deploy.TransportDeployModelOnNodeAction',
'org.opensearch.ml.action.undeploy.TransportUndeployModelsAction',
'org.opensearch.ml.action.prediction.TransportPredictionTaskAction',
'org.opensearch.ml.action.prediction.TransportPredictionTaskAction.1',
'org.opensearch.ml.action.tasks.GetTaskTransportAction',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,16 @@

package org.opensearch.ml.action.undeploy;

import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
Expand All @@ -18,6 +26,10 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
Expand All @@ -32,6 +44,8 @@
import org.opensearch.ml.task.MLTaskDispatcher;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
Expand Down Expand Up @@ -93,27 +107,49 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLUnde
String[] modelIds = undeployModelsRequest.getModelIds();
String[] targetNodeIds = undeployModelsRequest.getNodeIds();

if (modelAccessControlHelper.isModelAccessControlEnabled()) {
// Only allow user undeploy one model if model access control enabled.
if (modelIds == null || modelIds.length != 1) {
throw new IllegalArgumentException("only support undeploy one model");
}

if (modelIds.length == 1) {
String modelId = modelIds[0];
validateAccess(modelId, ActionListener.wrap(hasPermissionToUndeploy -> {
if (hasPermissionToUndeploy) {
MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest(targetNodeIds, modelIds);

client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(r -> {
listener.onResponse(new MLUndeployModelsResponse(r));
}, e -> { listener.onFailure(e); }));
undeployModels(targetNodeIds, modelIds, listener);
} else {
listener.onFailure(new IllegalArgumentException("No permission to undeploy model " + modelId));
}
}, listener::onFailure));
return;
} else {
// Only allow user to undeploy one model if model access control enabled.
// With multiple models, it is difficult to check to which models user has access to.
if (modelAccessControlHelper.isModelAccessControlEnabled()) {
throw new IllegalArgumentException("only support undeploy one model");
} else {
searchHiddenModels(modelIds, ActionListener.wrap(hiddenModels -> {
if (hiddenModels != null
&& hiddenModels.getHits().getTotalHits() != null
&& hiddenModels.getHits().getTotalHits().value != 0
&& !isSuperAdminUserWrapper(clusterService, client)) {
List<String> hiddenModelIds = Arrays
.stream(hiddenModels.getHits().getHits())
.map(SearchHit::getId)
.collect(Collectors.toList());

String[] modelsIDsToUndeploy = Arrays
.stream(modelIds)
.filter(modelId -> !hiddenModelIds.contains(modelId))
.toArray(String[]::new);

undeployModels(targetNodeIds, modelsIDsToUndeploy, listener);
} else {
undeployModels(targetNodeIds, modelIds, listener);
}
}, e -> {
log.error("Failed to search model index", e);
listener.onFailure(e);
}));
}
}
}

private void undeployModels(String[] targetNodeIds, String[] modelIds, ActionListener<MLUndeployModelsResponse> listener) {
MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest(targetNodeIds, modelIds);

client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(r -> {
Expand Down Expand Up @@ -153,6 +189,42 @@ private void validateAccess(String modelId, ActionListener<Boolean> listener) {
}
}

public void searchHiddenModels(String[] modelIds, ActionListener<SearchResponse> listener) throws IllegalArgumentException {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
// Create a TermsQueryBuilder for MODEL_ID_FIELD using the modelIds
TermsQueryBuilder termsQuery = QueryBuilders.termsQuery("_id", modelIds);

// Create a TermQueryBuilder for IS_HIDDEN_FIELD with value true
TermQueryBuilder isHiddenQuery = QueryBuilders.termQuery(MLModel.IS_HIDDEN_FIELD, true);

// Create an existsQuery to exclude model chunks
// Combine the queries using a bool query with must and mustNot clause
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder
.query(
QueryBuilders
.boolQuery()
.must(termsQuery)
.must(isHiddenQuery)
.mustNot(QueryBuilders.existsQuery(MLModel.CHUNK_NUMBER_FIELD))
);

SearchRequest searchRequest = new SearchRequest(ML_MODEL_INDEX).source(searchSourceBuilder);

client.search(searchRequest, ActionListener.runBefore(ActionListener.wrap(models -> { listener.onResponse(models); }, e -> {
if (e instanceof IndexNotFoundException) {
listener.onResponse(null);
} else {
log.error("Failed to search model index", e);
listener.onFailure(e);
}
}), () -> context.restore()));
} catch (Exception e) {
log.error("Failed to search model index", e);
listener.onFailure(e);
}
}

@VisibleForTesting
boolean isSuperAdminUserWrapper(ClusterService clusterService, Client client) {
return RestActionUtils.isSuperAdminUser(clusterService, client);
Expand Down

0 comments on commit f63c4df

Please sign in to comment.