Skip to content

Commit 6f83b9f

Browse files
authored
fix no worker node exception for remote embedding model (opensearch-project#1482)
* fix no worker node exception for remote embedding model Signed-off-by: Yaliang Wu <ylwu@amazon.com> * only add model info to cache if model cache exist Signed-off-by: Yaliang Wu <ylwu@amazon.com> --------- Signed-off-by: Yaliang Wu <ylwu@amazon.com>
1 parent da5d829 commit 6f83b9f

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,16 @@ private void executePredict(
137137
String requestId = mlPredictionTaskRequest.getRequestID();
138138
log.debug("receive predict request " + requestId + " for model " + mlPredictionTaskRequest.getModelId());
139139
long startTime = System.nanoTime();
140+
// For remote text embedding model, neural search will set mlPredictionTaskRequest.getMlInput().getAlgorithm() as
141+
// TEXT_EMBEDDING. In ml-commons we should always use the real function name of model: REMOTE. So we try to get
142+
// from model cache first.
143+
FunctionName functionName = modelCacheHelper
144+
.getOptionalFunctionName(modelId)
145+
.orElse(mlPredictionTaskRequest.getMlInput().getAlgorithm());
140146
mlPredictTaskRunner
141147
.run(
142-
mlPredictionTaskRequest.getMlInput().getAlgorithm(),
148+
// This is by design to NOT use mlPredictionTaskRequest.getMlInput().getAlgorithm() here
149+
functionName,
143150
mlPredictionTaskRequest,
144151
transportService,
145152
ActionListener.runAfter(wrappedListener, () -> {

plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -431,8 +431,10 @@ public boolean getDeployToAllNodes(String modelId) {
431431
}
432432

433433
public void setModelInfo(String modelId, MLModel mlModel) {
434-
MLModelCache mlModelCache = getExistingModelCache(modelId);
435-
mlModelCache.setModelInfo(mlModel);
434+
MLModelCache mlModelCache = modelCaches.get(modelId);
435+
if (mlModelCache != null) {
436+
mlModelCache.setModelInfo(mlModel);
437+
}
436438
}
437439

438440
public MLModel getModelInfo(String modelId) {

0 commit comments

Comments
 (0)