Skip to content

Commit f1dab23

Browse files
run auto deploy remote model in partially deployed status (opensearch-project#3423) (opensearch-project#3959)
* run auto deploy remote model in partially deployed status * add sync up for planning worker nodes * add more UTs and java doc * rename syncPlanningWorkerNodes from comments --------- (cherry picked from commit 8fff3f3) Signed-off-by: Xun Zhang <xunzh@amazon.com> Co-authored-by: Xun Zhang <xunzh@amazon.com>
1 parent 7db6ef7 commit f1dab23

File tree

6 files changed

+117
-1
lines changed

6 files changed

+117
-1
lines changed

plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,11 @@ public void syncWorkerNode(Set<String> workerNodes) {
146146
this.workerNodes.addAll(workerNodes);
147147
}
148148

149+
public void syncPlanningWorkerNodes(Set<String> planningWorkerNodes) {
150+
this.targetWorkerNodes.clear();
151+
this.targetWorkerNodes.addAll(planningWorkerNodes);
152+
}
153+
149154
public boolean isDeployToAllNodes() {
150155
return this.deployToAllNodes != null && this.deployToAllNodes;
151156
}

plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,20 @@ public String[] getWorkerNodes(String modelId) {
537537
return modelCache.getWorkerNodes();
538538
}
539539

540+
/**
541+
* Get target worker nodes of model.
542+
*
543+
* @param modelId model id
544+
* @return array of node id; return null if model not exists in cache
545+
*/
546+
public String[] getTargetWorkerNodes(String modelId) {
547+
MLModelCache modelCache = modelCaches.get(modelId);
548+
if (modelCache == null) {
549+
return null;
550+
}
551+
return modelCache.getTargetWorkerNodes();
552+
}
553+
540554
/**
541555
* Add worker node of model.
542556
*
@@ -609,6 +623,19 @@ public void syncWorkerNodes(Map<String, Set<String>> modelWorkerNodes) {
609623
});
610624
}
611625

626+
/**
627+
* Sync planning worker nodes for all models.
628+
*
629+
* @param modelPlanningWorkerNodes planning worker nodes of all models
630+
*/
631+
public void syncPlanningWorkerNodes(Map<String, Set<String>> modelPlanningWorkerNodes) {
632+
log.debug("sync model planning worker nodes");
633+
modelPlanningWorkerNodes.entrySet().forEach(entry -> {
634+
MLModelCache modelCache = getOrCreateModelCache(entry.getKey());
635+
modelCache.syncPlanningWorkerNodes(entry.getValue());
636+
});
637+
}
638+
612639
/**
613640
* Clear worker nodes for all models.
614641
*/

plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2436,6 +2436,16 @@ public int getWorkerNodesSize(String modelId, FunctionName functionName) {
24362436
return getWorkerNodes(modelId, functionName, false).length;
24372437
}
24382438

2439+
/**
2440+
* Get target/planning worker node of a specific model.
2441+
*
2442+
* @param modelId model id
2443+
* @return list of planning worker node ids
2444+
*/
2445+
public String[] getTargetWorkerNodes(String modelId) {
2446+
return modelCacheHelper.getTargetWorkerNodes(modelId);
2447+
}
2448+
24392449
/**
24402450
* Get predictable instance with model id.
24412451
*
@@ -2476,6 +2486,23 @@ public String[] getExpiredModels() {
24762486
*/
24772487
public synchronized void syncModelWorkerNodes(Map<String, Set<String>> modelWorkerNodes) {
24782488
modelCacheHelper.syncWorkerNodes(modelWorkerNodes);
2489+
2490+
syncModelPlanningWorkerNodes(modelWorkerNodes);
2491+
}
2492+
2493+
public synchronized void syncModelPlanningWorkerNodes(Map<String, Set<String>> modelWorkerNodes) {
2494+
Map<String, Set<String>> modelPlanningWorkerNodes = new HashMap<>();
2495+
modelWorkerNodes.keySet().forEach(modelId -> {
2496+
FunctionName functionName = modelCacheHelper.getFunctionName(modelId);
2497+
boolean isDeployToAll = modelCacheHelper.getDeployToAllNodes(modelId);
2498+
if (!isDeployToAll) {
2499+
return;
2500+
}
2501+
DiscoveryNode[] eligibleNodes = nodeHelper.getEligibleNodes(functionName);
2502+
Set<String> eligibleNodeIds = Arrays.stream(eligibleNodes).map(DiscoveryNode::getId).collect(Collectors.toSet());
2503+
modelPlanningWorkerNodes.put(modelId, eligibleNodeIds);
2504+
});
2505+
modelCacheHelper.syncPlanningWorkerNodes(modelPlanningWorkerNodes);
24792506
}
24802507

24812508
/**

plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,9 @@ public void dispatchTask(
180180
}
181181
}, listener::onFailure);
182182
String[] workerNodes = mlModelManager.getWorkerNodes(modelId, functionName, true);
183-
if (workerNodes == null || workerNodes.length == 0) {
183+
String[] targetWorkerNodes = mlModelManager.getTargetWorkerNodes(modelId);
184+
185+
if (requiresAutoDeployment(workerNodes, targetWorkerNodes)) {
184186
if (FunctionName.isAutoDeployEnabled(autoDeploymentEnabled, functionName)) {
185187
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
186188
mlModelManager.getModel(modelId, request.getTenantId(), ActionListener.runBefore(ActionListener.wrap(model -> {
@@ -651,4 +653,10 @@ public void validateOutputSchema(String modelId, ModelTensorOutput output) {
651653
}
652654
}
653655
}
656+
657+
private boolean requiresAutoDeployment(String[] workerNodes, String[] targetWorkerNodes) {
658+
return workerNodes == null
659+
|| workerNodes.length == 0
660+
|| (targetWorkerNodes != null && workerNodes.length < targetWorkerNodes.length);
661+
}
654662
}

plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,27 @@ public void testSyncWorkerNodes_NullModelState() {
298298
assertArrayEquals(new String[] { newNodeId }, cacheHelper.getWorkerNodes(modelId));
299299
}
300300

301+
public void testGetTargetWorkerNodes() {
302+
String[] workerNodes = cacheHelper.getTargetWorkerNodes(modelId);
303+
assertNull(workerNodes);
304+
String newNodeId = "new_node_id";
305+
Map<String, Set<String>> modelPlannningWorkerNodes = new HashMap<>();
306+
modelPlannningWorkerNodes.put(modelId, ImmutableSet.of(newNodeId));
307+
cacheHelper.syncPlanningWorkerNodes(modelPlannningWorkerNodes);
308+
workerNodes = cacheHelper.getTargetWorkerNodes(modelId);
309+
assertArrayEquals(new String[] { "new_node_id" }, workerNodes);
310+
311+
}
312+
313+
public void testSyncPlanningWorkerNodes() {
314+
String newNodeId = "new_node_id";
315+
Map<String, Set<String>> modelPlannningWorkerNodes = new HashMap<>();
316+
modelPlannningWorkerNodes.put(modelId, ImmutableSet.of(newNodeId));
317+
cacheHelper.syncPlanningWorkerNodes(modelPlannningWorkerNodes);
318+
assertArrayEquals(new String[] { modelId }, cacheHelper.getAllModels());
319+
assertArrayEquals(new String[] { newNodeId }, cacheHelper.getTargetWorkerNodes(modelId));
320+
}
321+
301322
public void testSyncWorkerNodes_ModelState() {
302323
String modelId2 = "model_id2";
303324
cacheHelper.initModelState(modelId2, MLModelState.DEPLOYED, FunctionName.TEXT_EMBEDDING, targetWorkerNodes, true);

plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import static org.mockito.Mockito.times;
2222
import static org.mockito.Mockito.verify;
2323
import static org.mockito.Mockito.when;
24+
import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE;
2425
import static org.opensearch.ml.common.MLTask.FUNCTION_NAME_FIELD;
2526
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_BATCH_INGESTION_BULK_SIZE;
2627
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MAX_BATCH_INFERENCE_TASKS;
@@ -50,6 +51,7 @@
5051
import static org.opensearch.ml.utils.TestHelper.copyFile;
5152

5253
import java.io.IOException;
54+
import java.net.InetAddress;
5355
import java.net.URISyntaxException;
5456
import java.nio.charset.StandardCharsets;
5557
import java.nio.file.Path;
@@ -65,6 +67,7 @@
6567
import java.util.concurrent.ExecutorService;
6668
import java.util.concurrent.atomic.AtomicBoolean;
6769
import java.util.function.Supplier;
70+
import java.util.stream.Collectors;
6871

6972
import org.junit.Before;
7073
import org.junit.Ignore;
@@ -75,11 +78,13 @@
7578
import org.mockito.Mock;
7679
import org.mockito.MockitoAnnotations;
7780
import org.opensearch.OpenSearchStatusException;
81+
import org.opensearch.Version;
7882
import org.opensearch.action.get.GetRequest;
7983
import org.opensearch.action.get.GetResponse;
8084
import org.opensearch.action.index.IndexResponse;
8185
import org.opensearch.action.update.UpdateRequest;
8286
import org.opensearch.action.update.UpdateResponse;
87+
import org.opensearch.cluster.node.DiscoveryNode;
8388
import org.opensearch.cluster.service.ClusterApplierService;
8489
import org.opensearch.cluster.service.ClusterService;
8590
import org.opensearch.common.settings.ClusterSettings;
@@ -90,6 +95,7 @@
9095
import org.opensearch.core.common.breaker.CircuitBreaker;
9196
import org.opensearch.core.common.breaker.CircuitBreakingException;
9297
import org.opensearch.core.common.bytes.BytesReference;
98+
import org.opensearch.core.common.transport.TransportAddress;
9399
import org.opensearch.core.index.shard.ShardId;
94100
import org.opensearch.core.rest.RestStatus;
95101
import org.opensearch.core.xcontent.NamedXContentRegistry;
@@ -986,6 +992,28 @@ public void testSyncModelWorkerNodes() {
986992
verify(modelCacheHelper).syncWorkerNodes(eq(modelWorkerNodes));
987993
}
988994

995+
public void testSyncModelPlanningWorkerNodes() {
996+
DiscoveryNode localNode = new DiscoveryNode(
997+
"foo1",
998+
"node1",
999+
new TransportAddress(InetAddress.getLoopbackAddress(), 9300),
1000+
Collections.emptyMap(),
1001+
Collections.singleton(CLUSTER_MANAGER_ROLE),
1002+
Version.CURRENT
1003+
);
1004+
1005+
Map<String, Set<String>> modelWorkerNodes = ImmutableMap.of(modelId, ImmutableSet.of("node1"));
1006+
when(modelCacheHelper.getFunctionName(modelId)).thenReturn(FunctionName.TEXT_EMBEDDING);
1007+
when(modelCacheHelper.getDeployToAllNodes(modelId)).thenReturn(true);
1008+
DiscoveryNode[] planningWorkerNodes = new DiscoveryNode[] { localNode };
1009+
when(nodeHelper.getEligibleNodes(FunctionName.TEXT_EMBEDDING)).thenReturn(planningWorkerNodes);
1010+
modelManager.syncModelPlanningWorkerNodes(modelWorkerNodes);
1011+
verify(modelCacheHelper)
1012+
.syncPlanningWorkerNodes(
1013+
Map.of(modelId, Arrays.stream(planningWorkerNodes).map(DiscoveryNode::getId).collect(Collectors.toSet()))
1014+
);
1015+
}
1016+
9891017
public void testClearRoutingTable() {
9901018
modelManager.clearRoutingTable();
9911019
verify(modelCacheHelper).clearWorkerNodes();

0 commit comments

Comments
 (0)