Skip to content

Commit

Permalink
[workflow] Allows model being shared between workflows (#665)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Apr 25, 2023
1 parent fb68c5b commit c77cb15
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 9 deletions.
16 changes: 13 additions & 3 deletions serving/src/main/java/ai/djl/serving/models/ModelManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ public CompletableFuture<Void> registerWorkflow(Workflow workflow) {

return CompletableFuture.supplyAsync(
() -> {
for (ModelInfo<Input, Output> model : workflow.getModels()) {
Map<String, ModelInfo<Input, Output>> models = workflow.getModelMap();
for (Map.Entry<String, ModelInfo<Input, Output>> entry : models.entrySet()) {
String key = entry.getKey();
ModelInfo<Input, Output> model = entry.getValue();
try {
// download model and configure per model settings
model.initialize();
Expand All @@ -92,14 +95,21 @@ public CompletableFuture<Void> registerWorkflow(Workflow workflow) {
String engine = model.getEngineName();
DependencyManager dm = DependencyManager.getInstance();
dm.installEngine(engine);
wlm.registerModel(model);
WorkerPool<Input, Output> wp = wlm.getWorkerPool(model);
if (wp != null) {
models.put(key, wp.getModel());
wp.increaseRef();
logger.info("Model {} is registered by other workflow", model);
continue;
}

wlm.registerModel(model);
String[] devices = model.getLoadOnDevices();
logger.info("Loading model on {}:{}", engine, Arrays.toString(devices));
for (String deviceName : devices) {
int minWorkers = model.getMinWorkers();
int maxWorkers = model.getMaxWorkers();
modelManager.initWorkers(model, deviceName, minWorkers, maxWorkers);
initWorkers(model, deviceName, minWorkers, maxWorkers);
}
} catch (IOException | ModelException e) {
throw new CompletionException(e);
Expand Down
9 changes: 9 additions & 0 deletions serving/src/main/java/ai/djl/serving/workflow/Workflow.java
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,15 @@ public Collection<ModelInfo<Input, Output>> getModels() {
return models.values();
}

/**
* Returns the model map in the workflow.
*
* @return the model map in the workflow
*/
public Map<String, ModelInfo<Input, Output>> getModelMap() {
return models;
}

/**
* Executes a workflow with an input.
*
Expand Down
2 changes: 1 addition & 1 deletion serving/src/test/java/ai/djl/serving/ModelServerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ public void testWorkflows()
assertTrue(server.isRunning());
Channel channel = initTestChannel();

testPredictionsModels(channel);
testPredictions(channel, new String[] {"/predictions/m"});
testPredictionsWorkflows(channel);

channel.close().sync();
Expand Down
2 changes: 1 addition & 1 deletion serving/src/test/java/ai/djl/serving/WorkflowTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public void testFunctions() throws IOException, BadWorkflowException {
public void testLocalPerf() throws IOException, BadWorkflowException {
Path workflowFile = Paths.get("src/test/resources/workflows/localPerf.json");
Workflow workflow = WorkflowDefinition.parse(workflowFile).toWorkflow();
ModelInfo<Input, Output> m = workflow.getModels().stream().findFirst().get();
ModelInfo<Input, Output> m = workflow.getModels().iterator().next();

Assert.assertEquals(m.getQueueSize(), 102);
Assert.assertEquals(m.getMaxIdleSeconds(), 62);
Expand Down
2 changes: 1 addition & 1 deletion serving/src/test/resources/workflow.config.properties
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
inference_address=https://127.0.0.1:8443
management_address=https://127.0.0.1:8443
model_store=build/models
load_models=https://resources.djl.ai/test-models/mlp.tar.gz,[mlp:v1:MXNet:*]=https://resources.djl.ai/test-models/mlp.tar.gz,https://resources.djl.ai/test-models/basic-serving-workflow.json
load_models=m=https://resources.djl.ai/test-models/mlp.tar.gz,https://resources.djl.ai/test-models/basic-serving-workflow.json
private_key_file=src/test/resources/key.pem
certificate_file=src/test/resources/certs.pem
max_request_size=10485760
7 changes: 5 additions & 2 deletions wlm/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,11 @@ public <I, O> WorkerPool<I, O> registerModel(ModelInfo<I, O> modelInfo) {
*/
public void unregisterModel(ModelInfo<?, ?> model) {
WorkerPool<?, ?> pool = getWorkerPool(model);
pool.shutdownWorkers();
workerPools.remove(model);
if (pool.decreaseRef() <= 0) {
logger.info("Unloading model: {}", model);
pool.shutdownWorkers();
workerPools.remove(model);
}
}

/**
Expand Down
24 changes: 23 additions & 1 deletion wlm/src/main/java/ai/djl/serving/wlm/WorkerPool.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

/**
Expand All @@ -43,6 +44,7 @@ public class WorkerPool<I, O> {
private ExecutorService threadPool;
private Map<Device, WorkerGroup<I, O>> workerGroups;
private LinkedBlockingDeque<WorkerJob<I, O>> jobQueue;
private AtomicInteger refCnt;

/**
* Construct and initial data structure.
Expand All @@ -54,9 +56,29 @@ public class WorkerPool<I, O> {
this.model = model;
this.threadPool = threadPool;
workerGroups = new ConcurrentHashMap<>();
refCnt = new AtomicInteger(1);
}

ModelInfo<I, O> getModel() {
/** Increases the reference count. */
public void increaseRef() {
refCnt.incrementAndGet();
}

/**
* Decrease the reference count and return the current count.
*
* @return the current count
*/
public int decreaseRef() {
return refCnt.decrementAndGet();
}

/**
* Returns the model of the worker pool.
*
* @return the model of the worker pool
*/
public ModelInfo<I, O> getModel() {
return model;
}

Expand Down

0 comments on commit c77cb15

Please sign in to comment.