Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

do not allow non super admin users to undeploy hidden models #1981

Merged
merged 2 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ List<String> jacocoExclusions = [
'org.opensearch.ml.profile.MLPredictRequestStats',
'org.opensearch.ml.action.deploy.TransportDeployModelAction',
'org.opensearch.ml.action.deploy.TransportDeployModelOnNodeAction',
'org.opensearch.ml.action.undeploy.TransportUndeployModelsAction',
'org.opensearch.ml.action.prediction.TransportPredictionTaskAction',
'org.opensearch.ml.action.prediction.TransportPredictionTaskAction.1',
'org.opensearch.ml.action.tasks.GetTaskTransportAction',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,16 @@

package org.opensearch.ml.action.undeploy;

import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
Expand All @@ -18,6 +26,10 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
Expand All @@ -32,6 +44,8 @@
import org.opensearch.ml.task.MLTaskDispatcher;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
Expand Down Expand Up @@ -93,27 +107,49 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLUnde
String[] modelIds = undeployModelsRequest.getModelIds();
String[] targetNodeIds = undeployModelsRequest.getNodeIds();

if (modelAccessControlHelper.isModelAccessControlEnabled()) {
// Only allow user undeploy one model if model access control enabled.
if (modelIds == null || modelIds.length != 1) {
throw new IllegalArgumentException("only support undeploy one model");
}

if (modelIds.length == 1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when there is only one model, why there is no need to check whether it's hidden model or not?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do check it. All validations are done here when there is only one model

String modelId = modelIds[0];
validateAccess(modelId, ActionListener.wrap(hasPermissionToUndeploy -> {
if (hasPermissionToUndeploy) {
MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest(targetNodeIds, modelIds);

client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(r -> {
listener.onResponse(new MLUndeployModelsResponse(r));
}, e -> { listener.onFailure(e); }));
undeployModels(targetNodeIds, modelIds, listener);
} else {
listener.onFailure(new IllegalArgumentException("No permission to undeploy model " + modelId));
}
}, listener::onFailure));
return;
} else {
// Only allow user to undeploy one model if model access control enabled.
// With multiple models, it is difficult to check to which models user has access to.
if (modelAccessControlHelper.isModelAccessControlEnabled()) {
throw new IllegalArgumentException("only support undeploy one model");
} else {
searchHiddenModels(modelIds, ActionListener.wrap(hiddenModels -> {
if (hiddenModels != null
&& hiddenModels.getHits().getTotalHits() != null
&& hiddenModels.getHits().getTotalHits().value != 0
&& !isSuperAdminUserWrapper(clusterService, client)) {
List<String> hiddenModelIds = Arrays
.stream(hiddenModels.getHits().getHits())
.map(SearchHit::getId)
.collect(Collectors.toList());

String[] modelsIDsToUndeploy = Arrays
.stream(modelIds)
.filter(modelId -> !hiddenModelIds.contains(modelId))
.toArray(String[]::new);

undeployModels(targetNodeIds, modelsIDsToUndeploy, listener);
} else {
undeployModels(targetNodeIds, modelIds, listener);
}
}, e -> {
log.error("Failed to search model index", e);
listener.onFailure(e);
}));
}
}
}

private void undeployModels(String[] targetNodeIds, String[] modelIds, ActionListener<MLUndeployModelsResponse> listener) {
MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest(targetNodeIds, modelIds);

client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(r -> {
Expand Down Expand Up @@ -153,6 +189,42 @@ private void validateAccess(String modelId, ActionListener<Boolean> listener) {
}
}

public void searchHiddenModels(String[] modelIds, ActionListener<SearchResponse> listener) throws IllegalArgumentException {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
// Create a TermsQueryBuilder for MODEL_ID_FIELD using the modelIds
TermsQueryBuilder termsQuery = QueryBuilders.termsQuery("_id", modelIds);

// Create a TermQueryBuilder for IS_HIDDEN_FIELD with value true
TermQueryBuilder isHiddenQuery = QueryBuilders.termQuery(MLModel.IS_HIDDEN_FIELD, true);

// Create an existsQuery to exclude model chunks
// Combine the queries using a bool query with must and mustNot clause
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder
.query(
QueryBuilders
.boolQuery()
.must(termsQuery)
.must(isHiddenQuery)
.mustNot(QueryBuilders.existsQuery(MLModel.CHUNK_NUMBER_FIELD))
);

SearchRequest searchRequest = new SearchRequest(ML_MODEL_INDEX).source(searchSourceBuilder);

client.search(searchRequest, ActionListener.runBefore(ActionListener.wrap(models -> { listener.onResponse(models); }, e -> {
if (e instanceof IndexNotFoundException) {
listener.onResponse(null);
} else {
log.error("Failed to search model index", e);
listener.onFailure(e);
}
}), () -> context.restore()));
} catch (Exception e) {
log.error("Failed to search model index", e);
listener.onFailure(e);
}
}

@VisibleForTesting
boolean isSuperAdminUserWrapper(ClusterService clusterService, Client client) {
return RestActionUtils.isSuperAdminUser(clusterService, client);
Expand Down
Loading