Skip to content

Commit

Permalink
enable auto redeploy for hidden model (#2102) (#2136)
Browse files Browse the repository at this point in the history
* enable auto redeploy for hidden model

Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com>
(cherry picked from commit 9567ca5)

Co-authored-by: Bhavana Ramaram <rbhavna@amazon.com>
  • Loading branch information
opensearch-trigger-bot[bot] and rbhavna authored Feb 20, 2024
1 parent 235fb9c commit a31215d
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,32 @@ public class MLDeployModelRequest extends MLTaskRequest {
private String modelId;
private String[] modelNodeIds;
boolean async;
// This is to identify if the deploy request is initiated by user or not. During auto redeploy also, we perform deploy operation.
// This field is mainly to distinguish between these two situations.
private final boolean isUserInitiatedDeployRequest;

@Builder
public MLDeployModelRequest(String modelId, String[] modelNodeIds, boolean async, boolean dispatchTask) {
public MLDeployModelRequest(String modelId, String[] modelNodeIds, boolean async, boolean dispatchTask, boolean isUserInitiatedDeployRequest) {
super(dispatchTask);
this.modelId = modelId;
this.modelNodeIds = modelNodeIds;
this.async = async;
this.isUserInitiatedDeployRequest = isUserInitiatedDeployRequest;
}

// In this constructor, isUserInitiatedDeployRequest to always set to true. So, it can be used only when
// deploy request is coming directly from the user. DO NOT use this when the
// deploy call is from the code or system initiated.
public MLDeployModelRequest(String modelId, boolean async) {
this(modelId, null, async, true);
this(modelId, null, async, true, true);
}

public MLDeployModelRequest(StreamInput in) throws IOException {
super(in);
this.modelId = in.readString();
this.modelNodeIds = in.readOptionalStringArray();
this.async = in.readBoolean();
this.isUserInitiatedDeployRequest = in.readBoolean();
}

@Override
Expand All @@ -74,6 +82,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
out.writeOptionalStringArray(modelNodeIds);
out.writeBoolean(async);
out.writeBoolean(isUserInitiatedDeployRequest);
}

public static MLDeployModelRequest parse(XContentParser parser, String modelId) throws IOException {
Expand All @@ -96,7 +105,7 @@ public static MLDeployModelRequest parse(XContentParser parser, String modelId)
}
}
String[] nodeIds = nodeIdList == null ? null : nodeIdList.toArray(new String[0]);
return new MLDeployModelRequest(modelId, nodeIds, false, true);
return new MLDeployModelRequest(modelId, nodeIds, false, true, true);
}

public static MLDeployModelRequest fromActionRequest(ActionRequest actionRequest) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ public TransportDeployModelAction(
protected void doExecute(Task task, ActionRequest request, ActionListener<MLDeployModelResponse> listener) {
MLDeployModelRequest deployModelRequest = MLDeployModelRequest.fromActionRequest(request);
String modelId = deployModelRequest.getModelId();
Boolean isUserInitiatedDeployRequest = deployModelRequest.isUserInitiatedDeployRequest();
User user = RestActionUtils.getUserContext(client);
boolean isSuperAdmin = isSuperAdminUserWrapper(clusterService, client);
String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD };
Expand All @@ -143,7 +144,9 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
if (functionName == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
}
if (isHidden != null && isHidden) {
if (!isUserInitiatedDeployRequest) {
deployModel(deployModelRequest, mlModel, modelId, wrappedListener, listener);
} else if (isHidden != null && isHidden) {
if (isSuperAdmin) {
deployModel(deployModelRequest, mlModel, modelId, wrappedListener, listener);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ private void triggerModelRedeploy(ModelAutoRedeployArrangement modelAutoRedeploy
ImmutableMap.of(MLModel.AUTO_REDEPLOY_RETRY_TIMES_FIELD, Optional.ofNullable(autoRedeployRetryTimes).orElse(0) + 1)
);

MLDeployModelRequest deployModelRequest = new MLDeployModelRequest(modelId, nodeIds, false, true);
MLDeployModelRequest deployModelRequest = new MLDeployModelRequest(modelId, nodeIds, false, true, false);
client.execute(MLDeployModelAction.INSTANCE, deployModelRequest, listener);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@ private void updateModelRegisterStateAsDone(
void deployModelAfterRegistering(MLRegisterModelInput registerModelInput, String modelId) {
String[] modelNodeIds = registerModelInput.getModelNodeIds();
log.debug("start deploying model after registering, modelId: {} on nodes: {}", modelId, Arrays.toString(modelNodeIds));
MLDeployModelRequest request = new MLDeployModelRequest(modelId, modelNodeIds, false, true);
MLDeployModelRequest request = new MLDeployModelRequest(modelId, modelNodeIds, false, true, true);
ActionListener<MLDeployModelResponse> listener = ActionListener
.wrap(r -> log.debug("model deployed, response {}", r), e -> log.error("Failed to deploy model", e));
client.execute(MLDeployModelAction.INSTANCE, request, listener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ public void setup() {
return null;
}).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any());

when(mlDeployModelRequest.isUserInitiatedDeployRequest()).thenReturn(true);

when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true);

MLStat mlStat = mock(MLStat.class);
Expand Down Expand Up @@ -218,6 +220,30 @@ public void testDoExecute_success() {
verify(deployModelResponseListener).onResponse(any(MLDeployModelResponse.class));
}

public void testDoExecute_success_not_userInitiatedRequest() {
MLModel mlModel = mock(MLModel.class);
when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION);
doAnswer(invocation -> {
ActionListener<MLModel> listener = invocation.getArgument(3);
listener.onResponse(mlModel);
return null;
}).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class));

when(mlDeployModelRequest.isUserInitiatedDeployRequest()).thenReturn(false);

IndexResponse indexResponse = mock(IndexResponse.class);
when(indexResponse.getId()).thenReturn("mockIndexId");
doAnswer(invocation -> {
ActionListener<IndexResponse> listener = invocation.getArgument(1);
listener.onResponse(indexResponse);
return null;
}).when(mlTaskManager).createMLTask(any(MLTask.class), Mockito.isA(ActionListener.class));

ActionListener<MLDeployModelResponse> deployModelResponseListener = mock(ActionListener.class);
transportDeployModelAction.doExecute(mock(Task.class), mlDeployModelRequest, deployModelResponseListener);
verify(deployModelResponseListener).onResponse(any(MLDeployModelResponse.class));
}

public void testDoExecute_success_hidden_model() {
transportDeployModelAction = spy(
new TransportDeployModelAction(
Expand Down

0 comments on commit a31215d

Please sign in to comment.