Skip to content

Commit ba93cec

Browse files
committed
add sync up for planning worker nodes
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 60bdb41 commit ba93cec

File tree

3 files changed

+36
-1
lines changed

3 files changed

+36
-1
lines changed

plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,11 @@ public void syncWorkerNode(Set<String> workerNodes) {
146146
this.workerNodes.addAll(workerNodes);
147147
}
148148

149+
public void syncPlanningWorkerNode(Set<String> planningWorkerNodes) {
150+
this.targetWorkerNodes.clear();
151+
this.targetWorkerNodes.addAll(planningWorkerNodes);
152+
}
153+
149154
public boolean isDeployToAllNodes() {
150155
return this.deployToAllNodes != null && this.deployToAllNodes;
151156
}

plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,6 @@ public String[] getTargetWorkerNodes(String modelId) {
551551
return modelCache.getTargetWorkerNodes();
552552
}
553553

554-
555554
/**
556555
* Add worker node of model.
557556
*
@@ -624,6 +623,19 @@ public void syncWorkerNodes(Map<String, Set<String>> modelWorkerNodes) {
624623
});
625624
}
626625

626+
/**
627+
* Sync planning worker nodes for all models.
628+
*
629+
* @param modelPlanningWorkerNodes planning worker nodes of all models
630+
*/
631+
public void syncPlanningWorkerNodes(Map<String, Set<String>> modelPlanningWorkerNodes) {
632+
log.debug("sync model planning worker nodes");
633+
modelPlanningWorkerNodes.entrySet().forEach(entry -> {
634+
MLModelCache modelCache = getOrCreateModelCache(entry.getKey());
635+
modelCache.syncPlanningWorkerNode(entry.getValue());
636+
});
637+
}
638+
627639
/**
628640
* Clear worker nodes for all models.
629641
*/

plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2480,6 +2480,24 @@ public String[] getExpiredModels() {
24802480
*/
24812481
public synchronized void syncModelWorkerNodes(Map<String, Set<String>> modelWorkerNodes) {
24822482
modelCacheHelper.syncWorkerNodes(modelWorkerNodes);
2483+
2484+
syncModelPlanningWorkerNodes(modelWorkerNodes);
2485+
}
2486+
2487+
public synchronized void syncModelPlanningWorkerNodes(Map<String, Set<String>> modelWorkerNodes) {
2488+
Map<String, Set<String>> modelPlanningWorkerNodes = new HashMap<>();
2489+
modelWorkerNodes.keySet().forEach( modelId -> {
2490+
FunctionName functionName = modelCacheHelper.getFunctionName(modelId);
2491+
boolean isDeployToAll = modelCacheHelper.getDeployToAllNodes(modelId);
2492+
if (!isDeployToAll) { return; }
2493+
DiscoveryNode[] eligibleNodes = nodeHelper.getEligibleNodes(functionName);
2494+
Set<String> eligibleNodeIds = Arrays
2495+
.stream(eligibleNodes)
2496+
.map(DiscoveryNode::getId)
2497+
.collect(Collectors.toSet());
2498+
modelPlanningWorkerNodes.put(modelId, eligibleNodeIds);
2499+
});
2500+
modelCacheHelper.syncPlanningWorkerNodes(modelPlanningWorkerNodes);
24832501
}
24842502

24852503
/**

0 commit comments

Comments
 (0)