Skip to content

Commit

Permalink
fix no worker node exception for remote embedding model (opensearch-p…
Browse files Browse the repository at this point in the history
…roject#1482) (opensearch-project#1511)

* fix no worker node exception for remote embedding model

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* only add model info to cache if model cache exist

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

---------

Signed-off-by: Yaliang Wu <ylwu@amazon.com>
(cherry picked from commit 6f83b9f)

Co-authored-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
2 people authored and rbhavna committed Nov 16, 2023
1 parent 2a7a3ab commit 0c6565a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,16 @@ private void executePredict(
String requestId = mlPredictionTaskRequest.getRequestID();
log.debug("receive predict request " + requestId + " for model " + mlPredictionTaskRequest.getModelId());
long startTime = System.nanoTime();
// For remote text embedding model, neural search will set mlPredictionTaskRequest.getMlInput().getAlgorithm() as
// TEXT_EMBEDDING. In ml-commons we should always use the real function name of model: REMOTE. So we try to get
// from model cache first.
FunctionName functionName = modelCacheHelper
.getOptionalFunctionName(modelId)
.orElse(mlPredictionTaskRequest.getMlInput().getAlgorithm());
mlPredictTaskRunner
.run(
mlPredictionTaskRequest.getMlInput().getAlgorithm(),
// This is by design to NOT use mlPredictionTaskRequest.getMlInput().getAlgorithm() here
functionName,
mlPredictionTaskRequest,
transportService,
ActionListener.runAfter(wrappedListener, () -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,10 @@ public boolean getDeployToAllNodes(String modelId) {
}

public void setModelInfo(String modelId, MLModel mlModel) {
MLModelCache mlModelCache = getExistingModelCache(modelId);
mlModelCache.setModelInfo(mlModel);
MLModelCache mlModelCache = modelCaches.get(modelId);
if (mlModelCache != null) {
mlModelCache.setModelInfo(mlModel);
}
}

public MLModel getModelInfo(String modelId) {
Expand Down

0 comments on commit 0c6565a

Please sign in to comment.