From 94656214c22fd8635f51149a80f5e9697da36032 Mon Sep 17 00:00:00 2001 From: Zach Kimberg Date: Tue, 28 Sep 2021 16:04:35 -0700 Subject: [PATCH] Separate WorkLoadManager This separates the WorkLoadManagement functionality from the main djl-serving into a separate module. The module first enables isolated usage without serving. It also helps convert the wlm to focus on CompletableFutures rather than netty. This makes it easier to leverage from pure Java. As a particular case, it also enables Translators to leverage the wlm. --- README.md | 4 + serving/build.gradle | 1 + .../main/java/ai/djl/serving/ModelServer.java | 2 +- .../serving/http/InferenceRequestHandler.java | 29 +-- .../http/ManagementRequestHandler.java | 4 +- .../djl/serving/{wlm => models}/Endpoint.java | 3 +- .../serving/{wlm => models}/ModelManager.java | 79 ++++++- .../serving/{wlm => models}/package-info.java | 2 +- .../ai/djl/serving/util/ConfigManager.java | 36 +--- .../src/main/java/ai/djl/serving/wlm/Job.java | 138 ------------ .../wlm/ScaleCapacityExceededException.java | 71 ------- serving/src/main/puml/architecture.puml | 12 +- settings.gradle | 1 + wlm/README.md | 32 +++ wlm/build.gradle | 101 +++++++++ wlm/gradlew | 201 ++++++++++++++++++ .../ai/djl/serving/wlm/BatchAggregator.java | 48 ++--- wlm/src/main/java/ai/djl/serving/wlm/Job.java | 88 ++++++++ .../java/ai/djl/serving/wlm/ModelInfo.java | 4 +- .../serving/wlm/PermanentBatchAggregator.java | 13 +- .../serving/wlm/TemporaryBatchAggregator.java | 13 +- .../ai/djl/serving/wlm/WorkLoadManager.java | 36 +++- .../ai/djl/serving/wlm/WorkerIdGenerator.java | 0 .../java/ai/djl/serving/wlm/WorkerState.java | 0 .../java/ai/djl/serving/wlm/WorkerThread.java | 58 +++-- .../java/ai/djl/serving/wlm/package-info.java | 18 ++ .../wlm/util/WlmCapacityException.java | 45 ++++ .../serving/wlm/util/WlmConfigManager.java | 82 +++++++ .../ai/djl/serving/wlm/util/WlmException.java | 45 ++++ .../wlm/util/WlmShutdownException.java | 45 ++++ .../ai/djl/serving/wlm/util/WorkerJob.java | 53 +++++ .../ai/djl/serving/wlm/util/package-info.java | 14 ++ wlm/src/main/javadoc/overview.html | 14 ++ .../ai/djl/serving/wlm/ModelInfoTest.java | 0 .../serving/wlm/WorkerIdGeneratorTest.java | 0 .../java/ai/djl/serving/wlm/package-info.java | 0 36 files changed, 944 insertions(+), 348 deletions(-) rename serving/src/main/java/ai/djl/serving/{wlm => models}/Endpoint.java (98%) rename serving/src/main/java/ai/djl/serving/{wlm => models}/ModelManager.java (72%) rename serving/src/main/java/ai/djl/serving/{wlm => models}/package-info.java (95%) delete mode 100644 serving/src/main/java/ai/djl/serving/wlm/Job.java delete mode 100644 serving/src/main/java/ai/djl/serving/wlm/ScaleCapacityExceededException.java create mode 100644 wlm/README.md create mode 100644 wlm/build.gradle create mode 100755 wlm/gradlew rename {serving => wlm}/src/main/java/ai/djl/serving/wlm/BatchAggregator.java (75%) create mode 100644 wlm/src/main/java/ai/djl/serving/wlm/Job.java rename {serving => wlm}/src/main/java/ai/djl/serving/wlm/ModelInfo.java (98%) rename {serving => wlm}/src/main/java/ai/djl/serving/wlm/PermanentBatchAggregator.java (83%) rename {serving => wlm}/src/main/java/ai/djl/serving/wlm/TemporaryBatchAggregator.java (86%) rename {serving => wlm}/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java (87%) rename {serving => wlm}/src/main/java/ai/djl/serving/wlm/WorkerIdGenerator.java (100%) rename {serving => wlm}/src/main/java/ai/djl/serving/wlm/WorkerState.java (100%) rename {serving => wlm}/src/main/java/ai/djl/serving/wlm/WorkerThread.java (86%) create mode 100644 wlm/src/main/java/ai/djl/serving/wlm/package-info.java create mode 100644 wlm/src/main/java/ai/djl/serving/wlm/util/WlmCapacityException.java create mode 100644 wlm/src/main/java/ai/djl/serving/wlm/util/WlmConfigManager.java create mode 100644 wlm/src/main/java/ai/djl/serving/wlm/util/WlmException.java create mode 100644 wlm/src/main/java/ai/djl/serving/wlm/util/WlmShutdownException.java create mode 100644 wlm/src/main/java/ai/djl/serving/wlm/util/WorkerJob.java create mode 100644 wlm/src/main/java/ai/djl/serving/wlm/util/package-info.java create mode 100644 wlm/src/main/javadoc/overview.html rename {serving => wlm}/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java (100%) rename {serving => wlm}/src/test/java/ai/djl/serving/wlm/WorkerIdGeneratorTest.java (100%) rename {serving => wlm}/src/test/java/ai/djl/serving/wlm/package-info.java (100%) diff --git a/README.md b/README.md index eb82ea7509..71e0d646dd 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,10 @@ You can install extra extensions to enable the following models: DJL serving is built on top of [Deep Java Library](https://djl.ai). You can visit [DJL github repository](https://github.com/deepjavalibrary/djl) to learn more about DJL. +It is also possible to leverage only the worker thread pool using the separate [WorkLoadManager](wlm) module. +The separate WorkLoadManager can be used by customers who want to take advantage of DJL serving's model batching +and threading but integrated into their own custom Java service. + ![Architecture Diagram](https://resources.djl.ai/images/djl-serving/architecture.png) ## Key features diff --git a/serving/build.gradle b/serving/build.gradle index 876fda5d22..89a192414c 100644 --- a/serving/build.gradle +++ b/serving/build.gradle @@ -4,6 +4,7 @@ plugins { } dependencies { + api project(":wlm") api platform("ai.djl:bom:${project.version}") api "ai.djl:api" api "io.netty:netty-all:${netty_version}" diff --git a/serving/src/main/java/ai/djl/serving/ModelServer.java b/serving/src/main/java/ai/djl/serving/ModelServer.java index f952ecb66f..37f808970c 100644 --- a/serving/src/main/java/ai/djl/serving/ModelServer.java +++ b/serving/src/main/java/ai/djl/serving/ModelServer.java @@ -13,12 +13,12 @@ package ai.djl.serving; import ai.djl.repository.FilenameUtils; +import ai.djl.serving.models.ModelManager; import ai.djl.serving.plugins.FolderScanPluginManager; import ai.djl.serving.util.ConfigManager; import ai.djl.serving.util.Connector; import ai.djl.serving.util.ServerGroups; import ai.djl.serving.wlm.ModelInfo; -import ai.djl.serving.wlm.ModelManager; import ai.djl.util.cuda.CudaUtils; import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.ChannelFuture; diff --git a/serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java b/serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java index bbce5fa97e..27ae660a5c 100644 --- a/serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java +++ b/serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java @@ -15,11 +15,11 @@ import ai.djl.ModelException; import ai.djl.modality.Input; import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.serving.models.ModelManager; import ai.djl.serving.util.ConfigManager; import ai.djl.serving.util.NettyUtils; import ai.djl.serving.wlm.Job; import ai.djl.serving.wlm.ModelInfo; -import ai.djl.serving.wlm.ModelManager; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpMethod; @@ -179,25 +179,7 @@ private void predict( ConfigManager.getInstance().getMaxBatchDelay(), ConfigManager.getInstance().getMaxIdleTime()) .thenApply(m -> modelManager.triggerModelUpdated(m.scaleWorkers(1, -1))) - .thenAccept( - m -> { - try { - if (!modelManager.addJob(new Job(ctx, m, input))) { - throw new ServiceUnavailableException( - "No worker is available to serve request: " - + modelName); - } - } catch (ModelNotFoundException e) { - logger.warn("Unexpected error", e); - NettyUtils.sendError(ctx, e); - } - }) - .exceptionally( - t -> { - logger.warn("Unexpected error", t); - NettyUtils.sendError(ctx, t); - return null; - }); + .thenApply(m -> modelManager.runJob(ctx, new Job(m, input))); return; } @@ -206,11 +188,6 @@ private void predict( return; } - Job job = new Job(ctx, model, input); - if (!modelManager.addJob(job)) { - logger.error("unable to process prediction. no free worker available."); - throw new ServiceUnavailableException( - "No worker is available to serve request: " + modelName); - } + modelManager.runJob(ctx, new Job(model, input)); } } diff --git a/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java b/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java index 7815b85e86..2e710bd2a2 100644 --- a/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java +++ b/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java @@ -14,10 +14,10 @@ import ai.djl.ModelException; import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.serving.models.Endpoint; +import ai.djl.serving.models.ModelManager; import ai.djl.serving.util.NettyUtils; -import ai.djl.serving.wlm.Endpoint; import ai.djl.serving.wlm.ModelInfo; -import ai.djl.serving.wlm.ModelManager; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpMethod; diff --git a/serving/src/main/java/ai/djl/serving/wlm/Endpoint.java b/serving/src/main/java/ai/djl/serving/models/Endpoint.java similarity index 98% rename from serving/src/main/java/ai/djl/serving/wlm/Endpoint.java rename to serving/src/main/java/ai/djl/serving/models/Endpoint.java index 7ad2fbde5b..e7e695073a 100644 --- a/serving/src/main/java/ai/djl/serving/wlm/Endpoint.java +++ b/serving/src/main/java/ai/djl/serving/models/Endpoint.java @@ -10,8 +10,9 @@ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ -package ai.djl.serving.wlm; +package ai.djl.serving.models; +import ai.djl.serving.wlm.ModelInfo; import java.util.ArrayList; import java.util.List; import java.util.Map; diff --git a/serving/src/main/java/ai/djl/serving/wlm/ModelManager.java b/serving/src/main/java/ai/djl/serving/models/ModelManager.java similarity index 72% rename from serving/src/main/java/ai/djl/serving/wlm/ModelManager.java rename to serving/src/main/java/ai/djl/serving/models/ModelManager.java index 539d964d77..2edf7ce99f 100644 --- a/serving/src/main/java/ai/djl/serving/wlm/ModelManager.java +++ b/serving/src/main/java/ai/djl/serving/models/ModelManager.java @@ -10,7 +10,7 @@ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ -package ai.djl.serving.wlm; +package ai.djl.serving.models; import ai.djl.Device; import ai.djl.ModelException; @@ -22,6 +22,19 @@ import ai.djl.serving.http.BadRequestException; import ai.djl.serving.http.DescribeModelResponse; import ai.djl.serving.util.ConfigManager; +import ai.djl.serving.util.NettyUtils; +import ai.djl.serving.wlm.Job; +import ai.djl.serving.wlm.ModelInfo; +import ai.djl.serving.wlm.WorkLoadManager; +import ai.djl.serving.wlm.WorkerThread; +import ai.djl.serving.wlm.util.WlmCapacityException; +import ai.djl.serving.wlm.util.WlmShutdownException; +import ai.djl.translate.TranslateException; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpVersion; import java.io.IOException; import java.util.HashSet; import java.util.List; @@ -236,14 +249,68 @@ public Set getStartupModels() { } /** - * Adds an inference job to the job queue. Assign the job to the next free worker. + * Runs an inference job by assigning the job to the next free worker. * + * @param ctx the netty channel handler context where the job response will be sent * @param job an inference job to be executed - * @return {@code true} if submit success - * @throws ModelNotFoundException if the model is not registered + * @return a future for after the netty response was sent */ - public boolean addJob(Job job) throws ModelNotFoundException { - return wlm.addJob(job); + @SuppressWarnings("PMD.InvalidLogMessageFormat") + public CompletableFuture runJob(ChannelHandlerContext ctx, Job job) { + return wlm.runJob(job) + .whenComplete( + (output, throwable) -> { + if (throwable != null) { + HttpResponseStatus status; + if (throwable instanceof TranslateException) { + status = HttpResponseStatus.BAD_REQUEST; + } else if (throwable instanceof WlmShutdownException) { + status = HttpResponseStatus.SERVICE_UNAVAILABLE; + logger.error("Unable to process prediction. Worker shutdown"); + } else if (throwable instanceof WlmCapacityException) { + logger.error( + "Unable to process prediction. Worker capacity exceeded"); + status = HttpResponseStatus.SERVICE_UNAVAILABLE; + } else { + logger.warn("Unexpected error", throwable); + status = HttpResponseStatus.INTERNAL_SERVER_ERROR; + } + + /* + * We can load the models based on the configuration file.Since this Job is + * not driven by the external connections, we could have a empty context for + * this job. We shouldn't try to send a response to ctx if this is not triggered + * by external clients. + */ + if (ctx != null) { + NettyUtils.sendError(ctx, status, throwable); + } + } else { // Handle output + FullHttpResponse resp = + new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.OK, false); + for (Map.Entry entry : + output.getProperties().entrySet()) { + resp.headers().set(entry.getKey(), entry.getValue()); + } + resp.content().writeBytes(output.getContent()); + + /* + * We can load the models based on the configuration file.Since this Job is + * not driven by the external connections, we could have a empty context for + * this job. We shouldn't try to send a response to ctx if this is not triggered + * by external clients. + */ + if (ctx != null) { + NettyUtils.sendHttpResponse(ctx, resp, true); + } + } + + logger.debug( + "Waiting time: {}, Backend time: {}", + job.getScheduled() - job.getBegin(), + System.currentTimeMillis() - job.getScheduled()); + }); } /** diff --git a/serving/src/main/java/ai/djl/serving/wlm/package-info.java b/serving/src/main/java/ai/djl/serving/models/package-info.java similarity index 95% rename from serving/src/main/java/ai/djl/serving/wlm/package-info.java rename to serving/src/main/java/ai/djl/serving/models/package-info.java index a1eb814e56..e8657db53b 100644 --- a/serving/src/main/java/ai/djl/serving/wlm/package-info.java +++ b/serving/src/main/java/ai/djl/serving/models/package-info.java @@ -11,4 +11,4 @@ * and limitations under the License. */ /** Contains classes that manage model lifecycle. */ -package ai.djl.serving.wlm; +package ai.djl.serving.models; diff --git a/serving/src/main/java/ai/djl/serving/util/ConfigManager.java b/serving/src/main/java/ai/djl/serving/util/ConfigManager.java index 9c66c0c0e9..e4bb50878d 100644 --- a/serving/src/main/java/ai/djl/serving/util/ConfigManager.java +++ b/serving/src/main/java/ai/djl/serving/util/ConfigManager.java @@ -12,14 +12,12 @@ */ package ai.djl.serving.util; -import ai.djl.Device; -import ai.djl.ndarray.NDManager; import ai.djl.serving.Arguments; +import ai.djl.serving.wlm.util.WlmConfigManager; import ai.djl.util.Utils; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.util.SelfSignedCertificate; - import java.io.IOException; import java.io.InputStream; import java.nio.file.Files; @@ -127,6 +125,9 @@ public static void init(Arguments args) { "log4j2.contextSelector", "org.apache.logging.log4j.core.async.AsyncLoggerContextSelector"); } + + // Sets corresponding config in the WlmConfigManager + WlmConfigManager.getInstance().setDebug(instance.isDebug()); } /** @@ -209,35 +210,6 @@ public int getMaxBatchDelay() { return getIntProperty(MAX_BATCH_DELAY, 300); } - /** - * Returns the default number of workers for a new registered model. - * - * @param manager the {@code NDManager} the model uses - * @param target the target number of worker - * @return the default number of workers for a new registered model - */ - public int getDefaultWorkers(NDManager manager, int target) { - if (target == 0) { - return 0; - } else if (target == -1 && isDebug()) { - return 1; - } - if (Device.Type.GPU.equals(manager.getDevice().getDeviceType())) { - if ("MXNet".equals(manager.getEngine().getEngineName())) { - // FIXME: MXNet GPU Model doesn't support multi-threading - return 1; - } else if (target == -1) { - target = 2; // default to max 2 workers per GPU - } - return target; - } - - if (target > 0) { - return target; - } - return Runtime.getRuntime().availableProcessors(); - } - /** * Returns the model server home directory. * diff --git a/serving/src/main/java/ai/djl/serving/wlm/Job.java b/serving/src/main/java/ai/djl/serving/wlm/Job.java deleted file mode 100644 index 26f7a09f2f..0000000000 --- a/serving/src/main/java/ai/djl/serving/wlm/Job.java +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.serving.wlm; - -import ai.djl.modality.Input; -import ai.djl.modality.Output; -import ai.djl.serving.util.NettyUtils; -import io.netty.channel.ChannelHandlerContext; -import io.netty.handler.codec.http.DefaultFullHttpResponse; -import io.netty.handler.codec.http.FullHttpResponse; -import io.netty.handler.codec.http.HttpResponseStatus; -import io.netty.handler.codec.http.HttpVersion; -import java.util.Map; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** A class represents an inference job. */ -public class Job { - - private static final Logger logger = LoggerFactory.getLogger(Job.class); - - private ChannelHandlerContext ctx; - - private ModelInfo modelInfo; - private Input input; - private long begin; - private long scheduled; - - /** - * Constructs an new {@code Job} instance. - * - * @param ctx the {@code ChannelHandlerContext} - * @param modelInfo the model to run the job - * @param input the input data - */ - public Job(ChannelHandlerContext ctx, ModelInfo modelInfo, Input input) { - this.ctx = ctx; - this.modelInfo = modelInfo; - this.input = input; - - begin = System.currentTimeMillis(); - scheduled = begin; - } - - /** - * Returns the request id. - * - * @return the request id - */ - public String getRequestId() { - return input.getRequestId(); - } - - /** - * Returns the model that associated with this job. - * - * @return the model that associated with this job - */ - public ModelInfo getModel() { - return modelInfo; - } - - /** - * Returns the input data. - * - * @return the input data - */ - public Input getInput() { - return input; - } - - /** Marks the job has been scheduled. */ - public void setScheduled() { - scheduled = System.currentTimeMillis(); - } - - /** - * Sends the response back to the client. - * - * @param output the output - */ - public void sendOutput(Output output) { - FullHttpResponse resp = - new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, false); - for (Map.Entry entry : output.getProperties().entrySet()) { - resp.headers().set(entry.getKey(), entry.getValue()); - } - resp.content().writeBytes(output.getContent()); - - /* - * We can load the models based on the configuration file.Since this Job is - * not driven by the external connections, we could have a empty context for - * this job. We shouldn't try to send a response to ctx if this is not triggered - * by external clients. - */ - if (ctx != null) { - NettyUtils.sendHttpResponse(ctx, resp, true); - } - - logger.debug( - "Waiting time: {}, Backend time: {}", - scheduled - begin, - System.currentTimeMillis() - scheduled); - } - - /** - * Sends error to the client. - * - * @param status the HTTP status - * @param error the exception - */ - public void sendError(HttpResponseStatus status, Throwable error) { - /* - * We can load the models based on the configuration file.Since this Job is - * not driven by the external connections, we could have a empty context for - * this job. We shouldn't try to send a response to ctx if this is not triggered - * by external clients. - */ - if (ctx != null) { - NettyUtils.sendError(ctx, status, error); - } - - logger.debug( - "Waiting time: {}, Inference time: {}", - scheduled - begin, - System.currentTimeMillis() - begin); - } -} diff --git a/serving/src/main/java/ai/djl/serving/wlm/ScaleCapacityExceededException.java b/serving/src/main/java/ai/djl/serving/wlm/ScaleCapacityExceededException.java deleted file mode 100644 index 982a81087d..0000000000 --- a/serving/src/main/java/ai/djl/serving/wlm/ScaleCapacityExceededException.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.serving.wlm; - -/** - * Is thrown when capacity of workers is reached during autoscaling. - * - * @author erik.bamberg@web.de - */ -public class ScaleCapacityExceededException extends Exception { - - /** serialVersionUDI for this class cause exceptions are serializable. */ - private static final long serialVersionUID = 1633130362838844091L; - - /** No arguments. */ - public ScaleCapacityExceededException() {} - - /** - * construct using a message. - * - * @param message the message. - */ - public ScaleCapacityExceededException(String message) { - super(message); - } - - /** - * construct using a cause. - * - * @param cause the root cause. - */ - public ScaleCapacityExceededException(Throwable cause) { - super(cause); - } - - /** - * construct using a message and a clause. - * - * @param message the message. - * @param cause the root cause. - */ - public ScaleCapacityExceededException(String message, Throwable cause) { - super(message, cause); - } - - /** - * construct using a message cause and flags. - * - * @param message the message. - * @param cause the root cause. - * @param enableSuppression enable suppression or not. - * @param writableStackTrace flag if writableStackTrace. - */ - public ScaleCapacityExceededException( - String message, - Throwable cause, - boolean enableSuppression, - boolean writableStackTrace) { - super(message, cause, enableSuppression, writableStackTrace); - } -} diff --git a/serving/src/main/puml/architecture.puml b/serving/src/main/puml/architecture.puml index 30d8ef5150..60d3a5c400 100644 --- a/serving/src/main/puml/architecture.puml +++ b/serving/src/main/puml/architecture.puml @@ -26,6 +26,7 @@ package "DJL Serving - single process" { HTTP - REST_API } + package "WorkLoad Manager" as wlm { frame "Worker thread pool" as wp { package Workers [ resent18_v1 (GPU0) @@ -39,8 +40,9 @@ package "DJL Serving - single process" { } queue "Job queue\nauto batch" as jq - [Model Manager] as wlm + } + [Model Manager] as mm frame Models { package Engines [ PyTorch @@ -54,10 +56,10 @@ package "DJL Serving - single process" { } REST_API -> jq - REST_API ---> wlm + REST_API ---> mm jq -> Workers : auto scale - jq ...> wlm - wlm -right-> Engines + jq ...> mm + mm -right-> Engines Engines -[hidden]up- [Translator] Translator <-up- Workers } @@ -66,7 +68,7 @@ frame "Python Workers" { [preprocess] -[hidden]-- [postprocess] } -wlm -down-> URL : load model +mm -down-> URL : load model Translator -up.> preprocess : optional Translator -down.> postprocess : optional @enduml diff --git a/settings.gradle b/settings.gradle index d158bbd3c6..94b449ea07 100644 --- a/settings.gradle +++ b/settings.gradle @@ -3,3 +3,4 @@ include ':central' include ':plugins:plugin-management-plugin' include ':plugins:static-file-plugin' include ':serving' +include ':wlm' diff --git a/wlm/README.md b/wlm/README.md new file mode 100644 index 0000000000..fac74ea145 --- /dev/null +++ b/wlm/README.md @@ -0,0 +1,32 @@ +# DJL Serving - WorkLoadManager + +The djl-serving serving can be divided into a frontend and backend. +The frontend is a [netty](https://netty.io/) webserver that manages incoming requests and operators the control plane. +The backend WorkLoadManager handles the model batching, workers, and threading for high-performance inference. + +For those who already have a web server infrastructure but want to operate high-performance inference, it is possible to use only the WorkLoadManager. +For this reason, we have it split apart into a separate module. + +Using the WorkLoadManager is quite simple. First, create a new one through the constructor: + +```java +WorkLoadManager wlm = new WorkLoadManager(); +``` + +You can also configure the WorkLoadManager by using the static `WlmConfigManager`. + +Then, you can construct an instance of the `ModelInfo` for each model you will want to run through `wlm`. +With the `ModelInfo`, you are able to build a `Job` once you receive input: + +```java +ModelInfo modelInfo = new ModelInfo(...); +Job job = new Job(modelInfo, input); +``` + +Once you have your job, it can be submitted to the WorkLoadManager. +It will automatically spin up workers if none are created and manage worker numbers. +Then, it returns a `CompletableFuture` for the result. + +```java +CompletableFuture futureResult = wlm.runJob(job); +``` diff --git a/wlm/build.gradle b/wlm/build.gradle new file mode 100644 index 0000000000..ae01bc1098 --- /dev/null +++ b/wlm/build.gradle @@ -0,0 +1,101 @@ +plugins { + id "maven-publish" + id "signing" +} + +dependencies { + api platform("ai.djl:bom:${project.version}") + api "ai.djl:api" + api "org.slf4j:slf4j-api:${slf4j_version}" + + testImplementation("org.testng:testng:${testng_version}") { + exclude group: "junit", module: "junit" + } +} + +java { + withJavadocJar() + withSourcesJar() +} + +javadoc { + title "DJL Serving WorkLoadManager ${version}" + options.encoding = "UTF-8" + options.overview "src/main/javadoc/overview.html" + options.addBooleanOption("-allow-script-in-comments", true) +} + +task uploadJavadoc(type: Exec) { + dependsOn javadoc + commandLine "sh", "-c", "find . -name .DS_Store | xargs rm && aws s3 sync build/docs/javadoc s3://javadoc-djl-ai/${project.name}/${version} > build/upload.log" +} + +signing { + required(project.hasProperty("staging") || project.hasProperty("snapshot")) + def signingKey = findProperty("signingKey") + def signingPassword = findProperty("signingPassword") + useInMemoryPgpKeys(signingKey, signingPassword) + sign publishing.publications +} + +publishing { + publications { + maven(MavenPublication) { + from components.java + artifacts = [jar, javadocJar, sourcesJar] + pom { + name = "DJL Serving WorkLoadManager" + description = "DJL Serving WorkLoadManager" + url = "http://www.djl.ai/" + + packaging = "jar" + + licenses { + license { + name = 'The Apache License, Version 2.0' + url = 'https://www.apache.org/licenses/LICENSE-2.0' + } + } + + scm { + connection = "scm:git:git@github.com:deepjavalibrary/djl-serving.git" + developerConnection = "scm:git:git@github.com:deepjavalibrary/djl-serving.git" + url = "https://github.com/deepjavalibrary/djl-serving" + tag = "HEAD" + } + + developers { + developer { + name = "DJL.AI Team" + email = "djl-dev@amazon.com" + organization = "Amazon AI" + organizationUrl = "https://amazon.com" + } + } + } + } + } + + repositories { + maven { + if (project.hasProperty("snapshot")) { + name = "snapshot" + url = "https://oss.sonatype.org/content/repositories/snapshots/" + credentials { + username = findProperty("ossrhUsername") + password = findProperty("ossrhPassword") + } + } else if (project.hasProperty("staging")) { + name = "staging" + url = "https://oss.sonatype.org/service/local/staging/deploy/maven2/" + credentials { + username = findProperty("ossrhUsername") + password = findProperty("ossrhPassword") + } + } else { + name = "local" + url = "build/repo" + } + } + } +} diff --git a/wlm/gradlew b/wlm/gradlew new file mode 100755 index 0000000000..b70597b686 --- /dev/null +++ b/wlm/gradlew @@ -0,0 +1,201 @@ +#!/usr/bin/env sh + +############################################################################## +## +## Gradle start up script for UN*X +## +############################################################################## + +# Attempt to set APP_HOME +# Resolve links: $0 may be a link +PRG="$0" +# Need this for relative symlinks. +while [ -h "$PRG" ] ; do + ls=`ls -ld "$PRG"` + link=`expr "$ls" : '.*-> \(.*\)$'` + if expr "$link" : '/.*' > /dev/null; then + PRG="$link" + else + PRG=`dirname "$PRG"`"/$link" + fi +done +SAVED="`pwd`" +cd "`dirname \"$PRG\"`/" >/dev/null +APP_HOME="`pwd -P`" +cd "$SAVED" >/dev/null + +APP_NAME="Gradle" +APP_BASE_NAME=`basename "$0"` + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS="" + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD="maximum" + +warn () { + echo "$*" +} + +die () { + echo + echo "$*" + echo + exit 1 +} + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "`uname`" in + CYGWIN* ) + cygwin=true + ;; + Darwin* ) + darwin=true + ;; + MINGW* ) + msys=true + ;; + NONSTOP* ) + nonstop=true + ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + else + JAVACMD="$JAVA_HOME/bin/java" + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD="java" + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." +fi + +# Increase the maximum file descriptors if we can. +if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then + MAX_FD_LIMIT=`ulimit -H -n` + if [ $? -eq 0 ] ; then + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then + MAX_FD="$MAX_FD_LIMIT" + fi + ulimit -n $MAX_FD + if [ $? -ne 0 ] ; then + warn "Could not set maximum file descriptor limit: $MAX_FD" + fi + else + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" + fi +fi + +########################################################################################## +# Extension to allow automatically downloading the gradle-wrapper.jar +# This allows using the maven wrapper in projects that prohibit checking in binary data. +########################################################################################## +WRAPPER_JAR_PATH="$APP_HOME/gradle/wrapper/gradle-wrapper.jar" +if [ ! -r "${WRAPPER_JAR_PATH}" ]; then + jarUrl="https://raw.githubusercontent.com/gradle/gradle/master/gradle/wrapper/gradle-wrapper.jar" + if command -v wget > /dev/null; then + wget -q "${jarUrl}" -O "${WRAPPER_JAR_PATH}" + elif command -v curl > /dev/null; then + curl -s -o "${WRAPPER_JAR_PATH}" "$jarUrl" + else + javaClass="$APP_HOME/gradle/wrapper/GradleWrapperDownloader.java" + if [ -e "$javaClass" ]; then + if [ ! -e "$APP_HOME/gradle/wrapper/GradleWrapperDownloader.class" ]; then + # Compiling the Java class + ("${JAVACMD}c" "$javaClass") + fi + if [ -e "$APP_HOME/gradle/wrapper/GradleWrapperDownloader.class" ]; then + ("$JAVACMD" -cp gradle/wrapper GradleWrapperDownloader "$APP_HOME") + fi + fi + fi +fi +########################################################################################## +# End of extension +########################################################################################## + + +# For Darwin, add options to specify how the application appears in the dock +if $darwin; then + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" +fi + +# For Cygwin, switch paths to Windows format before running java +if $cygwin ; then + APP_HOME=`cygpath --path --mixed "$APP_HOME"` + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` + JAVACMD=`cygpath --unix "$JAVACMD"` + + # We build the pattern for arguments to be converted via cygpath + ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` + SEP="" + for dir in $ROOTDIRSRAW ; do + ROOTDIRS="$ROOTDIRS$SEP$dir" + SEP="|" + done + OURCYGPATTERN="(^($ROOTDIRS))" + # Add a user-defined pattern to the cygpath arguments + if [ "$GRADLE_CYGPATTERN" != "" ] ; then + OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" + fi + # Now convert the arguments - kludge to limit ourselves to /bin/sh + i=0 + for arg in "$@" ; do + CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` + CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option + + if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition + eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` + else + eval `echo args$i`="\"$arg\"" + fi + i=$((i+1)) + done + case $i in + (0) set -- ;; + (1) set -- "$args0" ;; + (2) set -- "$args0" "$args1" ;; + (3) set -- "$args0" "$args1" "$args2" ;; + (4) set -- "$args0" "$args1" "$args2" "$args3" ;; + (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; + (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; + (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; + (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; + (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; + esac +fi + +# Escape application args +save () { + for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done + echo " " +} +APP_ARGS=$(save "$@") + +# Collect all arguments for the java command, following the shell quoting and substitution rules +eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" + +# by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong +if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then + cd "$(dirname "$0")" +fi + +exec "$JAVACMD" "$@" diff --git a/serving/src/main/java/ai/djl/serving/wlm/BatchAggregator.java b/wlm/src/main/java/ai/djl/serving/wlm/BatchAggregator.java similarity index 75% rename from serving/src/main/java/ai/djl/serving/wlm/BatchAggregator.java rename to wlm/src/main/java/ai/djl/serving/wlm/BatchAggregator.java index 25ed8515d3..2a6bd046d6 100644 --- a/serving/src/main/java/ai/djl/serving/wlm/BatchAggregator.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/BatchAggregator.java @@ -14,7 +14,7 @@ import ai.djl.modality.Input; import ai.djl.modality.Output; -import io.netty.handler.codec.http.HttpResponseStatus; +import ai.djl.serving.wlm.util.WorkerJob; import java.util.ArrayList; import java.util.List; import java.util.concurrent.LinkedBlockingDeque; @@ -30,8 +30,8 @@ abstract class BatchAggregator { protected int batchSize; protected int maxBatchDelay; - protected List jobs; - protected LinkedBlockingDeque jobQueue; + protected List wjs; + protected LinkedBlockingDeque jobQueue; /** * Constructs a new {@code BbatchAggregator} instance. @@ -39,11 +39,11 @@ abstract class BatchAggregator { * @param model the model to use. * @param jobQueue the job queue for polling data from. */ - public BatchAggregator(ModelInfo model, LinkedBlockingDeque jobQueue) { + public BatchAggregator(ModelInfo model, LinkedBlockingDeque jobQueue) { this.batchSize = model.getBatchSize(); this.maxBatchDelay = model.getMaxBatchDelay(); this.jobQueue = jobQueue; - jobs = new ArrayList<>(); + wjs = new ArrayList<>(); } /** @@ -54,9 +54,10 @@ public BatchAggregator(ModelInfo model, LinkedBlockingDeque jobQueue) { * queue. */ public List getRequest() throws InterruptedException { - jobs = pollBatch(); - List list = new ArrayList<>(jobs.size()); - for (Job job : jobs) { + wjs = pollBatch(); + List list = new ArrayList<>(wjs.size()); + for (WorkerJob wj : wjs) { + Job job = wj.getJob(); job.setScheduled(); list.add(job.getInput()); } @@ -69,30 +70,29 @@ public List getRequest() throws InterruptedException { * @param outputs list of model-outputs in same order as the input objects. */ public void sendResponse(List outputs) { - if (jobs.size() != outputs.size()) { + if (wjs.size() != outputs.size()) { throw new IllegalStateException("Not all jobs get response."); } int i = 0; for (Output output : outputs) { - Job job = jobs.get(i++); - output.setRequestId(job.getRequestId()); - job.sendOutput(output); + WorkerJob wj = wjs.get(i++); + output.setRequestId(wj.getJob().getRequestId()); + wj.getFuture().complete(output); } - jobs.clear(); + wjs.clear(); } /** - * Sends an error response to client. + * Completes the job with an error. * - * @param status the HTTP status * @param error the exception */ - public void sendError(HttpResponseStatus status, Throwable error) { - for (Job job : jobs) { - job.sendError(status, error); + public void sendError(Throwable error) { + for (WorkerJob wj : wjs) { + wj.getFuture().completeExceptionally(error); } - jobs.clear(); + wjs.clear(); } /** @@ -101,7 +101,7 @@ public void sendError(HttpResponseStatus status, Throwable error) { * @return a list of jobs read by this batch interation. * @throws InterruptedException if interrupted */ - protected abstract List pollBatch() throws InterruptedException; + protected abstract List pollBatch() throws InterruptedException; /** * Checks if this {@code BatchAggregator} and the thread can be shutdown or if this aggregator @@ -112,19 +112,19 @@ public void sendError(HttpResponseStatus status, Throwable error) { */ public abstract boolean isFinished(); - protected void drainTo(List list, int maxDelay) throws InterruptedException { + protected void drainTo(List list, int maxDelay) throws InterruptedException { long begin = System.currentTimeMillis(); jobQueue.drainTo(list, batchSize - 1); int remain = batchSize - list.size(); for (int i = 0; i < remain; ++i) { - Job job = jobQueue.poll(maxDelay, TimeUnit.MILLISECONDS); - if (job == null) { + WorkerJob wj = jobQueue.poll(maxDelay, TimeUnit.MILLISECONDS); + if (wj == null || wj.getJob() == null) { break; } long end = System.currentTimeMillis(); maxDelay -= end - begin; begin = end; - list.add(job); + list.add(wj); if (maxDelay <= 0) { break; } diff --git a/wlm/src/main/java/ai/djl/serving/wlm/Job.java b/wlm/src/main/java/ai/djl/serving/wlm/Job.java new file mode 100644 index 0000000000..04e5263ba1 --- /dev/null +++ b/wlm/src/main/java/ai/djl/serving/wlm/Job.java @@ -0,0 +1,88 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.wlm; + +import ai.djl.modality.Input; + +/** A class represents an inference job. */ +public class Job { + + private ModelInfo modelInfo; + private Input input; + private long begin; + private long scheduled; + + /** + * Constructs a new {@code Job} instance. + * + * @param modelInfo the model to run the job + * @param input the input data + */ + public Job(ModelInfo modelInfo, Input input) { + this.modelInfo = modelInfo; + this.input = input; + + begin = System.currentTimeMillis(); + scheduled = begin; + } + + /** + * Returns the request id. + * + * @return the request id + */ + public String getRequestId() { + return input.getRequestId(); + } + + /** + * Returns the model that associated with this job. + * + * @return the model that associated with this job + */ + public ModelInfo getModel() { + return modelInfo; + } + + /** + * Returns the input data. + * + * @return the input data + */ + public Input getInput() { + return input; + } + + /** + * Returns the job begin time. + * + * @return the job begin time + */ + public long getBegin() { + return begin; + } + + /** + * Returns the job scheduled time. + * + * @return the job scheduled time + */ + public long getScheduled() { + return scheduled; + } + + /** Marks the job has been scheduled. */ + public void setScheduled() { + scheduled = System.currentTimeMillis(); + } +} diff --git a/serving/src/main/java/ai/djl/serving/wlm/ModelInfo.java b/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java similarity index 98% rename from serving/src/main/java/ai/djl/serving/wlm/ModelInfo.java rename to wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java index b3ec29bb16..317bf428c6 100644 --- a/serving/src/main/java/ai/djl/serving/wlm/ModelInfo.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java @@ -17,7 +17,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.repository.FilenameUtils; import ai.djl.repository.zoo.ZooModel; -import ai.djl.serving.util.ConfigManager; +import ai.djl.serving.wlm.util.WlmConfigManager; import java.net.URI; import java.nio.file.Path; import java.util.Objects; @@ -98,7 +98,7 @@ public ModelInfo configureModelBatch(int batchSize, int maxBatchDelay) { */ public ModelInfo scaleWorkers(int minWorkers, int maxWorkers) { NDManager manager = model.getNDManager(); - ConfigManager configManager = ConfigManager.getInstance(); + WlmConfigManager configManager = WlmConfigManager.getInstance(); this.maxWorkers = configManager.getDefaultWorkers(manager, maxWorkers); this.minWorkers = Math.min(minWorkers, this.maxWorkers); return this; diff --git a/serving/src/main/java/ai/djl/serving/wlm/PermanentBatchAggregator.java b/wlm/src/main/java/ai/djl/serving/wlm/PermanentBatchAggregator.java similarity index 83% rename from serving/src/main/java/ai/djl/serving/wlm/PermanentBatchAggregator.java rename to wlm/src/main/java/ai/djl/serving/wlm/PermanentBatchAggregator.java index b7d5637fed..dec1c36b91 100644 --- a/serving/src/main/java/ai/djl/serving/wlm/PermanentBatchAggregator.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/PermanentBatchAggregator.java @@ -12,6 +12,7 @@ */ package ai.djl.serving.wlm; +import ai.djl.serving.wlm.util.WorkerJob; import java.util.ArrayList; import java.util.List; import java.util.concurrent.LinkedBlockingDeque; @@ -34,17 +35,17 @@ public class PermanentBatchAggregator extends BatchAggregator { * @param model the model to use. * @param jobQueue the job queue for polling data from. */ - public PermanentBatchAggregator(ModelInfo model, LinkedBlockingDeque jobQueue) { + public PermanentBatchAggregator(ModelInfo model, LinkedBlockingDeque jobQueue) { super(model, jobQueue); } /** {@inheritDoc} */ @Override - protected List pollBatch() throws InterruptedException { - List list = new ArrayList<>(batchSize); - Job job = jobQueue.take(); - list.add(job); - logger.trace("get first job: {}", job.getRequestId()); + protected List pollBatch() throws InterruptedException { + List list = new ArrayList<>(batchSize); + WorkerJob wj = jobQueue.take(); + list.add(wj); + logger.trace("get first job: {}", wj.getJob().getRequestId()); drainTo(list, maxBatchDelay); logger.trace("sending jobs, size: {}", list.size()); return list; diff --git a/serving/src/main/java/ai/djl/serving/wlm/TemporaryBatchAggregator.java b/wlm/src/main/java/ai/djl/serving/wlm/TemporaryBatchAggregator.java similarity index 86% rename from serving/src/main/java/ai/djl/serving/wlm/TemporaryBatchAggregator.java rename to wlm/src/main/java/ai/djl/serving/wlm/TemporaryBatchAggregator.java index e3dc134859..1482a5f5ad 100644 --- a/serving/src/main/java/ai/djl/serving/wlm/TemporaryBatchAggregator.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/TemporaryBatchAggregator.java @@ -12,6 +12,7 @@ */ package ai.djl.serving.wlm; +import ai.djl.serving.wlm.util.WorkerJob; import java.util.ArrayList; import java.util.List; import java.util.concurrent.LinkedBlockingDeque; @@ -37,7 +38,7 @@ public class TemporaryBatchAggregator extends BatchAggregator { * @param model the model to run for. * @param jobQueue reference to external job queue for polling. */ - public TemporaryBatchAggregator(ModelInfo model, LinkedBlockingDeque jobQueue) { + public TemporaryBatchAggregator(ModelInfo model, LinkedBlockingDeque jobQueue) { super(model, jobQueue); this.idleSince = System.currentTimeMillis(); this.maxIdleTime = model.getMaxIdleTime(); @@ -45,11 +46,11 @@ public TemporaryBatchAggregator(ModelInfo model, LinkedBlockingDeque jobQue /** {@inheritDoc} */ @Override - protected List pollBatch() throws InterruptedException { - List list = new ArrayList<>(batchSize); - Job job = jobQueue.poll(maxIdleTime, TimeUnit.SECONDS); - if (job != null) { - list.add(job); + protected List pollBatch() throws InterruptedException { + List list = new ArrayList<>(batchSize); + WorkerJob wj = jobQueue.poll(maxIdleTime, TimeUnit.SECONDS); + if (wj != null && wj.getJob() != null) { + list.add(wj); drainTo(list, maxBatchDelay); logger.trace("sending jobs, size: {}", list.size()); idleSince = System.currentTimeMillis(); diff --git a/serving/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java b/wlm/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java similarity index 87% rename from serving/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java rename to wlm/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java index 4e702821f7..155db70f99 100644 --- a/serving/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java @@ -12,8 +12,13 @@ */ package ai.djl.serving.wlm; +import ai.djl.modality.Output; +import ai.djl.serving.wlm.util.WlmCapacityException; +import ai.djl.serving.wlm.util.WlmShutdownException; +import ai.djl.serving.wlm.util.WorkerJob; import java.util.Collections; import java.util.List; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.ExecutorService; @@ -24,19 +29,19 @@ import org.slf4j.LoggerFactory; /** - * WorkLoadManager is repsonsible to manage the work load of worker thread. the manage scales + * WorkLoadManager is responsible to manage the work load of worker thread. the manage scales * up/down the required amount of worker threads per model. * * @author erik.bamberg@web.de */ -class WorkLoadManager { +public class WorkLoadManager { private static final Logger logger = LoggerFactory.getLogger(WorkLoadManager.class); private ExecutorService threadPool; private ConcurrentHashMap workerPools; - /** Constructs a {@code WorkLoadManager} instance. */ + /** Constructs a {@link WorkLoadManager} instance. */ public WorkLoadManager() { threadPool = Executors.newCachedThreadPool(); workerPools = new ConcurrentHashMap<>(); @@ -69,18 +74,27 @@ public List getWorkers(ModelInfo modelInfo) { * @param job an inference job to be executed. * @return {@code true} if submit success, false otherwise. */ - public boolean addJob(Job job) { + public CompletableFuture runJob(Job job) { + CompletableFuture result = new CompletableFuture<>(); ModelInfo modelInfo = job.getModel(); int maxWorkers = modelInfo.getMaxWorkers(); if (maxWorkers == 0) { logger.info("All model workers has been shutdown: {}", modelInfo.getModelName()); - return false; + result.completeExceptionally( + new WlmShutdownException( + "No worker is available to serve request: " + + modelInfo.getModelName())); + return result; } WorkerPool pool = getWorkerPoolForModel(modelInfo); - LinkedBlockingDeque queue = pool.getJobQueue(); - if (!queue.offer(job)) { + LinkedBlockingDeque queue = pool.getJobQueue(); + if (!queue.offer(new WorkerJob(job, result))) { logger.warn("Worker queue capacity exceeded for model: {}", modelInfo.getModelName()); - return false; + result.completeExceptionally( + new WlmCapacityException( + "No worker is available to serve request: " + + modelInfo.getModelName())); + return result; } int currentWorkers = getNumRunningWorkers(modelInfo); @@ -97,7 +111,7 @@ public boolean addJob(Job job) { } } } - return true; + return result; } /** @@ -208,7 +222,7 @@ private void addThreads( private static final class WorkerPool { private List workers; - private LinkedBlockingDeque jobQueue; + private LinkedBlockingDeque jobQueue; private String modelName; /** @@ -236,7 +250,7 @@ public List getWorkers() { * * @return the jobQueue */ - public LinkedBlockingDeque getJobQueue() { + public LinkedBlockingDeque getJobQueue() { return jobQueue; } diff --git a/serving/src/main/java/ai/djl/serving/wlm/WorkerIdGenerator.java b/wlm/src/main/java/ai/djl/serving/wlm/WorkerIdGenerator.java similarity index 100% rename from serving/src/main/java/ai/djl/serving/wlm/WorkerIdGenerator.java rename to wlm/src/main/java/ai/djl/serving/wlm/WorkerIdGenerator.java diff --git a/serving/src/main/java/ai/djl/serving/wlm/WorkerState.java b/wlm/src/main/java/ai/djl/serving/wlm/WorkerState.java similarity index 100% rename from serving/src/main/java/ai/djl/serving/wlm/WorkerState.java rename to wlm/src/main/java/ai/djl/serving/wlm/WorkerState.java diff --git a/serving/src/main/java/ai/djl/serving/wlm/WorkerThread.java b/wlm/src/main/java/ai/djl/serving/wlm/WorkerThread.java similarity index 86% rename from serving/src/main/java/ai/djl/serving/wlm/WorkerThread.java rename to wlm/src/main/java/ai/djl/serving/wlm/WorkerThread.java index fedc88df3a..7411581b60 100644 --- a/serving/src/main/java/ai/djl/serving/wlm/WorkerThread.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/WorkerThread.java @@ -12,14 +12,12 @@ */ package ai.djl.serving.wlm; -import ai.djl.engine.EngineException; import ai.djl.inference.Predictor; import ai.djl.modality.Input; import ai.djl.modality.Output; import ai.djl.repository.zoo.ZooModel; -import ai.djl.serving.http.InternalServerException; -import ai.djl.translate.TranslateException; -import io.netty.handler.codec.http.HttpResponseStatus; +import ai.djl.serving.wlm.util.WlmException; +import ai.djl.serving.wlm.util.WorkerJob; import java.util.List; import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.atomic.AtomicBoolean; @@ -27,7 +25,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -final class WorkerThread implements Runnable { +/** The {@link WorkerThread} is the worker managed by the {@link WorkLoadManager}. */ +public final class WorkerThread implements Runnable { private static final Logger logger = LoggerFactory.getLogger(WorkerThread.class); @@ -76,12 +75,9 @@ public void run() { try { List reply = predictor.batchPredict(req); aggregator.sendResponse(reply); - } catch (EngineException e) { + } catch (Exception e) { logger.warn("Failed to predict", e); - aggregator.sendError(HttpResponseStatus.INTERNAL_SERVER_ERROR, e); - } catch (TranslateException e) { - logger.warn("Failed to predict", e); - aggregator.sendError(HttpResponseStatus.BAD_REQUEST, e); + aggregator.sendError(e); } } req = null; @@ -96,40 +92,70 @@ public void run() { currentThread.set(null); shutdown(WorkerState.WORKER_STOPPED); if (req != null) { - Exception e = new InternalServerException(errorMessage); - aggregator.sendError(HttpResponseStatus.INTERNAL_SERVER_ERROR, e); + Exception e = new WlmException(errorMessage); + aggregator.sendError(e); } } } + /** + * Returns the worker thread ID. + * + * @return the worker thread ID + */ public int getWorkerId() { return workerId; } + /** + * Returns true if the worker thread is running. + * + * @return true if the worker thread is running + */ public boolean isRunning() { return running.get(); } + /** + * Returns the gpu id used by the thread. + * + * @return the gpu id used by the thread + */ public int getGpuId() { return gpuId; } + /** + * Returns the thread start time. + * + * @return the thread start time + */ public long getStartTime() { return startTime; } + /** + * Returns the worker state. + * + * @return the worker state + */ public WorkerState getState() { return state; } + /** + * Shuts down the worker thread. + * + * @param state the state to set the thread to + */ public void shutdown(WorkerState state) { running.set(false); setState(state); Thread thread = currentThread.getAndSet(null); if (thread != null) { thread.interrupt(); - Exception e = new InternalServerException("Worker shutting down"); - aggregator.sendError(HttpResponseStatus.INTERNAL_SERVER_ERROR, e); + Exception e = new WlmException("Worker shutting down"); + aggregator.sendError(e); } predictor.close(); } @@ -175,7 +201,7 @@ public static class Builder { private ModelInfo model; private BatchAggregator aggregator; - private LinkedBlockingDeque jobQueue; + private LinkedBlockingDeque jobQueue; private boolean fixPoolThread; Builder() { @@ -253,7 +279,7 @@ public Builder optAggregator(BatchAggregator aggregator) { * @param jobQueue the jobQueue to set * @return self-reference to this builder. */ - public Builder setJobQueue(LinkedBlockingDeque jobQueue) { + public Builder setJobQueue(LinkedBlockingDeque jobQueue) { this.jobQueue = jobQueue; return self(); } diff --git a/wlm/src/main/java/ai/djl/serving/wlm/package-info.java b/wlm/src/main/java/ai/djl/serving/wlm/package-info.java new file mode 100644 index 0000000000..681d0e86e1 --- /dev/null +++ b/wlm/src/main/java/ai/djl/serving/wlm/package-info.java @@ -0,0 +1,18 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +/** + * Contains the model server backend which manages worker threads and executes jobs on models. + * + * @see ai.djl.serving.wlm.WorkLoadManager + */ +package ai.djl.serving.wlm; diff --git a/wlm/src/main/java/ai/djl/serving/wlm/util/WlmCapacityException.java b/wlm/src/main/java/ai/djl/serving/wlm/util/WlmCapacityException.java new file mode 100644 index 0000000000..22042c7746 --- /dev/null +++ b/wlm/src/main/java/ai/djl/serving/wlm/util/WlmCapacityException.java @@ -0,0 +1,45 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.wlm.util; + +/** Thrown to throttle when a job is run but the job queue capacity is exceeded. */ +public class WlmCapacityException extends RuntimeException { + + static final long serialVersionUID = 1L; + + /** + * Constructs a {@link WlmCapacityException} with the specified detail message. + * + * @param message The detail message (which is saved for later retrieval by the {@link + * #getMessage()} method) + */ + public WlmCapacityException(String message) { + super(message); + } + + /** + * Constructs a {@link WlmCapacityException} with the specified detail message and cause. + * + *

Note that the detail message associated with {@code cause} is not automatically + * incorporated into this exception's detail message. + * + * @param message The detail message (which is saved for later retrieval by the {@link + * #getMessage()} method) + * @param cause The cause (which is saved for later retrieval by the {@link #getCause()} + * method). (A null value is permitted, and indicates that the cause is nonexistent or + * unknown.) + */ + public WlmCapacityException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/wlm/src/main/java/ai/djl/serving/wlm/util/WlmConfigManager.java b/wlm/src/main/java/ai/djl/serving/wlm/util/WlmConfigManager.java new file mode 100644 index 0000000000..181991becd --- /dev/null +++ b/wlm/src/main/java/ai/djl/serving/wlm/util/WlmConfigManager.java @@ -0,0 +1,82 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.wlm.util; + +import ai.djl.Device; +import ai.djl.ndarray.NDManager; + +/** This manages some configurations used by the {@link ai.djl.serving.wlm.WorkLoadManager}. */ +public final class WlmConfigManager { + + private static final WlmConfigManager INSTANCE = new WlmConfigManager(); + + private boolean debug; + + /** + * Returns the singleton {@code ConfigManager} instance. + * + * @return the singleton {@code ConfigManager} instance + */ + public static WlmConfigManager getInstance() { + return INSTANCE; + } + + /** + * Returns if debug is enabled. + * + * @return {@code true} if debug is enabled + */ + public boolean isDebug() { + return debug; + } + + /** + * Sets debug mode. + * + * @param debug true to enable debug mode and false to disable + * @return this config manager + */ + public WlmConfigManager setDebug(boolean debug) { + this.debug = debug; + return this; + } + + /** + * Returns the default number of workers for a new registered model. + * + * @param manager the {@code NDManager} the model uses + * @param target the target number of worker + * @return the default number of workers for a new registered model + */ + public int getDefaultWorkers(NDManager manager, int target) { + if (target == 0) { + return 0; + } else if (target == -1 && isDebug()) { + return 1; + } + if (Device.Type.GPU.equals(manager.getDevice().getDeviceType())) { + if ("MXNet".equals(manager.getEngine().getEngineName())) { + // FIXME: MXNet GPU Model doesn't support multi-threading + return 1; + } else if (target == -1) { + target = 2; // default to max 2 workers per GPU + } + return target; + } + + if (target > 0) { + return target; + } + return Runtime.getRuntime().availableProcessors(); + } +} diff --git a/wlm/src/main/java/ai/djl/serving/wlm/util/WlmException.java b/wlm/src/main/java/ai/djl/serving/wlm/util/WlmException.java new file mode 100644 index 0000000000..7293205456 --- /dev/null +++ b/wlm/src/main/java/ai/djl/serving/wlm/util/WlmException.java @@ -0,0 +1,45 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.wlm.util; + +/** Thrown when an exception occurs inside the {@link ai.djl.serving.wlm.WorkLoadManager}. */ +public class WlmException extends RuntimeException { + + static final long serialVersionUID = 1L; + + /** + * Constructs a {@link WlmException} with the specified detail message. + * + * @param message The detail message (which is saved for later retrieval by the {@link + * #getMessage()} method) + */ + public WlmException(String message) { + super(message); + } + + /** + * Constructs a {@link WlmException} with the specified detail message and cause. + * + *

Note that the detail message associated with {@code cause} is not automatically + * incorporated into this exception's detail message. + * + * @param message The detail message (which is saved for later retrieval by the {@link + * #getMessage()} method) + * @param cause The cause (which is saved for later retrieval by the {@link #getCause()} + * method). (A null value is permitted, and indicates that the cause is nonexistent or + * unknown.) + */ + public WlmException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/wlm/src/main/java/ai/djl/serving/wlm/util/WlmShutdownException.java b/wlm/src/main/java/ai/djl/serving/wlm/util/WlmShutdownException.java new file mode 100644 index 0000000000..3527fa1a7a --- /dev/null +++ b/wlm/src/main/java/ai/djl/serving/wlm/util/WlmShutdownException.java @@ -0,0 +1,45 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.wlm.util; + +/** Thrown when a job is run but all workers are shutdown. */ +public class WlmShutdownException extends RuntimeException { + + static final long serialVersionUID = 1L; + + /** + * Constructs a {@link WlmShutdownException} with the specified detail message. + * + * @param message The detail message (which is saved for later retrieval by the {@link + * #getMessage()} method) + */ + public WlmShutdownException(String message) { + super(message); + } + + /** + * Constructs a {@link WlmShutdownException} with the specified detail message and cause. + * + *

Note that the detail message associated with {@code cause} is not automatically + * incorporated into this exception's detail message. + * + * @param message The detail message (which is saved for later retrieval by the {@link + * #getMessage()} method) + * @param cause The cause (which is saved for later retrieval by the {@link #getCause()} + * method). (A null value is permitted, and indicates that the cause is nonexistent or + * unknown.) + */ + public WlmShutdownException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/wlm/src/main/java/ai/djl/serving/wlm/util/WorkerJob.java b/wlm/src/main/java/ai/djl/serving/wlm/util/WorkerJob.java new file mode 100644 index 0000000000..dce9c01a30 --- /dev/null +++ b/wlm/src/main/java/ai/djl/serving/wlm/util/WorkerJob.java @@ -0,0 +1,53 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.wlm.util; + +import ai.djl.modality.Output; +import ai.djl.serving.wlm.Job; +import java.util.concurrent.CompletableFuture; + +/** A {@link Job} containing metadata from the {@link ai.djl.serving.wlm.WorkLoadManager}. */ +public final class WorkerJob { + + private final Job job; + private final CompletableFuture future; + + /** + * Constructs a new {@link WorkerJob}. + * + * @param job the job to execute + * @param future the future containing the job response + */ + public WorkerJob(Job job, CompletableFuture future) { + this.job = job; + this.future = future; + } + + /** + * Returns the {@link Job}. + * + * @return the {@link Job} + */ + public Job getJob() { + return job; + } + + /** + * Returns the future for the job. + * + * @return the future for the job + */ + public CompletableFuture getFuture() { + return future; + } +} diff --git a/wlm/src/main/java/ai/djl/serving/wlm/util/package-info.java b/wlm/src/main/java/ai/djl/serving/wlm/util/package-info.java new file mode 100644 index 0000000000..a1d39eb57d --- /dev/null +++ b/wlm/src/main/java/ai/djl/serving/wlm/util/package-info.java @@ -0,0 +1,14 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +/** Contains utilities to support the {@link ai.djl.serving.wlm.WorkLoadManager}. */ +package ai.djl.serving.wlm.util; diff --git a/wlm/src/main/javadoc/overview.html b/wlm/src/main/javadoc/overview.html new file mode 100644 index 0000000000..a7ef6f1731 --- /dev/null +++ b/wlm/src/main/javadoc/overview.html @@ -0,0 +1,14 @@ + + + + + +

This document is the API specification for the DJL Serving WorkLoadManager.

+ +

+ This module provides the worker and thread management for a high-performance inference server. + See here for more details. +

+ + + diff --git a/serving/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java b/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java similarity index 100% rename from serving/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java rename to wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java diff --git a/serving/src/test/java/ai/djl/serving/wlm/WorkerIdGeneratorTest.java b/wlm/src/test/java/ai/djl/serving/wlm/WorkerIdGeneratorTest.java similarity index 100% rename from serving/src/test/java/ai/djl/serving/wlm/WorkerIdGeneratorTest.java rename to wlm/src/test/java/ai/djl/serving/wlm/WorkerIdGeneratorTest.java diff --git a/serving/src/test/java/ai/djl/serving/wlm/package-info.java b/wlm/src/test/java/ai/djl/serving/wlm/package-info.java similarity index 100% rename from serving/src/test/java/ai/djl/serving/wlm/package-info.java rename to wlm/src/test/java/ai/djl/serving/wlm/package-info.java