diff --git a/serving/src/main/java/ai/djl/serving/models/ModelManager.java b/serving/src/main/java/ai/djl/serving/models/ModelManager.java index a750b7951c..4437bd1336 100644 --- a/serving/src/main/java/ai/djl/serving/models/ModelManager.java +++ b/serving/src/main/java/ai/djl/serving/models/ModelManager.java @@ -83,7 +83,10 @@ public CompletableFuture registerWorkflow(Workflow workflow) { return CompletableFuture.supplyAsync( () -> { - for (ModelInfo model : workflow.getModels()) { + Map> models = workflow.getModelMap(); + for (Map.Entry> entry : models.entrySet()) { + String key = entry.getKey(); + ModelInfo model = entry.getValue(); try { // download model and configure per model settings model.initialize(); @@ -92,14 +95,21 @@ public CompletableFuture registerWorkflow(Workflow workflow) { String engine = model.getEngineName(); DependencyManager dm = DependencyManager.getInstance(); dm.installEngine(engine); - wlm.registerModel(model); + WorkerPool wp = wlm.getWorkerPool(model); + if (wp != null) { + models.put(key, wp.getModel()); + wp.increaseRef(); + logger.info("Model {} is registered by other workflow", model); + continue; + } + wlm.registerModel(model); String[] devices = model.getLoadOnDevices(); logger.info("Loading model on {}:{}", engine, Arrays.toString(devices)); for (String deviceName : devices) { int minWorkers = model.getMinWorkers(); int maxWorkers = model.getMaxWorkers(); - modelManager.initWorkers(model, deviceName, minWorkers, maxWorkers); + initWorkers(model, deviceName, minWorkers, maxWorkers); } } catch (IOException | ModelException e) { throw new CompletionException(e); diff --git a/serving/src/main/java/ai/djl/serving/workflow/Workflow.java b/serving/src/main/java/ai/djl/serving/workflow/Workflow.java index 6fa29bb443..c23b8268ca 100644 --- a/serving/src/main/java/ai/djl/serving/workflow/Workflow.java +++ b/serving/src/main/java/ai/djl/serving/workflow/Workflow.java @@ -115,6 +115,15 @@ public Collection> getModels() { return models.values(); } + /** + * Returns the model map in the workflow. + * + * @return the model map in the workflow + */ + public Map> getModelMap() { + return models; + } + /** * Executes a workflow with an input. * diff --git a/serving/src/test/java/ai/djl/serving/ModelServerTest.java b/serving/src/test/java/ai/djl/serving/ModelServerTest.java index cb581ad781..f8f761950a 100644 --- a/serving/src/test/java/ai/djl/serving/ModelServerTest.java +++ b/serving/src/test/java/ai/djl/serving/ModelServerTest.java @@ -264,7 +264,7 @@ public void testWorkflows() assertTrue(server.isRunning()); Channel channel = initTestChannel(); - testPredictionsModels(channel); + testPredictions(channel, new String[] {"/predictions/m"}); testPredictionsWorkflows(channel); channel.close().sync(); diff --git a/serving/src/test/java/ai/djl/serving/WorkflowTest.java b/serving/src/test/java/ai/djl/serving/WorkflowTest.java index 17f240642d..d4f60322db 100644 --- a/serving/src/test/java/ai/djl/serving/WorkflowTest.java +++ b/serving/src/test/java/ai/djl/serving/WorkflowTest.java @@ -78,7 +78,7 @@ public void testFunctions() throws IOException, BadWorkflowException { public void testLocalPerf() throws IOException, BadWorkflowException { Path workflowFile = Paths.get("src/test/resources/workflows/localPerf.json"); Workflow workflow = WorkflowDefinition.parse(workflowFile).toWorkflow(); - ModelInfo m = workflow.getModels().stream().findFirst().get(); + ModelInfo m = workflow.getModels().iterator().next(); Assert.assertEquals(m.getQueueSize(), 102); Assert.assertEquals(m.getMaxIdleSeconds(), 62); diff --git a/serving/src/test/resources/workflow.config.properties b/serving/src/test/resources/workflow.config.properties index 2190424a5e..46a0378963 100644 --- a/serving/src/test/resources/workflow.config.properties +++ b/serving/src/test/resources/workflow.config.properties @@ -2,7 +2,7 @@ inference_address=https://127.0.0.1:8443 management_address=https://127.0.0.1:8443 model_store=build/models -load_models=https://resources.djl.ai/test-models/mlp.tar.gz,[mlp:v1:MXNet:*]=https://resources.djl.ai/test-models/mlp.tar.gz,https://resources.djl.ai/test-models/basic-serving-workflow.json +load_models=m=https://resources.djl.ai/test-models/mlp.tar.gz,https://resources.djl.ai/test-models/basic-serving-workflow.json private_key_file=src/test/resources/key.pem certificate_file=src/test/resources/certs.pem max_request_size=10485760 diff --git a/wlm/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java b/wlm/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java index 90e62769ee..ea308518c3 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java @@ -71,8 +71,11 @@ public WorkerPool registerModel(ModelInfo modelInfo) { */ public void unregisterModel(ModelInfo model) { WorkerPool pool = getWorkerPool(model); - pool.shutdownWorkers(); - workerPools.remove(model); + if (pool.decreaseRef() <= 0) { + logger.info("Unloading model: {}", model); + pool.shutdownWorkers(); + workerPools.remove(model); + } } /** diff --git a/wlm/src/main/java/ai/djl/serving/wlm/WorkerPool.java b/wlm/src/main/java/ai/djl/serving/wlm/WorkerPool.java index c4dafb6ca3..d6917c8201 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/WorkerPool.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/WorkerPool.java @@ -28,6 +28,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; /** @@ -43,6 +44,7 @@ public class WorkerPool { private ExecutorService threadPool; private Map> workerGroups; private LinkedBlockingDeque> jobQueue; + private AtomicInteger refCnt; /** * Construct and initial data structure. @@ -54,9 +56,29 @@ public class WorkerPool { this.model = model; this.threadPool = threadPool; workerGroups = new ConcurrentHashMap<>(); + refCnt = new AtomicInteger(1); } - ModelInfo getModel() { + /** Increases the reference count. */ + public void increaseRef() { + refCnt.incrementAndGet(); + } + + /** + * Decrease the reference count and return the current count. + * + * @return the current count + */ + public int decreaseRef() { + return refCnt.decrementAndGet(); + } + + /** + * Returns the model of the worker pool. + * + * @return the model of the worker pool + */ + public ModelInfo getModel() { return model; }