Skip to content

Commit 88332a0

Browse files
committed
add sync up for planning worker nodes
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 60bdb41 commit 88332a0

File tree

4 files changed

+38
-3
lines changed

4 files changed

+38
-3
lines changed

plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,11 @@ public void syncWorkerNode(Set<String> workerNodes) {
146146
this.workerNodes.addAll(workerNodes);
147147
}
148148

149+
public void syncPlanningWorkerNode(Set<String> planningWorkerNodes) {
150+
this.targetWorkerNodes.clear();
151+
this.targetWorkerNodes.addAll(planningWorkerNodes);
152+
}
153+
149154
public boolean isDeployToAllNodes() {
150155
return this.deployToAllNodes != null && this.deployToAllNodes;
151156
}

plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,6 @@ public String[] getTargetWorkerNodes(String modelId) {
551551
return modelCache.getTargetWorkerNodes();
552552
}
553553

554-
555554
/**
556555
* Add worker node of model.
557556
*
@@ -624,6 +623,19 @@ public void syncWorkerNodes(Map<String, Set<String>> modelWorkerNodes) {
624623
});
625624
}
626625

626+
/**
627+
* Sync planning worker nodes for all models.
628+
*
629+
* @param modelPlanningWorkerNodes planning worker nodes of all models
630+
*/
631+
public void syncPlanningWorkerNodes(Map<String, Set<String>> modelPlanningWorkerNodes) {
632+
log.debug("sync model planning worker nodes");
633+
modelPlanningWorkerNodes.entrySet().forEach(entry -> {
634+
MLModelCache modelCache = getOrCreateModelCache(entry.getKey());
635+
modelCache.syncPlanningWorkerNode(entry.getValue());
636+
});
637+
}
638+
627639
/**
628640
* Clear worker nodes for all models.
629641
*/

plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2480,6 +2480,23 @@ public String[] getExpiredModels() {
24802480
*/
24812481
public synchronized void syncModelWorkerNodes(Map<String, Set<String>> modelWorkerNodes) {
24822482
modelCacheHelper.syncWorkerNodes(modelWorkerNodes);
2483+
2484+
syncModelPlanningWorkerNodes(modelWorkerNodes);
2485+
}
2486+
2487+
public synchronized void syncModelPlanningWorkerNodes(Map<String, Set<String>> modelWorkerNodes) {
2488+
Map<String, Set<String>> modelPlanningWorkerNodes = new HashMap<>();
2489+
modelWorkerNodes.keySet().forEach(modelId -> {
2490+
FunctionName functionName = modelCacheHelper.getFunctionName(modelId);
2491+
boolean isDeployToAll = modelCacheHelper.getDeployToAllNodes(modelId);
2492+
if (!isDeployToAll) {
2493+
return;
2494+
}
2495+
DiscoveryNode[] eligibleNodes = nodeHelper.getEligibleNodes(functionName);
2496+
Set<String> eligibleNodeIds = Arrays.stream(eligibleNodes).map(DiscoveryNode::getId).collect(Collectors.toSet());
2497+
modelPlanningWorkerNodes.put(modelId, eligibleNodeIds);
2498+
});
2499+
modelCacheHelper.syncPlanningWorkerNodes(modelPlanningWorkerNodes);
24832500
}
24842501

24852502
/**

plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,8 @@ public void validateOutputSchema(String modelId, ModelTensorOutput output) {
655655
}
656656

657657
private boolean requiresAutoDeployment(String[] workerNodes, String[] targetWorkerNodes) {
658-
return workerNodes == null || workerNodes.length == 0 ||
659-
(targetWorkerNodes != null && workerNodes.length < targetWorkerNodes.length);
658+
return workerNodes == null
659+
|| workerNodes.length == 0
660+
|| (targetWorkerNodes != null && workerNodes.length < targetWorkerNodes.length);
660661
}
661662
}

0 commit comments

Comments
 (0)