Skip to content

Commit

Permalink
apply multi-tenancy and sdk client in Connector (Create + Get + Delete)
Browse files Browse the repository at this point in the history
Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>
  • Loading branch information
dhrubo-os committed Jan 12, 2025
1 parent 7e8b253 commit a795fce
Show file tree
Hide file tree
Showing 28 changed files with 1,484 additions and 500 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,21 @@ default ActionFuture<DeleteResponse> deleteConnector(String connectorId) {
return actionFuture;
}

/**
* Delete connector for remote model
* @param connectorId The id of the connector to delete
* @return the result future
*/
default ActionFuture<DeleteResponse> deleteConnector(String connectorId, String tenantId) {
PlainActionFuture<DeleteResponse> actionFuture = PlainActionFuture.newFuture();
deleteConnector(connectorId, tenantId, actionFuture);
return actionFuture;
}

void deleteConnector(String connectorId, ActionListener<DeleteResponse> listener);

void deleteConnector(String connectorId, String tenantId, ActionListener<DeleteResponse> listener);

/**
* Register model group
* For additional info on model group, refer: https://opensearch.org/docs/latest/ml-commons-plugin/model-access-control#registering-a-model-group
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ public void run(MLInput mlInput, Map<String, Object> args, ActionListener<MLOutp
mlInput.setParameters(mlAlgoParams);
switch (action) {
case TRAIN:
boolean asyncTask = args.containsKey(ASYNC) ? (boolean) args.get(ASYNC) : false;
boolean asyncTask = args.containsKey(ASYNC) && (boolean) args.get(ASYNC);
train(mlInput, asyncTask, listener);
break;
case PREDICT:
Expand Down Expand Up @@ -174,30 +174,19 @@ private ActionListener<MLModelGetResponse> getMlGetModelResponseActionListener(A
ActionListener<MLModelGetResponse> internalListener = ActionListener.wrap(predictionResponse -> {
listener.onResponse(predictionResponse.getMlModel());
}, listener::onFailure);
ActionListener<MLModelGetResponse> actionListener = wrapActionListener(internalListener, res -> {
MLModelGetResponse getResponse = MLModelGetResponse.fromActionResponse(res);
return getResponse;
});
return actionListener;
return wrapActionListener(internalListener, MLModelGetResponse::fromActionResponse);
}

@Override
public void deleteModel(String modelId, ActionListener<DeleteResponse> listener) {
MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId(modelId).build();

client.execute(MLModelDeleteAction.INSTANCE, mlModelDeleteRequest, ActionListener.wrap(deleteResponse -> {
listener.onResponse(deleteResponse);
}, listener::onFailure));
client.execute(MLModelDeleteAction.INSTANCE, mlModelDeleteRequest, ActionListener.wrap(listener::onResponse, listener::onFailure));
}

@Override
public void searchModel(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
client
.execute(
MLModelSearchAction.INSTANCE,
searchRequest,
ActionListener.wrap(searchResponse -> { listener.onResponse(searchResponse); }, listener::onFailure)
);
client.execute(MLModelSearchAction.INSTANCE, searchRequest, ActionListener.wrap(listener::onResponse, listener::onFailure));
}

@Override
Expand Down Expand Up @@ -238,19 +227,12 @@ public void getTask(String taskId, ActionListener<MLTask> listener) {
public void deleteTask(String taskId, ActionListener<DeleteResponse> listener) {
MLTaskDeleteRequest mlTaskDeleteRequest = MLTaskDeleteRequest.builder().taskId(taskId).build();

client.execute(MLTaskDeleteAction.INSTANCE, mlTaskDeleteRequest, ActionListener.wrap(deleteResponse -> {
listener.onResponse(deleteResponse);
}, listener::onFailure));
client.execute(MLTaskDeleteAction.INSTANCE, mlTaskDeleteRequest, ActionListener.wrap(listener::onResponse, listener::onFailure));
}

@Override
public void searchTask(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
client
.execute(
MLTaskSearchAction.INSTANCE,
searchRequest,
ActionListener.wrap(searchResponse -> { listener.onResponse(searchResponse); }, listener::onFailure)
);
client.execute(MLTaskSearchAction.INSTANCE, searchRequest, ActionListener.wrap(listener::onResponse, listener::onFailure));
}

@Override
Expand Down Expand Up @@ -280,9 +262,23 @@ public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, Actio
@Override
public void deleteConnector(String connectorId, ActionListener<DeleteResponse> listener) {
MLConnectorDeleteRequest connectorDeleteRequest = new MLConnectorDeleteRequest(connectorId);
client.execute(MLConnectorDeleteAction.INSTANCE, connectorDeleteRequest, ActionListener.wrap(deleteResponse -> {
listener.onResponse(deleteResponse);
}, listener::onFailure));
client
.execute(
MLConnectorDeleteAction.INSTANCE,
connectorDeleteRequest,
ActionListener.wrap(listener::onResponse, listener::onFailure)
);
}

@Override
public void deleteConnector(String connectorId, String tenantId, ActionListener<DeleteResponse> listener) {
MLConnectorDeleteRequest connectorDeleteRequest = new MLConnectorDeleteRequest(connectorId, tenantId);
client
.execute(
MLConnectorDeleteAction.INSTANCE,
connectorDeleteRequest,
ActionListener.wrap(listener::onResponse, listener::onFailure)
);
}

@Override
Expand All @@ -294,9 +290,7 @@ public void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentRespons
@Override
public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
MLAgentDeleteRequest agentDeleteRequest = new MLAgentDeleteRequest(agentId);
client.execute(MLAgentDeleteAction.INSTANCE, agentDeleteRequest, ActionListener.wrap(deleteResponse -> {
listener.onResponse(deleteResponse);
}, listener::onFailure));
client.execute(MLAgentDeleteAction.INSTANCE, agentDeleteRequest, ActionListener.wrap(listener::onResponse, listener::onFailure));
}

@Override
Expand Down Expand Up @@ -324,123 +318,78 @@ private ActionListener<MLToolsListResponse> getMlListToolsResponseActionListener
ActionListener<MLToolsListResponse> internalListener = ActionListener.wrap(mlModelListResponse -> {
listener.onResponse(mlModelListResponse.getToolMetadataList());
}, listener::onFailure);
ActionListener<MLToolsListResponse> actionListener = wrapActionListener(internalListener, res -> {
MLToolsListResponse getResponse = MLToolsListResponse.fromActionResponse(res);
return getResponse;
});
return actionListener;
return wrapActionListener(internalListener, MLToolsListResponse::fromActionResponse);
}

private ActionListener<MLToolGetResponse> getMlGetToolResponseActionListener(ActionListener<ToolMetadata> listener) {
ActionListener<MLToolGetResponse> internalListener = ActionListener.wrap(mlModelGetResponse -> {
listener.onResponse(mlModelGetResponse.getToolMetadata());
}, listener::onFailure);
ActionListener<MLToolGetResponse> actionListener = wrapActionListener(internalListener, res -> {
MLToolGetResponse getResponse = MLToolGetResponse.fromActionResponse(res);
return getResponse;
});
return actionListener;
return wrapActionListener(internalListener, MLToolGetResponse::fromActionResponse);
}

private ActionListener<MLConfigGetResponse> getMlGetConfigResponseActionListener(ActionListener<MLConfig> listener) {
ActionListener<MLConfigGetResponse> internalListener = ActionListener.wrap(mlConfigGetResponse -> {
listener.onResponse(mlConfigGetResponse.getMlConfig());
}, listener::onFailure);
ActionListener<MLConfigGetResponse> actionListener = wrapActionListener(internalListener, res -> {
MLConfigGetResponse getResponse = MLConfigGetResponse.fromActionResponse(res);
return getResponse;
});
return actionListener;
return wrapActionListener(internalListener, MLConfigGetResponse::fromActionResponse);
}

private ActionListener<MLRegisterAgentResponse> getMLRegisterAgentResponseActionListener(
ActionListener<MLRegisterAgentResponse> listener
) {
ActionListener<MLRegisterAgentResponse> actionListener = wrapActionListener(listener, res -> {
MLRegisterAgentResponse mlRegisterAgentResponse = MLRegisterAgentResponse.fromActionResponse(res);
return mlRegisterAgentResponse;
});
return actionListener;
return wrapActionListener(listener, MLRegisterAgentResponse::fromActionResponse);
}

private ActionListener<MLTaskGetResponse> getMLTaskResponseActionListener(ActionListener<MLTask> listener) {
ActionListener<MLTaskGetResponse> internalListener = ActionListener
.wrap(getResponse -> { listener.onResponse(getResponse.getMlTask()); }, listener::onFailure);
ActionListener<MLTaskGetResponse> actionListener = wrapActionListener(internalListener, response -> {
MLTaskGetResponse getResponse = MLTaskGetResponse.fromActionResponse(response);
return getResponse;
});
return actionListener;
return wrapActionListener(internalListener, MLTaskGetResponse::fromActionResponse);
}

private ActionListener<MLDeployModelResponse> getMlDeployModelResponseActionListener(ActionListener<MLDeployModelResponse> listener) {
ActionListener<MLDeployModelResponse> actionListener = wrapActionListener(listener, response -> {
MLDeployModelResponse deployModelResponse = MLDeployModelResponse.fromActionResponse(response);
return deployModelResponse;
});
return actionListener;
return wrapActionListener(listener, MLDeployModelResponse::fromActionResponse);
}

private ActionListener<MLUndeployModelsResponse> getMlUndeployModelsResponseActionListener(
ActionListener<MLUndeployModelsResponse> listener
) {
ActionListener<MLUndeployModelsResponse> actionListener = wrapActionListener(listener, response -> {
MLUndeployModelsResponse deployModelResponse = MLUndeployModelsResponse.fromActionResponse(response);
return deployModelResponse;
});
return actionListener;
return wrapActionListener(listener, MLUndeployModelsResponse::fromActionResponse);
}

private ActionListener<MLCreateConnectorResponse> getMlCreateConnectorResponseActionListener(
ActionListener<MLCreateConnectorResponse> listener
) {
ActionListener<MLCreateConnectorResponse> actionListener = wrapActionListener(listener, response -> {
MLCreateConnectorResponse createConnectorResponse = MLCreateConnectorResponse.fromActionResponse(response);
return createConnectorResponse;
});
return actionListener;
return wrapActionListener(listener, MLCreateConnectorResponse::fromActionResponse);
}

private ActionListener<MLRegisterModelGroupResponse> getMlRegisterModelGroupResponseActionListener(
ActionListener<MLRegisterModelGroupResponse> listener
) {
ActionListener<MLRegisterModelGroupResponse> actionListener = wrapActionListener(listener, response -> {
MLRegisterModelGroupResponse registerModelGroupResponse = MLRegisterModelGroupResponse.fromActionResponse(response);
return registerModelGroupResponse;
});
return actionListener;
return wrapActionListener(listener, MLRegisterModelGroupResponse::fromActionResponse);
}

private ActionListener<MLTaskResponse> getMlPredictionTaskResponseActionListener(ActionListener<MLOutput> listener) {
ActionListener<MLTaskResponse> internalListener = ActionListener.wrap(predictionResponse -> {
listener.onResponse(predictionResponse.getOutput());
}, listener::onFailure);
ActionListener<MLTaskResponse> actionListener = wrapActionListener(internalListener, res -> {
MLTaskResponse predictionResponse = MLTaskResponse.fromActionResponse(res);
return predictionResponse;
});
return actionListener;
return wrapActionListener(internalListener, MLTaskResponse::fromActionResponse);
}

private ActionListener<MLRegisterModelResponse> getMLRegisterModelResponseActionListener(
ActionListener<MLRegisterModelResponse> listener
) {
ActionListener<MLRegisterModelResponse> actionListener = wrapActionListener(listener, res -> {
MLRegisterModelResponse registerModelResponse = MLRegisterModelResponse.fromActionResponse(res);
return registerModelResponse;
});
return actionListener;
return wrapActionListener(listener, MLRegisterModelResponse::fromActionResponse);
}

private <T extends ActionResponse> ActionListener<T> wrapActionListener(
final ActionListener<T> listener,
final Function<ActionResponse, T> recreate
) {
ActionListener<T> actionListener = ActionListener.wrap(r -> {
return ActionListener.wrap(r -> {
listener.onResponse(recreate.apply(r));
;
}, e -> { listener.onFailure(e); });
return actionListener;
}, listener::onFailure);
}

private void validateMLInput(MLInput mlInput, boolean requireInput) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,11 @@ public void execute(FunctionName name, Input input, ActionListener<MLExecuteTask
listener.onResponse(mlExecuteTaskResponse);
}

@Override
public void deleteConnector(String connectorId, String tenantId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

@Override
public void deleteConnector(String connectorId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
import static org.opensearch.ml.common.input.Constants.ACTION;
import static org.opensearch.ml.common.input.Constants.ALGORITHM;
import static org.opensearch.ml.common.input.Constants.ASYNC;
import static org.opensearch.ml.common.input.Constants.KMEANS;
import static org.opensearch.ml.common.input.Constants.MODELID;
import static org.opensearch.ml.common.input.Constants.PREDICT;
Expand Down Expand Up @@ -251,6 +252,42 @@ public void predict() {
assertEquals(output, ((MLPredictionOutput) dataFrameArgumentCaptor.getValue()).getPredictionResult());
}

@Test
public void execute_train_asyncTask() {
String modelId = "test_model_id";
String status = "InProgress";
doAnswer(invocation -> {
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
MLTrainingOutput output = MLTrainingOutput.builder().status(status).modelId(modelId).build();
actionListener.onResponse(MLTaskResponse.builder().output(output).build());
return null;
}).when(client).execute(eq(MLTrainingTaskAction.INSTANCE), any(), any());

ArgumentCaptor<MLOutput> argumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
Map<String, Object> args = new HashMap<>();
args.put(ACTION, TRAIN);
args.put(ALGORITHM, KMEANS);
args.put(ASYNC, true);
MLInput mlInput = MLInput.builder().algorithm(FunctionName.SAMPLE_ALGO).inputDataset(input).build();
machineLearningNodeClient.run(mlInput, args, trainingActionListener);

verify(client).execute(eq(MLTrainingTaskAction.INSTANCE), isA(MLTrainingTaskRequest.class), any());
verify(trainingActionListener).onResponse(argumentCaptor.capture());
assertEquals(modelId, ((MLTrainingOutput) argumentCaptor.getValue()).getModelId());
assertEquals(status, ((MLTrainingOutput) argumentCaptor.getValue()).getStatus());
}

@Test
public void execute_predict_missing_modelId() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("The model ID is required for prediction.");
Map<String, Object> args = new HashMap<>();
args.put(ACTION, PREDICT);
args.put(ALGORITHM, KMEANS);
MLInput mlInput = MLInput.builder().algorithm(FunctionName.SAMPLE_ALGO).inputDataset(input).build();
machineLearningNodeClient.run(mlInput, args, dataFrameActionListener);
}

@Test
public void predict_Exception_WithNullAlgorithm() {
exceptionRule.expect(IllegalArgumentException.class);
Expand Down Expand Up @@ -288,6 +325,27 @@ public void train() {
assertEquals(status, ((MLTrainingOutput) argumentCaptor.getValue()).getStatus());
}

@Test
public void registerModelGroup_withValidInput() {
doAnswer(invocation -> {
ActionListener<MLRegisterModelGroupResponse> actionListener = invocation.getArgument(2);
MLRegisterModelGroupResponse output = new MLRegisterModelGroupResponse("groupId", "created");
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLRegisterModelGroupAction.INSTANCE), any(), any());

MLRegisterModelGroupInput input = MLRegisterModelGroupInput
.builder()
.name("test")
.description("description")
.backendRoles(Arrays.asList("role1", "role2"))
.modelAccessMode(AccessMode.PUBLIC)
.build();

machineLearningNodeClient.registerModelGroup(input, registerModelGroupResponseActionListener);
verify(client).execute(eq(MLRegisterModelGroupAction.INSTANCE), isA(MLRegisterModelGroupRequest.class), any());
}

@Test
public void train_Exception_WithNullDataSet() {
exceptionRule.expect(IllegalArgumentException.class);
Expand Down Expand Up @@ -499,6 +557,26 @@ public void getModel() {
assertEquals(modelContent, argumentCaptor.getValue().getContent());
}

@Test
public void deleteConnector_withTenantId() {
String connectorId = "connectorId";
String tenantId = "tenantId";
doAnswer(invocation -> {
ActionListener<DeleteResponse> actionListener = invocation.getArgument(2);
ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1);
DeleteResponse output = new DeleteResponse(shardId, connectorId, 1, 1, 1, true);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLConnectorDeleteAction.INSTANCE), any(), any());

ArgumentCaptor<DeleteResponse> argumentCaptor = ArgumentCaptor.forClass(DeleteResponse.class);
machineLearningNodeClient.deleteConnector(connectorId, tenantId, deleteConnectorActionListener);

verify(client).execute(eq(MLConnectorDeleteAction.INSTANCE), isA(MLConnectorDeleteRequest.class), any());
verify(deleteConnectorActionListener).onResponse(argumentCaptor.capture());
assertEquals(connectorId, (argumentCaptor.getValue()).getId());
}

@Test
public void deleteModel() {
String modelId = "testModelId";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ public AwsConnector(
List<String> backendRoles,
AccessMode accessMode,
User owner,
ConnectorClientConfig connectorClientConfig
ConnectorClientConfig connectorClientConfig,
String tenantId
) {
super(
name,
Expand All @@ -54,7 +55,8 @@ public AwsConnector(
backendRoles,
accessMode,
owner,
connectorClientConfig
connectorClientConfig,
tenantId
);
validate();
}
Expand Down
Loading

0 comments on commit a795fce

Please sign in to comment.