Skip to content

Commit

Permalink
add remote predict thread pool (opensearch-project#2207)
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
ylwu-amzn authored and Zhangxunmt committed Mar 21, 2024
1 parent 7722020 commit 045915c
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc
public static final String EXECUTE_THREAD_POOL = "opensearch_ml_execute";
public static final String TRAIN_THREAD_POOL = "opensearch_ml_train";
public static final String PREDICT_THREAD_POOL = "opensearch_ml_predict";
public static final String REMOTE_PREDICT_THREAD_POOL = "opensearch_ml_predict_remote";
public static final String REGISTER_THREAD_POOL = "opensearch_ml_register";
public static final String DEPLOY_THREAD_POOL = "opensearch_ml_deploy";
public static final String ML_BASE_URI = "/_plugins/_ml";
Expand Down Expand Up @@ -824,9 +825,25 @@ public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
ML_THREAD_POOL_PREFIX + PREDICT_THREAD_POOL,
false
);
FixedExecutorBuilder remotePredictThreadPool = new FixedExecutorBuilder(
settings,
REMOTE_PREDICT_THREAD_POOL,
OpenSearchExecutors.allocatedProcessors(settings) * 4,
10000,
ML_THREAD_POOL_PREFIX + REMOTE_PREDICT_THREAD_POOL,
false
);

return ImmutableList
.of(generalThreadPool, registerModelThreadPool, deployModelThreadPool, executeThreadPool, trainThreadPool, predictThreadPool);
.of(
generalThreadPool,
registerModelThreadPool,
deployModelThreadPool,
executeThreadPool,
trainThreadPool,
predictThreadPool,
remotePredictThreadPool
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import static org.opensearch.ml.permission.AccessController.checkUserPermissions;
import static org.opensearch.ml.permission.AccessController.getUserContext;
import static org.opensearch.ml.plugin.MachineLearningPlugin.PREDICT_THREAD_POOL;
import static org.opensearch.ml.plugin.MachineLearningPlugin.REMOTE_PREDICT_THREAD_POOL;

import java.time.Instant;
import java.util.UUID;
Expand Down Expand Up @@ -162,13 +163,14 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener<MLTas
MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
Instant now = Instant.now();
String modelId = request.getModelId();
FunctionName functionName = request.getMlInput().getFunctionName();
MLTask mlTask = MLTask
.builder()
.taskId(UUID.randomUUID().toString())
.modelId(modelId)
.taskType(MLTaskType.PREDICTION)
.inputType(inputDataType)
.functionName(request.getMlInput().getFunctionName())
.functionName(functionName)
.state(MLTaskState.CREATED)
.workerNodes(ImmutableList.of(clusterService.localNode().getId()))
.createTime(now)
Expand All @@ -186,16 +188,22 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener<MLTas
handleAsyncMLTaskFailure(mlTask, e);
listener.onFailure(e);
});
mlInputDatasetHandler.parseSearchQueryInput(mlInput.getInputDataset(), threadedActionListener(dataFrameActionListener));
mlInputDatasetHandler
.parseSearchQueryInput(mlInput.getInputDataset(), threadedActionListener(functionName, dataFrameActionListener));
break;
case DATA_FRAME:
case TEXT_DOCS:
default:
threadPool.executor(PREDICT_THREAD_POOL).execute(() -> { predict(modelId, mlTask, mlInput, listener); });
String threadPoolName = getPredictThreadPool(functionName);
threadPool.executor(threadPoolName).execute(() -> { predict(modelId, mlTask, mlInput, listener); });
break;
}
}

private String getPredictThreadPool(FunctionName functionName) {
return functionName == FunctionName.REMOTE ? REMOTE_PREDICT_THREAD_POOL : PREDICT_THREAD_POOL;
}

private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListener<MLTaskResponse> listener) {
ActionListener<MLTaskResponse> internalListener = wrappedCleanupListener(listener, mlTask.getTaskId());
// track ML task count and add ML task into cache
Expand Down Expand Up @@ -287,7 +295,14 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe
handlePredictFailure(mlTask, internalListener, e, true, modelId);
});
GetRequest getRequest = new GetRequest(ML_MODEL_INDEX, mlTask.getModelId());
client.get(getRequest, threadedActionListener(ActionListener.runBefore(getModelListener, () -> context.restore())));
client
.get(
getRequest,
threadedActionListener(
mlTask.getFunctionName(),
ActionListener.runBefore(getModelListener, () -> context.restore())
)
);
} catch (Exception e) {
log.error("Failed to get model " + mlTask.getModelId(), e);
handlePredictFailure(mlTask, internalListener, e, true, modelId);
Expand All @@ -299,8 +314,9 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe
}
}

private <T> ThreadedActionListener<T> threadedActionListener(ActionListener<T> listener) {
return new ThreadedActionListener<>(log, threadPool, PREDICT_THREAD_POOL, listener, false);
private <T> ThreadedActionListener<T> threadedActionListener(FunctionName functionName, ActionListener<T> listener) {
String threadPoolName = getPredictThreadPool(functionName);
return new ThreadedActionListener<>(log, threadPool, threadPoolName, listener, false);
}

private void handlePredictFailure(
Expand Down

0 comments on commit 045915c

Please sign in to comment.